Spaces:
Running
Running
Good bar chart
Browse files- app.py +3 -7
- bar_plot.py +66 -108
- data.py +12 -0
- plot_utils.py +53 -0
app.py
CHANGED
|
@@ -2,7 +2,6 @@ import gradio as gr
|
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import matplotlib
|
| 4 |
|
| 5 |
-
from data import ModelBenchmarkData
|
| 6 |
from bar_plot import create_matplotlib_bar_plot
|
| 7 |
|
| 8 |
|
|
@@ -11,9 +10,6 @@ matplotlib.use('Agg')
|
|
| 11 |
plt.ioff()
|
| 12 |
|
| 13 |
|
| 14 |
-
DATA = ModelBenchmarkData("data.json")
|
| 15 |
-
|
| 16 |
-
|
| 17 |
def load_css():
|
| 18 |
"""Load CSS styling."""
|
| 19 |
try:
|
|
@@ -23,10 +19,10 @@ def load_css():
|
|
| 23 |
return "body { background: #000; color: #fff; }"
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
def refresh_plot():
|
| 28 |
"""Generate new matplotlib charts and update description."""
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
# Create Gradio interface
|
| 32 |
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
|
|
@@ -40,7 +36,7 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
|
|
| 40 |
# Main plot area
|
| 41 |
with gr.Column(elem_classes=["main-content"]):
|
| 42 |
plot = gr.HTML(
|
| 43 |
-
create_matplotlib_bar_plot(
|
| 44 |
elem_classes=["plot-container"],
|
| 45 |
)
|
| 46 |
|
|
|
|
| 2 |
import matplotlib.pyplot as plt
|
| 3 |
import matplotlib
|
| 4 |
|
|
|
|
| 5 |
from bar_plot import create_matplotlib_bar_plot
|
| 6 |
|
| 7 |
|
|
|
|
| 10 |
plt.ioff()
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
def load_css():
|
| 14 |
"""Load CSS styling."""
|
| 15 |
try:
|
|
|
|
| 19 |
return "body { background: #000; color: #fff; }"
|
| 20 |
|
| 21 |
|
|
|
|
| 22 |
def refresh_plot():
|
| 23 |
"""Generate new matplotlib charts and update description."""
|
| 24 |
+
sidebar_text = "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
|
| 25 |
+
return create_matplotlib_bar_plot(), sidebar_text
|
| 26 |
|
| 27 |
# Create Gradio interface
|
| 28 |
with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
|
|
|
|
| 36 |
# Main plot area
|
| 37 |
with gr.Column(elem_classes=["main-content"]):
|
| 38 |
plot = gr.HTML(
|
| 39 |
+
create_matplotlib_bar_plot(),
|
| 40 |
elem_classes=["plot-container"],
|
| 41 |
)
|
| 42 |
|
bar_plot.py
CHANGED
|
@@ -3,66 +3,9 @@ import io
|
|
| 3 |
import numpy as np
|
| 4 |
import base64
|
| 5 |
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
# Color manipulation functions
|
| 8 |
-
def hex_to_rgb(hex_color):
|
| 9 |
-
hex_color = hex_color.lstrip('#')
|
| 10 |
-
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| 11 |
-
return r, g, b
|
| 12 |
-
|
| 13 |
-
def blend_colors(rgb, hex_color, blend_strength):
|
| 14 |
-
other_rgb = hex_to_rgb(hex_color)
|
| 15 |
-
return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3)))
|
| 16 |
-
|
| 17 |
-
def increase_brightness(r, g, b, factor):
|
| 18 |
-
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
|
| 19 |
-
|
| 20 |
-
def decrease_brightness(r, g, b, factor):
|
| 21 |
-
return tuple(map(lambda x: int(x * factor), (r, g, b)))
|
| 22 |
-
|
| 23 |
-
def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
|
| 24 |
-
gray = 0.299 * r + 0.587 * g + 0.114 * b
|
| 25 |
-
return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
|
| 26 |
-
|
| 27 |
-
def rgb_to_hex(r, g, b):
|
| 28 |
-
r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
|
| 29 |
-
return f"#{r:02x}{g:02x}{b:02x}"
|
| 30 |
-
|
| 31 |
-
# Color assignment function
|
| 32 |
-
def get_color_for_config(config, filtered_on_compile_mode: bool = False):
|
| 33 |
-
|
| 34 |
-
# Determine the main hue for the attention implementation
|
| 35 |
-
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
|
| 36 |
-
compilation = config["compilation"]
|
| 37 |
-
if attn_implementation == "eager":
|
| 38 |
-
main_hue = "#FF4B4BFF" if compilation else "#FF4141FF"
|
| 39 |
-
elif attn_implementation == "sdpa":
|
| 40 |
-
main_hue = {
|
| 41 |
-
None: "#4A90E2" if compilation else "#2E82E1FF",
|
| 42 |
-
"math": "#408DDB" if compilation else "#227BD3FF",
|
| 43 |
-
"flash_attention": "#35A34D" if compilation else "#219F3CFF",
|
| 44 |
-
"efficient_attention": "#605895" if compilation else "#423691FF",
|
| 45 |
-
"cudnn_attention": "#774AE2" if compilation else "#5D27DCFF",
|
| 46 |
-
}[sdpa_backend] # fmt: off
|
| 47 |
-
elif attn_implementation == "flash_attention_2":
|
| 48 |
-
main_hue = "#FFD700" if compilation else "#FFBF00FF"
|
| 49 |
-
else:
|
| 50 |
-
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
|
| 51 |
-
|
| 52 |
-
# Apply color modifications for compilation and kernelization
|
| 53 |
-
r, g, b = hex_to_rgb(main_hue)
|
| 54 |
-
if config["compilation"]:
|
| 55 |
-
delta = 0.2
|
| 56 |
-
delta += 0.2 * (len(config["compile_mode"]) - 7) / 8 if filtered_on_compile_mode else 0
|
| 57 |
-
r, g, b = increase_brightness(r, g, b, delta)
|
| 58 |
-
if config["kernelize"]:
|
| 59 |
-
pass
|
| 60 |
-
# r, g, b = blend_colors((r, g, b), "#FF00F2FF", 0.7)
|
| 61 |
-
r, g, b = decrease_brightness(r, g, b, 0.8)
|
| 62 |
-
# r, g, b = increase_saturation(r, g, b, 0.9)
|
| 63 |
-
|
| 64 |
-
# Return the color as a hex string
|
| 65 |
-
return rgb_to_hex(r, g, b)
|
| 66 |
|
| 67 |
def reorder_data(per_scenario_data: dict) -> dict:
|
| 68 |
keys = list(per_scenario_data.keys())
|
|
@@ -78,70 +21,59 @@ def reorder_data(per_scenario_data: dict) -> dict:
|
|
| 78 |
return per_scenario_data
|
| 79 |
|
| 80 |
|
| 81 |
-
def make_bar_kwargs(
|
|
|
|
|
|
|
| 82 |
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
# Set ticks and grid
|
| 106 |
-
ax.set_xticks([])
|
| 107 |
-
ax.tick_params(colors='white', labelsize=13)
|
| 108 |
-
# Truncate axis to better fit the bars
|
| 109 |
-
new_ymin, new_ymax = 1e9, -1e9
|
| 110 |
-
for h, e in zip(bar_kwargs["height"], errors):
|
| 111 |
-
new_ymin = min(new_ymin, 0.98 * (h - e))
|
| 112 |
-
new_ymax = max(new_ymax, 1.02 * (h + e))
|
| 113 |
-
ymin, ymax = ax.get_ylim()
|
| 114 |
-
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def create_matplotlib_bar_plot(per_scenario_data: dict):
|
| 118 |
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
|
| 119 |
|
| 120 |
# Create figure with dark theme - maximum size for full screen
|
| 121 |
plt.style.use('dark_background')
|
| 122 |
-
fig, axs = plt.subplots(
|
| 123 |
fig.patch.set_facecolor('#000000')
|
| 124 |
|
| 125 |
-
# Reorganize data
|
| 126 |
-
per_scenario_data = reorder_data(per_scenario_data)
|
| 127 |
-
|
| 128 |
# TTFT Plot (left)
|
| 129 |
-
ttft_bars, ttft_errors = make_bar_kwargs(
|
| 130 |
-
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")
|
| 131 |
|
| 132 |
-
# ITL Plot (right)
|
| 133 |
-
itl_bars, itl_errors = make_bar_kwargs(
|
| 134 |
-
draw_bar_plot(axs[1], itl_bars, itl_errors,
|
| 135 |
|
| 136 |
-
# E2E Plot (right)
|
| 137 |
-
e2e_bars, e2e_errors = make_bar_kwargs(
|
| 138 |
-
draw_bar_plot(axs
|
|
|
|
| 139 |
|
| 140 |
# Add common legend with full text
|
| 141 |
legend_labels = ttft_bars["label"] # Use full labels without truncation
|
| 142 |
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
|
|
|
|
|
|
|
| 143 |
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
|
| 144 |
-
bbox_to_anchor=(0.
|
| 145 |
labelcolor='white', fontsize=14)
|
| 146 |
|
| 147 |
# Save plot to bytes with high DPI for crisp text
|
|
@@ -161,3 +93,29 @@ def create_matplotlib_bar_plot(per_scenario_data: dict):
|
|
| 161 |
</div>
|
| 162 |
"""
|
| 163 |
return html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import base64
|
| 5 |
|
| 6 |
+
from plot_utils import get_color_for_config
|
| 7 |
+
from data import load_data
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def reorder_data(per_scenario_data: dict) -> dict:
|
| 11 |
keys = list(per_scenario_data.keys())
|
|
|
|
| 21 |
return per_scenario_data
|
| 22 |
|
| 23 |
|
| 24 |
+
def make_bar_kwargs(key: str) -> tuple[dict, list]:
|
| 25 |
+
# Prepare accumulators
|
| 26 |
+
current_x = 0
|
| 27 |
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
|
| 28 |
+
errors_bars = []
|
| 29 |
+
x_ticks = []
|
| 30 |
+
|
| 31 |
+
for device_name, device_data in load_data(keep_common_scenarios_only=False).items():
|
| 32 |
+
per_scenario_data = device_data.get_bar_plot_data()
|
| 33 |
+
per_scenario_data = reorder_data(per_scenario_data)
|
| 34 |
+
device_xs = []
|
| 35 |
+
|
| 36 |
+
for scenario_name, scenario_data in per_scenario_data.items():
|
| 37 |
+
bar_kwargs["x"].append(current_x)
|
| 38 |
+
bar_kwargs["height"].append(np.median(scenario_data[key]))
|
| 39 |
+
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
|
| 40 |
+
bar_kwargs["label"].append(scenario_name)
|
| 41 |
+
errors_bars.append(np.std(scenario_data[key]))
|
| 42 |
+
device_xs.append(current_x)
|
| 43 |
+
current_x += 1
|
| 44 |
+
|
| 45 |
+
x_ticks.append((np.mean(device_xs), device_name))
|
| 46 |
+
current_x += 1.5
|
| 47 |
+
return bar_kwargs, errors_bars, x_ticks
|
| 48 |
+
|
| 49 |
+
def create_matplotlib_bar_plot() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
|
| 51 |
|
| 52 |
# Create figure with dark theme - maximum size for full screen
|
| 53 |
plt.style.use('dark_background')
|
| 54 |
+
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
|
| 55 |
fig.patch.set_facecolor('#000000')
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
# TTFT Plot (left)
|
| 58 |
+
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs("ttft")
|
| 59 |
+
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
|
| 60 |
|
| 61 |
+
# # ITL Plot (right)
|
| 62 |
+
itl_bars, itl_errors, x_ticks = make_bar_kwargs("itl")
|
| 63 |
+
draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
|
| 64 |
|
| 65 |
+
# # E2E Plot (right)
|
| 66 |
+
# e2e_bars, e2e_errors = make_bar_kwargs("e2e")
|
| 67 |
+
# draw_bar_plot(axs, e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
|
| 68 |
+
plt.tight_layout()
|
| 69 |
|
| 70 |
# Add common legend with full text
|
| 71 |
legend_labels = ttft_bars["label"] # Use full labels without truncation
|
| 72 |
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
|
| 73 |
+
|
| 74 |
+
# Put a legend to the right of the current axis
|
| 75 |
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
|
| 76 |
+
bbox_to_anchor=(0.515, -0.15), facecolor='black', edgecolor='white',
|
| 77 |
labelcolor='white', fontsize=14)
|
| 78 |
|
| 79 |
# Save plot to bytes with high DPI for crisp text
|
|
|
|
| 93 |
</div>
|
| 94 |
"""
|
| 95 |
return html
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str, xticks: list[tuple[float, str]]):
|
| 99 |
+
ax.set_facecolor('#000000')
|
| 100 |
+
ax.grid(True, alpha=0.2, color='white', zorder=0)
|
| 101 |
+
# Draw bars
|
| 102 |
+
_ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
|
| 103 |
+
# Add error bars
|
| 104 |
+
ax.errorbar(
|
| 105 |
+
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
|
| 106 |
+
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
|
| 107 |
+
)
|
| 108 |
+
# Set labels and title
|
| 109 |
+
ax.set_ylabel(ylabel, color='white', fontsize=16)
|
| 110 |
+
ax.set_title(title, color='white', fontsize=18, pad=20)
|
| 111 |
+
# Set ticks and grid
|
| 112 |
+
ax.set_xticks([])
|
| 113 |
+
ax.tick_params(colors='white', labelsize=13)
|
| 114 |
+
ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
|
| 115 |
+
# Truncate axis to better fit the bars
|
| 116 |
+
new_ymin, new_ymax = 1e9, -1e9
|
| 117 |
+
for h, e in zip(bar_kwargs["height"], errors):
|
| 118 |
+
new_ymin = min(new_ymin, 0.98 * (h - e))
|
| 119 |
+
new_ymax = max(new_ymax, 1.02 * (h + e))
|
| 120 |
+
ymin, ymax = ax.get_ylim()
|
| 121 |
+
ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
|
data.py
CHANGED
|
@@ -50,3 +50,15 @@ class ModelBenchmarkData:
|
|
| 50 |
per_scenario_data = {k: per_scenario_data[k] for k, _ in collapsed_keys.values()}
|
| 51 |
|
| 52 |
return per_scenario_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
per_scenario_data = {k: per_scenario_data[k] for k, _ in collapsed_keys.values()}
|
| 51 |
|
| 52 |
return per_scenario_data
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_data(keep_common_scenarios_only: bool = True) -> dict[str, ModelBenchmarkData]:
|
| 56 |
+
data = {
|
| 57 |
+
"MI325": ModelBenchmarkData("mi325_data.json"),
|
| 58 |
+
"H100": ModelBenchmarkData("h100_data.json"),
|
| 59 |
+
}
|
| 60 |
+
if keep_common_scenarios_only:
|
| 61 |
+
common_scenarios = set(data["MI325"].data.keys()) & set(data["H100"].data.keys())
|
| 62 |
+
for device_name, device_data in data.items():
|
| 63 |
+
device_data.data = {k: v for k, v in device_data.data.items() if k in common_scenarios}
|
| 64 |
+
return data
|
plot_utils.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Color manipulation functions
|
| 2 |
+
def hex_to_rgb(hex_color):
|
| 3 |
+
hex_color = hex_color.lstrip('#')
|
| 4 |
+
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| 5 |
+
return r, g, b
|
| 6 |
+
|
| 7 |
+
def blend_colors(rgb, hex_color, blend_strength):
|
| 8 |
+
other_rgb = hex_to_rgb(hex_color)
|
| 9 |
+
return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3)))
|
| 10 |
+
|
| 11 |
+
def increase_brightness(r, g, b, factor):
|
| 12 |
+
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
|
| 13 |
+
|
| 14 |
+
def decrease_brightness(r, g, b, factor):
|
| 15 |
+
return tuple(map(lambda x: int(x * factor), (r, g, b)))
|
| 16 |
+
|
| 17 |
+
def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
|
| 18 |
+
gray = 0.299 * r + 0.587 * g + 0.114 * b
|
| 19 |
+
return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
|
| 20 |
+
|
| 21 |
+
def rgb_to_hex(r, g, b):
|
| 22 |
+
r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
|
| 23 |
+
return f"#{r:02x}{g:02x}{b:02x}"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Color assignment function
|
| 27 |
+
def get_color_for_config(config: dict):
|
| 28 |
+
# Determine the main hue for the attention implementation
|
| 29 |
+
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
|
| 30 |
+
compilation = config["compilation"]
|
| 31 |
+
if attn_implementation == "eager":
|
| 32 |
+
main_hue = "#FF4B4BFF" if compilation else "#FF4141FF"
|
| 33 |
+
elif attn_implementation == "sdpa":
|
| 34 |
+
main_hue = {
|
| 35 |
+
None: "#4A90E2" if compilation else "#2E82E1FF",
|
| 36 |
+
"math": "#408DDB" if compilation else "#227BD3FF",
|
| 37 |
+
"flash_attention": "#35A34D" if compilation else "#219F3CFF",
|
| 38 |
+
"efficient_attention": "#605895" if compilation else "#423691FF",
|
| 39 |
+
"cudnn_attention": "#774AE2" if compilation else "#5D27DCFF",
|
| 40 |
+
}[sdpa_backend] # fmt: off
|
| 41 |
+
elif attn_implementation == "flash_attention_2":
|
| 42 |
+
main_hue = "#FFD700" if compilation else "#FFBF00FF"
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
|
| 45 |
+
# Apply color modifications for compilation and kernelization
|
| 46 |
+
r, g, b = hex_to_rgb(main_hue)
|
| 47 |
+
if config["compilation"]:
|
| 48 |
+
delta = 0.2 + 0.2 * (len(config["compile_mode"]) - 7) / 8
|
| 49 |
+
r, g, b = increase_brightness(r, g, b, delta)
|
| 50 |
+
if config["kernelize"]:
|
| 51 |
+
r, g, b = decrease_brightness(r, g, b, 0.8)
|
| 52 |
+
# Return the color as a hex string
|
| 53 |
+
return rgb_to_hex(r, g, b)
|