ror HF Staff commited on
Commit
79e7993
·
1 Parent(s): 59644f0

Probably v1

Browse files
Files changed (6) hide show
  1. app.py +15 -4
  2. bar_plot.py +109 -34
  3. data.py +30 -12
  4. h100_data.json +2 -2
  5. mi325_data.json +2 -2
  6. plot_utils.py +26 -10
app.py CHANGED
@@ -6,7 +6,7 @@ from bar_plot import create_matplotlib_bar_plot
6
 
7
 
8
  # Configure matplotlib for better performance
9
- matplotlib.use('Agg')
10
  plt.ioff()
11
 
12
 
@@ -24,14 +24,25 @@ def refresh_plot():
24
  sidebar_text = "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
25
  return create_matplotlib_bar_plot(), sidebar_text
26
 
 
27
  # Create Gradio interface
28
- with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
 
 
29
  with gr.Row():
30
  # Sidebar
31
  with gr.Column(scale=1, elem_classes=["sidebar"]):
32
  gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
33
- description = gr.Markdown("**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*", elem_classes=["sidebar-description"])
34
- summary_btn = gr.Button("summary\n📊", variant="primary", size="lg", elem_classes=["summary-button"])
 
 
 
 
 
 
 
 
35
 
36
  # Main plot area
37
  with gr.Column(elem_classes=["main-content"]):
 
6
 
7
 
8
  # Configure matplotlib for better performance
9
+ matplotlib.use("Agg")
10
  plt.ioff()
11
 
12
 
 
24
  sidebar_text = "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
25
  return create_matplotlib_bar_plot(), sidebar_text
26
 
27
+
28
  # Create Gradio interface
29
+ with gr.Blocks(
30
+ title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True
31
+ ) as demo:
32
  with gr.Row():
33
  # Sidebar
34
  with gr.Column(scale=1, elem_classes=["sidebar"]):
35
  gr.Markdown("# 🤖 TCID", elem_classes=["sidebar-title"])
36
+ description = gr.Markdown(
37
+ "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*",
38
+ elem_classes=["sidebar-description"],
39
+ )
40
+ summary_btn = gr.Button(
41
+ "summary\n📊",
42
+ variant="primary",
43
+ size="lg",
44
+ elem_classes=["summary-button"],
45
+ )
46
 
47
  # Main plot area
48
  with gr.Column(elem_classes=["main-content"]):
bar_plot.py CHANGED
@@ -1,10 +1,11 @@
1
  import matplotlib.pyplot as plt
 
2
  import io
3
  import numpy as np
4
  import base64
5
 
6
  from plot_utils import get_color_for_config
7
- from data import load_data
8
 
9
 
10
  def reorder_data(per_scenario_data: dict) -> dict:
@@ -13,8 +14,25 @@ def reorder_data(per_scenario_data: dict) -> dict:
13
  def sorting_fn(key: str) -> float:
14
  cfg = per_scenario_data[key]["config"]
15
  attn_implementation = cfg["attn_implementation"]
