Spaces:
Running
Running
Better look
Browse files- bar_plot.py +28 -8
- data.py +11 -1
- plot_utils.py +22 -24
bar_plot.py
CHANGED
|
@@ -21,14 +21,26 @@ def reorder_data(per_scenario_data: dict) -> dict:
|
|
| 21 |
return per_scenario_data
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Prepare accumulators
|
| 26 |
current_x = 0
|
| 27 |
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
|
| 28 |
errors_bars = []
|
| 29 |
x_ticks = []
|
| 30 |
|
| 31 |
-
for device_name, device_data in
|
| 32 |
per_scenario_data = device_data.get_bar_plot_data()
|
| 33 |
per_scenario_data = reorder_data(per_scenario_data)
|
| 34 |
device_xs = []
|
|
@@ -37,7 +49,7 @@ def make_bar_kwargs(key: str) -> tuple[dict, list]:
|
|
| 37 |
bar_kwargs["x"].append(current_x)
|
| 38 |
bar_kwargs["height"].append(np.median(scenario_data[key]))
|
| 39 |
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
|
| 40 |
-
bar_kwargs["label"].append(
|
| 41 |
errors_bars.append(np.std(scenario_data[key]))
|
| 42 |
device_xs.append(current_x)
|
| 43 |
current_x += 1
|
|
@@ -54,12 +66,19 @@ def create_matplotlib_bar_plot() -> None:
|
|
| 54 |
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
|
| 55 |
fig.patch.set_facecolor('#000000')
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# TTFT Plot (left)
|
| 58 |
-
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs("ttft")
|
| 59 |
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
|
| 60 |
|
| 61 |
# # ITL Plot (right)
|
| 62 |
-
itl_bars, itl_errors, x_ticks = make_bar_kwargs("itl")
|
| 63 |
draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
|
| 64 |
|
| 65 |
# # E2E Plot (right)
|
|
@@ -68,8 +87,9 @@ def create_matplotlib_bar_plot() -> None:
|
|
| 68 |
plt.tight_layout()
|
| 69 |
|
| 70 |
# Add common legend with full text
|
| 71 |
-
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
# Put a legend to the right of the current axis
|
| 75 |
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
|
|
@@ -103,7 +123,7 @@ def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylab
|
|
| 103 |
# Add error bars
|
| 104 |
ax.errorbar(
|
| 105 |
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
|
| 106 |
-
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
|
| 107 |
)
|
| 108 |
# Set labels and title
|
| 109 |
ax.set_ylabel(ylabel, color='white', fontsize=16)
|
|
|
|
| 21 |
return per_scenario_data
|
| 22 |
|
| 23 |
|
| 24 |
+
def infer_bar_label(config: dict) -> str:
|
| 25 |
+
"""Format legend labels to be more readable."""
|
| 26 |
+
attn_implementation = {
|
| 27 |
+
"flash_attention_2": "Flash attention",
|
| 28 |
+
"sdpa": "SDPA",
|
| 29 |
+
"eager": "Eager",
|
| 30 |
+
}[config["attn_implementation"]]
|
| 31 |
+
compile = "compiled" if config["compilation"] else "no compile"
|
| 32 |
+
kernels = "kernelized" if config["kernelize"] else "no kernels"
|
| 33 |
+
return f"{attn_implementation}, {compile}, {kernels}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
|
| 37 |
# Prepare accumulators
|
| 38 |
current_x = 0
|
| 39 |
bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
|
| 40 |
errors_bars = []
|
| 41 |
x_ticks = []
|
| 42 |
|
| 43 |
+
for device_name, device_data in per_device_data.items():
|
| 44 |
per_scenario_data = device_data.get_bar_plot_data()
|
| 45 |
per_scenario_data = reorder_data(per_scenario_data)
|
| 46 |
device_xs = []
|
|
|
|
| 49 |
bar_kwargs["x"].append(current_x)
|
| 50 |
bar_kwargs["height"].append(np.median(scenario_data[key]))
|
| 51 |
bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
|
| 52 |
+
bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
|
| 53 |
errors_bars.append(np.std(scenario_data[key]))
|
| 54 |
device_xs.append(current_x)
|
| 55 |
current_x += 1
|
|
|
|
| 66 |
fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
|
| 67 |
fig.patch.set_facecolor('#000000')
|
| 68 |
|
| 69 |
+
# Load and sanitize data
|
| 70 |
+
per_device_data = load_data()
|
| 71 |
+
batch_sizes = {name: device_data.get_main_batch_size() for name, device_data in per_device_data.items()}
|
| 72 |
+
if len(set(batch_sizes.values())) > 1:
|
| 73 |
+
fig.suptitle(f"Unmatched batch sizes: {batch_sizes}", color='white', fontsize=18, pad=20)
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
# TTFT Plot (left)
|
| 77 |
+
ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
|
| 78 |
draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
|
| 79 |
|
| 80 |
# # ITL Plot (right)
|
| 81 |
+
itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
|
| 82 |
draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
|
| 83 |
|
| 84 |
# # E2E Plot (right)
|
|
|
|
| 87 |
plt.tight_layout()
|
| 88 |
|
| 89 |
# Add common legend with full text
|
| 90 |
+
unique_bars = len(ttft_bars["label"]) // 2
|
| 91 |
+
legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
|
| 92 |
+
legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
|
| 93 |
|
| 94 |
# Put a legend to the right of the current axis
|
| 95 |
fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
|
|
|
|
| 123 |
# Add error bars
|
| 124 |
ax.errorbar(
|
| 125 |
bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
|
| 126 |
+
fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
|
| 127 |
)
|
| 128 |
# Set labels and title
|
| 129 |
ax.set_ylabel(ylabel, color='white', fontsize=16)
|
data.py
CHANGED
|
@@ -25,6 +25,16 @@ class ModelBenchmarkData:
|
|
| 25 |
num_tokens = len(measures["t_tokens"]) - 1
|
| 26 |
return delta_t / num_tokens
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
|
| 29 |
# Gather data for each scenario
|
| 30 |
per_scenario_data = {}
|
|
@@ -52,7 +62,7 @@ class ModelBenchmarkData:
|
|
| 52 |
return per_scenario_data
|
| 53 |
|
| 54 |
|
| 55 |
-
def load_data(keep_common_scenarios_only: bool =
|
| 56 |
data = {
|
| 57 |
"MI325": ModelBenchmarkData("mi325_data.json"),
|
| 58 |
"H100": ModelBenchmarkData("h100_data.json"),
|
|
|
|
| 25 |
num_tokens = len(measures["t_tokens"]) - 1
|
| 26 |
return delta_t / num_tokens
|
| 27 |
|
| 28 |
+
def get_main_batch_size(self) -> int:
|
| 29 |
+
batch_sizes = {}
|
| 30 |
+
for cfg_name, data in self.data.items():
|
| 31 |
+
for measure in data["measures"]:
|
| 32 |
+
bs = measure["batch_size"]
|
| 33 |
+
if bs not in batch_sizes:
|
| 34 |
+
batch_sizes[bs] = 0
|
| 35 |
+
batch_sizes[bs] += 1
|
| 36 |
+
return max(batch_sizes, key=batch_sizes.get)
|
| 37 |
+
|
| 38 |
def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
|
| 39 |
# Gather data for each scenario
|
| 40 |
per_scenario_data = {}
|
|
|
|
| 62 |
return per_scenario_data
|
| 63 |
|
| 64 |
|
| 65 |
+
def load_data(keep_common_scenarios_only: bool = False) -> dict[str, ModelBenchmarkData]:
|
| 66 |
data = {
|
| 67 |
"MI325": ModelBenchmarkData("mi325_data.json"),
|
| 68 |
"H100": ModelBenchmarkData("h100_data.json"),
|
plot_utils.py
CHANGED
|
@@ -4,9 +4,11 @@ def hex_to_rgb(hex_color):
|
|
| 4 |
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| 5 |
return r, g, b
|
| 6 |
|
| 7 |
-
def blend_colors(
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def increase_brightness(r, g, b, factor):
|
| 12 |
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
|
|
@@ -25,29 +27,25 @@ def rgb_to_hex(r, g, b):
|
|
| 25 |
|
| 26 |
# Color assignment function
|
| 27 |
def get_color_for_config(config: dict):
|
| 28 |
-
# Determine the main hue for the attention implementation
|
| 29 |
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
if attn_implementation == "eager":
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
elif attn_implementation == "flash_attention_2":
|
| 42 |
-
|
| 43 |
else:
|
| 44 |
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
if config["compilation"]:
|
| 48 |
-
delta = 0.2 + 0.2 * (len(config["compile_mode"]) - 7) / 8
|
| 49 |
-
r, g, b = increase_brightness(r, g, b, delta)
|
| 50 |
-
if config["kernelize"]:
|
| 51 |
-
r, g, b = decrease_brightness(r, g, b, 0.8)
|
| 52 |
-
# Return the color as a hex string
|
| 53 |
-
return rgb_to_hex(r, g, b)
|
|
|
|
| 4 |
r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
|
| 5 |
return r, g, b
|
| 6 |
|
| 7 |
+
def blend_colors(color1, color2, blend_strength):
|
| 8 |
+
rgb1 = hex_to_rgb(color1)
|
| 9 |
+
rgb2 = hex_to_rgb(color2)
|
| 10 |
+
new_color = tuple(map(lambda i: int(rgb1[i] * blend_strength + rgb2[i] * (1 - blend_strength)), range(3)))
|
| 11 |
+
return rgb_to_hex(*new_color)
|
| 12 |
|
| 13 |
def increase_brightness(r, g, b, factor):
|
| 14 |
return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
|
|
|
|
| 27 |
|
| 28 |
# Color assignment function
|
| 29 |
def get_color_for_config(config: dict):
|
|
|
|
| 30 |
attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
|
| 31 |
+
barycenter = 1 - (config["compilation"] + 2 * config["kernelize"]) / 3
|
| 32 |
+
|
| 33 |
+
# Eager
|
| 34 |
if attn_implementation == "eager":
|
| 35 |
+
color = blend_colors("#FA7F7FFF", "#FF2D2DFF", barycenter)
|
| 36 |
+
|
| 37 |
+
# SDPA - math
|
| 38 |
+
elif attn_implementation == "sdpa" and sdpa_backend == "math":
|
| 39 |
+
color = blend_colors("#7AB8FFFF", "#277CD0FF", barycenter)
|
| 40 |
+
|
| 41 |
+
# SDPA - flash attention
|
| 42 |
+
elif attn_implementation == "sdpa" and sdpa_backend == "flash_attention":
|
| 43 |
+
color = blend_colors("#81FF9CFF", "#219F3CFF", barycenter)
|
| 44 |
+
|
| 45 |
+
# Flash attention
|
| 46 |
elif attn_implementation == "flash_attention_2":
|
| 47 |
+
color = blend_colors("#FFDB70FF", "#DFD002FF", barycenter)
|
| 48 |
else:
|
| 49 |
raise ValueError(f"Unknown attention implementation: {attn_implementation}")
|
| 50 |
+
|
| 51 |
+
return color
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|