ror HF Staff commited on
Commit
165f130
·
1 Parent(s): dc41c89

Better look

Browse files
Files changed (3) hide show
  1. bar_plot.py +28 -8
  2. data.py +11 -1
  3. plot_utils.py +22 -24
bar_plot.py CHANGED
@@ -21,14 +21,26 @@ def reorder_data(per_scenario_data: dict) -> dict:
21
  return per_scenario_data
22
 
23
 
24
- def make_bar_kwargs(key: str) -> tuple[dict, list]:
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Prepare accumulators
26
  current_x = 0
27
  bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
28
  errors_bars = []
29
  x_ticks = []
30
 
31
- for device_name, device_data in load_data(keep_common_scenarios_only=False).items():
32
  per_scenario_data = device_data.get_bar_plot_data()
33
  per_scenario_data = reorder_data(per_scenario_data)
34
  device_xs = []
@@ -37,7 +49,7 @@ def make_bar_kwargs(key: str) -> tuple[dict, list]:
37
  bar_kwargs["x"].append(current_x)
38
  bar_kwargs["height"].append(np.median(scenario_data[key]))
39
  bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
40
- bar_kwargs["label"].append(scenario_name)
41
  errors_bars.append(np.std(scenario_data[key]))
42
  device_xs.append(current_x)
43
  current_x += 1
@@ -54,12 +66,19 @@ def create_matplotlib_bar_plot() -> None:
54
  fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
55
  fig.patch.set_facecolor('#000000')
56
 
 
 
 
 
 
 
 
57
  # TTFT Plot (left)
58
- ttft_bars, ttft_errors, x_ticks = make_bar_kwargs("ttft")
59
  draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
60
 
61
  # # ITL Plot (right)
62
- itl_bars, itl_errors, x_ticks = make_bar_kwargs("itl")
63
  draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
64
 
65
  # # E2E Plot (right)
@@ -68,8 +87,9 @@ def create_matplotlib_bar_plot() -> None:
68
  plt.tight_layout()
69
 
70
  # Add common legend with full text
71
- legend_labels = ttft_bars["label"] # Use full labels without truncation
72
- legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
 
73
 
74
  # Put a legend to the right of the current axis
75
  fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
@@ -103,7 +123,7 @@ def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylab
103
  # Add error bars
104
  ax.errorbar(
105
  bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
106
- fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
107
  )
108
  # Set labels and title
109
  ax.set_ylabel(ylabel, color='white', fontsize=16)
 
21
  return per_scenario_data
22
 
23
 
24
+ def infer_bar_label(config: dict) -> str:
25
+ """Format legend labels to be more readable."""
26
+ attn_implementation = {
27
+ "flash_attention_2": "Flash attention",
28
+ "sdpa": "SDPA",
29
+ "eager": "Eager",
30
+ }[config["attn_implementation"]]
31
+ compile = "compiled" if config["compilation"] else "no compile"
32
+ kernels = "kernelized" if config["kernelize"] else "no kernels"
33
+ return f"{attn_implementation}, {compile}, {kernels}"
34
+
35
+
36
+ def make_bar_kwargs(per_device_data: dict, key: str) -> tuple[dict, list]:
37
  # Prepare accumulators
38
  current_x = 0
39
  bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
40
  errors_bars = []
41
  x_ticks = []
42
 
43
+ for device_name, device_data in per_device_data.items():
44
  per_scenario_data = device_data.get_bar_plot_data()
45
  per_scenario_data = reorder_data(per_scenario_data)
46
  device_xs = []
 
49
  bar_kwargs["x"].append(current_x)
50
  bar_kwargs["height"].append(np.median(scenario_data[key]))
51
  bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
52
+ bar_kwargs["label"].append(infer_bar_label(scenario_data["config"]))
53
  errors_bars.append(np.std(scenario_data[key]))
54
  device_xs.append(current_x)
55
  current_x += 1
 
66
  fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
67
  fig.patch.set_facecolor('#000000')
68
 
69
+ # Load and sanitize data
70
+ per_device_data = load_data()
71
+ batch_sizes = {name: device_data.get_main_batch_size() for name, device_data in per_device_data.items()}
72
+ if len(set(batch_sizes.values())) > 1:
73
+ fig.suptitle(f"Unmatched batch sizes: {batch_sizes}", color='white', fontsize=18, pad=20)
74
+ return None
75
+
76
  # TTFT Plot (left)
77
+ ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft")
78
  draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
79
 
80
  # # ITL Plot (right)
81
+ itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl")
82
  draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
83
 
84
  # # E2E Plot (right)
 
87
  plt.tight_layout()
88
 
89
  # Add common legend with full text
90
+ unique_bars = len(ttft_bars["label"]) // 2
91
+ legend_labels, legend_colors = ttft_bars["label"][:unique_bars], ttft_bars["color"][:unique_bars]
92
+ legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in legend_colors]
93
 
94
  # Put a legend to the right of the current axis
95
  fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
 
123
  # Add error bars
124
  ax.errorbar(
125
  bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
126
+ fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4,
127
  )
128
  # Set labels and title
129
  ax.set_ylabel(ylabel, color='white', fontsize=16)
data.py CHANGED
@@ -25,6 +25,16 @@ class ModelBenchmarkData:
25
  num_tokens = len(measures["t_tokens"]) - 1