16
- attn_implementation_prio = {"flash_attention_2": 0, "sdpa": 1, "eager": 2}[attn_implementation]
17
- return attn_implementation_prio, cfg["sdpa_backend"], cfg["kernelize"], cfg["compilation"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  keys.sort(key=sorting_fn)
20
  per_scenario_data = {k: per_scenario_data[k] for k in keys}
@@ -27,6 +45,8 @@ def infer_bar_label(config: dict) -> str:
27
  attn_implementation = "Eager"
28
  elif config["attn_implementation"] == "flash_attention_2":
29
  attn_implementation = "Flash attention"
 
 
30
  elif config["attn_implementation"] == "sdpa":
31
  attn_implementation = {
32
  "flash_attention": "SDPA (flash attention)",
@@ -37,15 +57,24 @@ def infer_bar_label(config: dict) -> str:
37
  else:
38
  attn_implementation = "Unknown"
39
 
40
- compile = "compiled" if config["compilation"] else "no compile"
41
  kernels = "kernelized" if config["kernelize"] else "no kernels"
42
  return f"{attn_implementation}, {compile}, {kernels}"
43
 
44
 
45
- def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
 
 
 
 
 
 
 
 
 
46
  # Prepare accumulators
47
  current_x = 0
48
- bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
49
  errors_bars = []
50
  x_ticks = []
51
 
@@ -53,12 +82,13 @@ def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
53
  per_scenario_data = device_data.get_bar_plot_data()
54
  per_scenario_data = reorder_data(per_scenario_data)
55
  device_xs = []
56
-
57
- for scenario_name, scenario_data in per_scenario_data.items():
58
  bar_kwargs["x"].append(current_x)
59
  bar_kwargs["height"].append(np.median(scenario_data[key]))
60
  bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
61
  bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
 
62
  errors_bars.append(np.std(scenario_data[key]))
63
  device_xs.append(current_x)
64
  current_x += 1
@@ -67,13 +97,14 @@ def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
67
  current_x += 1.5
68
  return bar_kwargs, errors_bars, x_ticks
69
 
 
70
  def create_matplotlib_bar_plot() -> None:
71
  """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
72
 
73
  # Create figure with dark theme - maximum size for full screen
74
- plt.style.use('dark_background')
75
  fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16
76
- fig.patch.set_facecolor('#000000')
77
 
78
  # Load data and ensure coherence
79
  per_device_data = load_data()
@@ -82,11 +113,16 @@ def create_matplotlib_bar_plot() -> None:
82
  bs, seqlen, n_tok = device_data.ensure_coherence()
83
  if batch_size is None:
84
  batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
85
- elif (bs, seqlen, n_tok) != (batch_size, sequence_length, num_tokens_to_generate):
 
 
 
 
86
  fig.suptitle(
87
  f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
88
  f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
89
- color='white', fontsize=18
 
90
  )
91
  return None
92
 
@@ -99,27 +135,58 @@ def create_matplotlib_bar_plot() -> None:
99
  draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
100
 
101
  # Title and tight layout
102
- title = "\n".join([
103
- "Time to first token and inter-token latency (lower is better)",
104
- f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
105
- ])
106
- fig.suptitle(title, color='white', fontsize=20, y=1.005, linespacing=1.5)
 
 
107
  plt.tight_layout()
108
 
109
  # Add common legend with full text
110
- unique_bars = len(ttft_bars["label"]) // 2
111
- legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
112
- legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Put a legend to the right of the current axis
115
- fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
116
- bbox_to_anchor=(0.515, -0.11), facecolor='black', edgecolor='white',
117
- labelcolor='white', fontsize=14)
 
 
 
 
 
 
 
118
 
119
  # Save plot to bytes with high DPI for crisp text
120
  buffer = io.BytesIO()
121
- plt.savefig(buffer, format='png', facecolor='#000000',
122
- bbox_inches='tight', dpi=150)
123
  buffer.seek(0)
124
 
125
  # Convert to base64 for HTML embedding
@@ -136,26 +203,34 @@ def create_matplotlib_bar_plot() -> None:
136
 
137
 
138
  def draw_bar_plot(
139
- ax: plt.Axes,
140
  bar_kwargs: dict,
141
  errors: list,
142
  ylabel: str,
143
  xticks: list[tuple[float, str]],
144
  adapt_ylim: bool = False,
145
  ) -> None:
146
- ax.set_facecolor('#000000')
147
- ax.grid(True, alpha=0.3, color='white', axis='y', zorder=0)
148
  # Draw bars
149
- _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
150
  # Add error bars
151
  ax.errorbar(
152
- bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
153
- fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
 
 
 
 
 
 
 
 
154
  )
155
  # Set labels, ticks and grid
156
- ax.set_ylabel(ylabel, color='white', fontsize=16)
157
  ax.set_xticks([])
158
- ax.tick_params(colors='white', labelsize=13)
159
  ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
160
  # Truncate axis to better fit the bars
161
  if adapt_ylim:
@@ -163,5 +238,5 @@ def draw_bar_plot(
163
  for h, e in zip(bar_kwargs["height"], errors):
164
  new_ymin = min(new_ymin, 0.98 * (h - e))
165
  new_ymax = max(new_ymax, 1.02 * (h + e))
166
- ymin, ymax = ax.get_ylim()
167
  ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
 
1
  import matplotlib.pyplot as plt
2
+ import matplotlib.patches as mpatches
3
  import io
4
  import numpy as np
5
  import base64
6
 
7
  from plot_utils import get_color_for_config
8
+ from data import load_data, ModelBenchmarkData
9
 
10
 
11
  def reorder_data(per_scenario_data: dict) -> dict:
 
14
  def sorting_fn(key: str) -> float:
15
  cfg = per_scenario_data[key]["config"]
16
  attn_implementation = cfg["attn_implementation"]
17
+ attn_impl_prio = {
18
+ "flash_attention_2": 0,
19
+ "sdpa": 1,
20
+ "eager": 2,
21
+ "flex_attention": 3,
22
+ }[attn_implementation]
23
+ sdpa_backend_prio = {
24
+ None: -1,
25
+ "flash_attention": 0,
26
+ "math": 1,
27
+ "efficient_attention": 2,
28
+ "cudnn_attention": 3,
29
+ }[cfg["sdpa_backend"]]
30
+ return (
31
+ attn_impl_prio,
32
+ sdpa_backend_prio,
33
+ cfg["kernelize"],
34
+ cfg["compile_mode"] is not None,
35
+ )
36
 
37
  keys.sort(key=sorting_fn)
38
  per_scenario_data = {k: per_scenario_data[k] for k in keys}
 
45
  attn_implementation = "Eager"
46
  elif config["attn_implementation"] == "flash_attention_2":
47
  attn_implementation = "Flash attention"
48
+ elif config["attn_implementation"] == "flex_attention":
49
+ attn_implementation = "Flex attention"
50
  elif config["attn_implementation"] == "sdpa":
51
  attn_implementation = {
52
  "flash_attention": "SDPA (flash attention)",
 
57
  else:
58
  attn_implementation = "Unknown"
59
 
60
+ compile = "compiled" if config["compile_mode"] is not None else "no compile"
61
  kernels = "kernelized" if config["kernelize"] else "no kernels"
62
  return f"{attn_implementation}, {compile}, {kernels}"
63
 
64
 
65
+ def infer_bar_hatch(config: dict) -> str:
66
+ if config["compile_mode"] is not None:
67
+ return "/"
68
+ else:
69
+ return ""
70
+
71
+
72
+ def make_bar_kwargs(
73
+ per_device_data: dict[str, ModelBenchmarkData], key: str
74
+ ) -> tuple[dict, list]:
75
  # Prepare accumulators
76
  current_x = 0
77
+ bar_kwargs = {"x": [], "height": [], "color": [], "label": [], "hatch": []}
78
  errors_bars = []
79
  x_ticks = []
80
 
 
82
  per_scenario_data = device_data.get_bar_plot_data()
83
  per_scenario_data = reorder_data(per_scenario_data)
84
  device_xs = []
85
+
86
+ for scenario_name, scenario_data in per_scenario_data.items():
87
  bar_kwargs["x"].append(current_x)
88
  bar_kwargs["height"].append(np.median(scenario_data[key]))
89
  bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
90
  bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
91
+ bar_kwargs["hatch"].append(infer_bar_hatch(scenario_data["config"]))
92
  errors_bars.append(np.std(scenario_data[key]))
93
  device_xs.append(current_x)
94
  current_x += 1
 
97
  current_x += 1.5
98
  return bar_kwargs, errors_bars, x_ticks
99
 
100
+
101
  def create_matplotlib_bar_plot() -> None:
102
  """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
103
 
104
  # Create figure with dark theme - maximum size for full screen
105
+ plt.style.use("dark_background")
106
  fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16
107
+ fig.patch.set_facecolor("#000000")
108
 
109
  # Load data and ensure coherence
110
  per_device_data = load_data()
 
113
  bs, seqlen, n_tok = device_data.ensure_coherence()
114
  if batch_size is None:
115
  batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok
116
+ elif (bs, seqlen, n_tok) != (
117
+ batch_size,
118
+ sequence_length,
119
+ num_tokens_to_generate,
120
+ ):
121
  fig.suptitle(
122
  f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} "
123
  f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}",
124
+ color="white",
125
+ fontsize=18,
126
  )
127
  return None
128
 
 
135
  draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks)
136
 
137
  # Title and tight layout
138
+ title = "\n".join(
139
+ [
140
+ "Time to first token and inter-token latency (lower is better)",
141
+ f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}",
142
+ ]
143
+ )
144
+ fig.suptitle(title, color="white", fontsize=20, y=1.005, linespacing=1.5)
145
  plt.tight_layout()
146
 
147
  # Add common legend with full text
148
+ legend_labels, legend_colors, legend_hatches = [], [], []
149
+ for label, color, hatch in zip(
150
+ ttft_bars["label"], ttft_bars["color"], ttft_bars["hatch"]
151
+ ):
152
+ if label not in legend_labels:
153
+ legend_labels.append(label)
154
+ legend_colors.append(color)
155
+ legend_hatches.append(hatch)
156
+
157
+ # Make sure all attn implementations are equally represented
158
+ # implementations = {}
159
+ # for label, color, hatch in zip(legend_labels, legend_colors, legend_hatches):
160
+ # impl = label.split(",")[0]
161
+ # implementations[impl] = implementations.get(impl, []) + [(label, color, hatch)]
162
+
163
+ # n_max = max(len(impls) for impls in implementations.values())
164
+ # for label_color_pairs in implementations.values():
165
+ # for _ in range(len(label_color_pairs), n_max):
166
+ # label_color_pairs.append(("", "#000000"))
167
+
168
+ # legend_labels, legend_colors = zip(*sum(implementations.values(), []))
169
+
170
+ legend_handles = [
171
+ mpatches.Patch(facecolor=color, hatch=hatch, label=label, edgecolor="white")
172
+ for color, hatch, label in zip(legend_colors, legend_hatches, legend_labels)
173
+ ]
174
 