26
  return delta_t / num_tokens
27
 
 
 
 
 
 
 
 
 
 
 
28
  def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
29
  # Gather data for each scenario
30
  per_scenario_data = {}
@@ -52,7 +62,7 @@ class ModelBenchmarkData:
52
  return per_scenario_data
53
 
54
 
55
- def load_data(keep_common_scenarios_only: bool = True) -> dict[str, ModelBenchmarkData]:
56
  data = {
57
  "MI325": ModelBenchmarkData("mi325_data.json"),
58
  "H100": ModelBenchmarkData("h100_data.json"),
 
25
  num_tokens = len(measures["t_tokens"]) - 1
26
  return delta_t / num_tokens
27
 
28
+ def get_main_batch_size(self) -> int:
29
+ batch_sizes = {}
30
+ for cfg_name, data in self.data.items():
31
+ for measure in data["measures"]:
32
+ bs = measure["batch_size"]
33
+ if bs not in batch_sizes:
34
+ batch_sizes[bs] = 0
35
+ batch_sizes[bs] += 1
36
+ return max(batch_sizes, key=batch_sizes.get)
37
+
38
  def get_bar_plot_data(self, collapse_on_cache: bool = True, collapse_on_compile_mode: bool = True) -> dict:
39
  # Gather data for each scenario
40
  per_scenario_data = {}
 
62
  return per_scenario_data
63
 
64
 
65
+ def load_data(keep_common_scenarios_only: bool = False) -> dict[str, ModelBenchmarkData]:
66
  data = {
67
  "MI325": ModelBenchmarkData("mi325_data.json"),
68
  "H100": ModelBenchmarkData("h100_data.json"),
plot_utils.py CHANGED
@@ -4,9 +4,11 @@ def hex_to_rgb(hex_color):
4
  r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
5
  return r, g, b
6
 
7
- def blend_colors(rgb, hex_color, blend_strength):
8
- other_rgb = hex_to_rgb(hex_color)
9
- return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3)))
 
 
10
 
11
  def increase_brightness(r, g, b, factor):
12
  return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
@@ -25,29 +27,25 @@ def rgb_to_hex(r, g, b):
25
 
26
  # Color assignment function
27
  def get_color_for_config(config: dict):
28
- # Determine the main hue for the attention implementation
29
  attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
30
- compilation = config["compilation"]
 
 
31
  if attn_implementation == "eager":
32
- main_hue = "#FF4B4BFF" if compilation else "#FF4141FF"
33
- elif attn_implementation == "sdpa":
34
- main_hue = {
35
- None: "#4A90E2" if compilation else "#2E82E1FF",
36
- "math": "#408DDB" if compilation else "#227BD3FF",
37
- "flash_attention": "#35A34D" if compilation else "#219F3CFF",
38
- "efficient_attention": "#605895" if compilation else "#423691FF",
39
- "cudnn_attention": "#774AE2" if compilation else "#5D27DCFF",
40
- }[sdpa_backend] # fmt: off
 
 
41
  elif attn_implementation == "flash_attention_2":
42
- main_hue = "#FFD700" if compilation else "#FFBF00FF"
43
  else:
44
  raise ValueError(f"Unknown attention implementation: {attn_implementation}")
45
- # Apply color modifications for compilation and kernelization
46
- r, g, b = hex_to_rgb(main_hue)
47
- if config["compilation"]:
48
- delta = 0.2 + 0.2 * (len(config["compile_mode"]) - 7) / 8
49
- r, g, b = increase_brightness(r, g, b, delta)
50
- if config["kernelize"]:
51
- r, g, b = decrease_brightness(r, g, b, 0.8)
52
- # Return the color as a hex string
53
- return rgb_to_hex(r, g, b)
 
4
  r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
5
  return r, g, b
6
 
7
+ def blend_colors(color1, color2, blend_strength):
8
+ rgb1 = hex_to_rgb(color1)
9
+ rgb2 = hex_to_rgb(color2)
10
+ new_color = tuple(map(lambda i: int(rgb1[i] * blend_strength + rgb2[i] * (1 - blend_strength)), range(3)))
11
+ return rgb_to_hex(*new_color)
12
 
13
  def increase_brightness(r, g, b, factor):
14
  return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
 
27
 
28
  # Color assignment function
29
  def get_color_for_config(config: dict):
 
30
  attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
31
+ barycenter = 1 - (config["compilation"] + 2 * config["kernelize"]) / 3
32
+
33
+ # Eager
34
  if attn_implementation == "eager":
35
+ color = blend_colors("#FA7F7FFF", "#FF2D2DFF", barycenter)
36
+
37
+ # SDPA - math
38
+ elif attn_implementation == "sdpa" and sdpa_backend == "math":
39
+ color = blend_colors("#7AB8FFFF", "#277CD0FF", barycenter)
40
+
41
+ # SDPA - flash attention
42
+ elif attn_implementation == "sdpa" and sdpa_backend == "flash_attention":
43
+ color = blend_colors("#81FF9CFF", "#219F3CFF", barycenter)
44
+
45
+ # Flash attention
46
  elif attn_implementation == "flash_attention_2":
47
+ color = blend_colors("#FFDB70FF", "#DFD002FF", barycenter)
48
  else:
49
  raise ValueError(f"Unknown attention implementation: {attn_implementation}")
50
+
51
+ return color