175
  # Put a legend to the right of the current axis
176
+ fig.legend(
177
+ handles=legend_handles,
178
+ loc="lower center",
179
+ ncol=4,
180
+ bbox_to_anchor=(0.515, -0.11),
181
+ facecolor="black",
182
+ edgecolor="white",
183
+ labelcolor="white",
184
+ fontsize=14,
185
+ )
186
 
187
  # Save plot to bytes with high DPI for crisp text
188
  buffer = io.BytesIO()
189
+ plt.savefig(buffer, format="png", facecolor="#000000", bbox_inches="tight", dpi=150)
 
190
  buffer.seek(0)
191
 
192
  # Convert to base64 for HTML embedding
 
203
 
204
 
205
  def draw_bar_plot(
206
+ ax: plt.Axes,
207
  bar_kwargs: dict,
208
  errors: list,
209
  ylabel: str,
210
  xticks: list[tuple[float, str]],
211
  adapt_ylim: bool = False,
212
  ) -> None:
213
+ ax.set_facecolor("#000000")
214
+ ax.grid(True, alpha=0.3, color="white", axis="y", zorder=0)
215
  # Draw bars
216
+ _ = ax.bar(**bar_kwargs, width=1.0, edgecolor="white", linewidth=1, zorder=3)
217
  # Add error bars
218
  ax.errorbar(
219
+ bar_kwargs["x"],
220
+ bar_kwargs["height"],
221
+ yerr=errors,
222
+ fmt="none",
223
+ ecolor="white",
224
+ alpha=0.8,
225
+ elinewidth=1.5,
226
+ capthick=1.5,
227
+ capsize=4,
228
+ zorder=4,
229
  )
230
  # Set labels, ticks and grid
231
+ ax.set_ylabel(ylabel, color="white", fontsize=16)
232
  ax.set_xticks([])
233
+ ax.tick_params(colors="white", labelsize=13)
234
  ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
235
  # Truncate axis to better fit the bars
236
  if adapt_ylim:
 
238
  for h, e in zip(bar_kwargs["height"], errors):
239
  new_ymin = min(new_ymin, 0.98 * (h - e))
240
  new_ymax = max(new_ymax, 1.02 * (h + e))
241
+ ymin, ymax = ax.get_ylim()
242
  ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
data.py CHANGED
@@ -1,20 +1,22 @@
1
  import json
 
 
2
  import numpy as np
3
- from typing import Optional
4
 
5
  def make_id(config: dict, keys_to_ignore: list[str]) -> str:
6
  keys = sorted(set(config.keys()))
7
  return "_".join(str(config[k]) for k in keys if k not in keys_to_ignore)
8
 
9
- class ModelBenchmarkData:
10
 
 
11
  def __init__(self, json_path: str) -> None:
12
  with open(json_path, "r") as f:
13
  self.data: dict = json.load(f)
14
 
15
  def compute_ttft(self, measures: dict) -> list[float]:
16
  return [dts[0] for dts in measures["dt_tokens"]]
17
-
18
  def compute_itl(self, measures: dict) -> list[float]:
19
  return [
20
  (dts[-1] - dts[0]) / (len(dts) - 1) if len(dts) > 2 else 0
@@ -34,7 +36,11 @@ class ModelBenchmarkData:
34
  all_hyperparams = set()
35
  for data in self.data.values():
36
  config = data["config"]
37
- hyperparams = (config["batch_size"], config["sequence_length"], config["num_tokens_to_generate"])
 
 
 
 
38
  all_hyperparams.add(hyperparams)
39
  if len(all_hyperparams) > 1:
40
  raise ValueError(
@@ -42,7 +48,9 @@ class ModelBenchmarkData:
42
  )
43
  return all_hyperparams.pop()
44
 
45
- def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
 
 
46
  # Gather data for each scenario
47
  per_scenario_data = {}
48
  for cfg_name, data in self.data.items():
@@ -57,25 +65,35 @@ class ModelBenchmarkData:
57
  collapsed_keys = {}
58
  for cfg_name, data in per_scenario_data.items():
59
  keys_to_ignore = ["name"]
60
- keys_to_ignore += (["use_cache"] if collapse_on_cache else [])
61
- keys_to_ignore += (["compile_mode"] if collapse_on_compile_mode else [])
62
- cfg_id = make_id(data["config"], keys_to_ignore)
 
 
63
  cfg_e2e = np.mean(data["e2e"])
64
  other_name, other_e2e = collapsed_keys.get(cfg_id, (None, 1e16))
65
  if cfg_e2e < other_e2e:
66
  collapsed_keys[cfg_id] = (cfg_name, cfg_e2e)
67
- per_scenario_data = {k: per_scenario_data[k] for k, _ in collapsed_keys.values()}
 
 
68
 
69
  return per_scenario_data
70
 
71
 
72
- def load_data(keep_common_scenarios_only: bool = False) -> dict[str, ModelBenchmarkData]:
 
 
73
  data = {
74
  "MI325": ModelBenchmarkData("mi325_data.json"),
75
  "H100": ModelBenchmarkData("h100_data.json"),
76
  }
77
  if keep_common_scenarios_only:
78
- common_scenarios = set(data["MI325"].data.keys()) & set(data["H100"].data.keys())
 
 
79
  for device_name, device_data in data.items():
80
- device_data.data = {k: v for k, v in device_data.data.items() if k in common_scenarios}
 
 
81
  return data
 
1
  import json
2
+ from copy import deepcopy
3
+
4
  import numpy as np
5
+
6
 
7
  def make_id(config: dict, keys_to_ignore: list[str]) -> str:
8
  keys = sorted(set(config.keys()))
9
  return "_".join(str(config[k]) for k in keys if k not in keys_to_ignore)
10
 
 
11
 
12
+ class ModelBenchmarkData:
13
  def __init__(self, json_path: str) -> None:
14
  with open(json_path, "r") as f:
15
  self.data: dict = json.load(f)
16
 
17
  def compute_ttft(self, measures: dict) -> list[float]:
18
  return [dts[0] for dts in measures["dt_tokens"]]
19
+
20
  def compute_itl(self, measures: dict) -> list[float]:
21
  return [
22
  (dts[-1] - dts[0]) / (len(dts) - 1) if len(dts) > 2 else 0
 
36
  all_hyperparams = set()
37
  for data in self.data.values():
38
  config = data["config"]
39
+ hyperparams = (
40
+ config["batch_size"],
41
+ config["sequence_length"],
42
+ config["num_tokens_to_generate"],
43
+ )
44
  all_hyperparams.add(hyperparams)
45
  if len(all_hyperparams) > 1:
46
  raise ValueError(
 
48
  )
49
  return all_hyperparams.pop()
50
 
51
+ def get_bar_plot_data(
52
+ self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True
53
+ ) -> dict:
54
  # Gather data for each scenario
55
  per_scenario_data = {}
56
  for cfg_name, data in self.data.items():
 
65
  collapsed_keys = {}
66
  for cfg_name, data in per_scenario_data.items():
67
  keys_to_ignore = ["name"]
68
+ keys_to_ignore += ["use_cache"] if collapse_on_cache else []
69
+ keys_to_ignore += ["compile_mode"] if collapse_on_compile_mode else []
70
+ duply_cfg = deepcopy(data["config"])
71
+ duply_cfg["compiled"] = duply_cfg["compile_mode"] is not None
72
+ cfg_id = make_id(duply_cfg, keys_to_ignore)
73
  cfg_e2e = np.mean(data["e2e"])
74
  other_name, other_e2e = collapsed_keys.get(cfg_id, (None, 1e16))
75
  if cfg_e2e < other_e2e:
76
  collapsed_keys[cfg_id] = (cfg_name, cfg_e2e)
77
+ per_scenario_data = {
78
+ k: per_scenario_data[k] for k, _ in collapsed_keys.values()
79
+ }
80
 
81
  return per_scenario_data
82
 
83
 
84
+ def load_data(
85
+ keep_common_scenarios_only: bool = False,
86
+ ) -> dict[str, ModelBenchmarkData]:
87
  data = {
88
  "MI325": ModelBenchmarkData("mi325_data.json"),
89
  "H100": ModelBenchmarkData("h100_data.json"),
90
  }
91
  if keep_common_scenarios_only:
92
+ common_scenarios = set(data["MI325"].data.keys()) & set(
93
+ data["H100"].data.keys()
94
+ )
95
  for device_name, device_data in data.items():
96
+ device_data.data = {
97
+ k: v for k, v in device_data.data.items() if k in common_scenarios
98
+ }
99
  return data
h100_data.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2f843d3f436d7919f67c071824fde3bc247b7e3a096a92d8abb191988d86a9d2
3
- size 2627476
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee66b31725d29b9faaf38a437f4ca3ba8251f3ddc6eb9733650dac8b414bd73e
3
+ size 1848790
mi325_data.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:628d50b1cbd1eca36f26d3e2c3e31d3062996ffbeb203f3c7ee04b604fec039e
3
- size 2674620
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e85e274fdf29798e4e1093df3beec82d1369fe306e720811670fe68176e9bc51
3
+ size 1872352
plot_utils.py CHANGED
@@ -1,25 +1,35 @@
1
  # Color manipulation functions
2
  def hex_to_rgb(hex_color):
3
- hex_color = hex_color.lstrip('#')
4
  r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
5
  return r, g, b
6
 
 
7
  def blend_colors(color1, color2, blend_strength):
8
  rgb1 = hex_to_rgb(color1)
9
  rgb2 = hex_to_rgb(color2)
10
- new_color = tuple(map(lambda i: int(rgb1[i] * blend_strength + rgb2[i] * (1 - blend_strength)), range(3)))
 
 
 
 
 
11
  return rgb_to_hex(*new_color)
12
 
 
13
  def increase_brightness(r, g, b, factor):
14
  return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
15
 
 
16
  def decrease_brightness(r, g, b, factor):
17
  return tuple(map(lambda x: int(x * factor), (r, g, b)))
18
 
 
19
  def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
20
  gray = 0.299 * r + 0.587 * g + 0.114 * b
21
  return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
22
 
 
23
  def rgb_to_hex(r, g, b):
24
  r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
25
  return f"#{r:02x}{g:02x}{b:02x}"
@@ -27,25 +37,31 @@ def rgb_to_hex(r, g, b):
27
 
28
  # Color assignment function
29
  def get_color_for_config(config: dict):
30
- attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
31
- barycenter = 1 - (config["compilation"] + 2 * config["kernelize"]) / 3
 
 
 
 
32
 
33
- # Eager
34
  if attn_implementation == "eager":
35
  color = blend_colors("#FA7F7FFF", "#FF2D2DFF", barycenter)
36
-
37
  # SDPA - math
38
  elif attn_implementation == "sdpa" and sdpa_backend == "math":
39
  color = blend_colors("#7AB8FFFF", "#277CD0FF", barycenter)
40
-
41
  # SDPA - flash attention
42
- elif attn_implementation == "sdpa" and sdpa_backend == "flash_attention":
43
  color = blend_colors("#81FF9CFF", "#219F3CFF", barycenter)
44
-
 
 
45
  # Flash attention
46
  elif attn_implementation == "flash_attention_2":
47
  color = blend_colors("#FFDB70FF", "#DFD002FF", barycenter)
 
 
 
48
  else:
49
  raise ValueError(f"Unknown attention implementation: {attn_implementation}")
50
-
51
  return color
 
1
  # Color manipulation functions
2
  def hex_to_rgb(hex_color):
3
+ hex_color = hex_color.lstrip("#")
4
  r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
5
  return r, g, b
6
 
7
+
8
  def blend_colors(color1, color2, blend_strength):
9
  rgb1 = hex_to_rgb(color1)
10
  rgb2 = hex_to_rgb(color2)
11
+ new_color = tuple(
12
+ map(
13
+ lambda i: int(rgb1[i] * blend_strength + rgb2[i] * (1 - blend_strength)),
14
+ range(3),
15
+ )
16
+ )
17
  return rgb_to_hex(*new_color)
18
 
19
+
20
  def increase_brightness(r, g, b, factor):
21
  return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
22
 
23
+
24
  def decrease_brightness(r, g, b, factor):
25
  return tuple(map(lambda x: int(x * factor), (r, g, b)))
26
 
27
+
28
  def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
29
  gray = 0.299 * r + 0.587 * g + 0.114 * b
30
  return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
31
 
32
+
33
  def rgb_to_hex(r, g, b):
34
  r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
35
  return f"#{r:02x}{g:02x}{b:02x}"
 
37
 
38
  # Color assignment function
39
  def get_color_for_config(config: dict):
40
+ attn_implementation, sdpa_backend = (
41
+ config["attn_implementation"],
42
+ config["sdpa_backend"],
43
+ )
44
+ compile_mode = config["compile_mode"] is not None
45
+ barycenter = 1 - (compile_mode + 2 * config["kernelize"]) / 3
46
 
47
+ # Eager
48
  if attn_implementation == "eager":
49
  color = blend_colors("#FA7F7FFF", "#FF2D2DFF", barycenter)
 
50
  # SDPA - math
51
  elif attn_implementation == "sdpa" and sdpa_backend == "math":
52
  color = blend_colors("#7AB8FFFF", "#277CD0FF", barycenter)
 
53
  # SDPA - flash attention
54
+ elif attn_implementation == "sdpa" and sdpa_backend in [None, "flash_attention"]:
55
  color = blend_colors("#81FF9CFF", "#219F3CFF", barycenter)
56
+ # SDPA - efficient attention
57
+ elif attn_implementation == "sdpa" and sdpa_backend == "efficient_attention":
58
+ color = blend_colors("#DB81FFFF", "#9C33B1FF", barycenter)
59
  # Flash attention
60
  elif attn_implementation == "flash_attention_2":
61
  color = blend_colors("#FFDB70FF", "#DFD002FF", barycenter)
62
+ # Flex attention
63
+ elif attn_implementation == "flex_attention":
64
+ color = blend_colors("#DB81FFFF", "#9C33B1FF", barycenter)
65
  else:
66
  raise ValueError(f"Unknown attention implementation: {attn_implementation}")
 
67
  return color