ror HF Staff commited on
Commit
dc41c89
·
1 Parent(s): e1f4b73

Good bar chart

Browse files
Files changed (4) hide show
  1. app.py +3 -7
  2. bar_plot.py +66 -108
  3. data.py +12 -0
  4. plot_utils.py +53 -0
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import matplotlib
4
 
5
- from data import ModelBenchmarkData
6
  from bar_plot import create_matplotlib_bar_plot
7
 
8
 
@@ -11,9 +10,6 @@ matplotlib.use('Agg')
11
  plt.ioff()
12
 
13
 
14
- DATA = ModelBenchmarkData("data.json")
15
-
16
-
17
  def load_css():
18
  """Load CSS styling."""
19
  try:
@@ -23,10 +19,10 @@ def load_css():
23
  return "body { background: #000; color: #fff; }"
24
 
25
 
26
-
27
  def refresh_plot():
28
  """Generate new matplotlib charts and update description."""
29
- return create_matplotlib_bar_plot(DATA.get_bar_plot_data()), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
 
30
 
31
  # Create Gradio interface
32
  with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
@@ -40,7 +36,7 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
40
  # Main plot area
41
  with gr.Column(elem_classes=["main-content"]):
42
  plot = gr.HTML(
43
- create_matplotlib_bar_plot(DATA.get_bar_plot_data()),
44
  elem_classes=["plot-container"],
45
  )
46
 
 
2
  import matplotlib.pyplot as plt
3
  import matplotlib
4
 
 
5
  from bar_plot import create_matplotlib_bar_plot
6
 
7
 
 
10
  plt.ioff()
11
 
12
 
 
 
 
13
  def load_css():
14
  """Load CSS styling."""
15
  try:
 
19
  return "body { background: #000; color: #fff; }"
20
 
21
 
 
22
  def refresh_plot():
23
  """Generate new matplotlib charts and update description."""
24
+ sidebar_text = "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
25
+ return create_matplotlib_bar_plot(), sidebar_text
26
 
27
  # Create Gradio interface
28
  with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
 
36
  # Main plot area
37
  with gr.Column(elem_classes=["main-content"]):
38
  plot = gr.HTML(
39
+ create_matplotlib_bar_plot(),
40
  elem_classes=["plot-container"],
41
  )
42
 
bar_plot.py CHANGED
@@ -3,66 +3,9 @@ import io
3
  import numpy as np
4
  import base64
5
 
 
 
6
 
7
- # Color manipulation functions
8
- def hex_to_rgb(hex_color):
9
- hex_color = hex_color.lstrip('#')
10
- r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
11
- return r, g, b
12
-
13
- def blend_colors(rgb, hex_color, blend_strength):
14
- other_rgb = hex_to_rgb(hex_color)
15
- return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3)))
16
-
17
- def increase_brightness(r, g, b, factor):
18
- return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
19
-
20
- def decrease_brightness(r, g, b, factor):
21
- return tuple(map(lambda x: int(x * factor), (r, g, b)))
22
-
23
- def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
24
- gray = 0.299 * r + 0.587 * g + 0.114 * b
25
- return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
26
-
27
- def rgb_to_hex(r, g, b):
28
- r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
29
- return f"#{r:02x}{g:02x}{b:02x}"
30
-
31
- # Color assignment function
32
- def get_color_for_config(config, filtered_on_compile_mode: bool = False):
33
-
34
- # Determine the main hue for the attention implementation
35
- attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
36
- compilation = config["compilation"]
37
- if attn_implementation == "eager":
38
- main_hue = "#FF4B4BFF" if compilation else "#FF4141FF"
39
- elif attn_implementation == "sdpa":
40
- main_hue = {
41
- None: "#4A90E2" if compilation else "#2E82E1FF",
42
- "math": "#408DDB" if compilation else "#227BD3FF",
43
- "flash_attention": "#35A34D" if compilation else "#219F3CFF",
44
- "efficient_attention": "#605895" if compilation else "#423691FF",
45
- "cudnn_attention": "#774AE2" if compilation else "#5D27DCFF",
46
- }[sdpa_backend] # fmt: off
47
- elif attn_implementation == "flash_attention_2":
48
- main_hue = "#FFD700" if compilation else "#FFBF00FF"
49
- else:
50
- raise ValueError(f"Unknown attention implementation: {attn_implementation}")
51
-
52
- # Apply color modifications for compilation and kernelization
53
- r, g, b = hex_to_rgb(main_hue)
54
- if config["compilation"]:
55
- delta = 0.2
56
- delta += 0.2 * (len(config["compile_mode"]) - 7) / 8 if filtered_on_compile_mode else 0
57
- r, g, b = increase_brightness(r, g, b, delta)
58
- if config["kernelize"]:
59
- pass
60
- # r, g, b = blend_colors((r, g, b), "#FF00F2FF", 0.7)
61
- r, g, b = decrease_brightness(r, g, b, 0.8)
62
- # r, g, b = increase_saturation(r, g, b, 0.9)
63
-
64
- # Return the color as a hex string
65
- return rgb_to_hex(r, g, b)
66
 
67
  def reorder_data(per_scenario_data: dict) -> dict:
68
  keys = list(per_scenario_data.keys())
@@ -78,70 +21,59 @@ def reorder_data(per_scenario_data: dict) -> dict:
78
  return per_scenario_data
79
 
80
 
81
- def make_bar_kwargs(per_scenario_data: dict, key: str) -> tuple[dict, list]:
 
 
82
  bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
83
- errors = []
84
- for i, (name, data) in enumerate(per_scenario_data.items()):
85
- bar_kwargs["x"].append(i)
86
- bar_kwargs["height"].append(np.median(data[key]))
87
- bar_kwargs["color"].append(get_color_for_config(data["config"]))
88
- bar_kwargs["label"].append(name)
89
- errors.append(np.std(data[key]))
90
- return bar_kwargs, errors
91
-
92
- def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str):
93
- ax.set_facecolor('#000000')
94
- # ax.grid(True, alpha=0.3, color='white')
95
- # Draw bars
96
- _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1)
97
- # Add error bars
98
- ax.errorbar(
99
- bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
100
- fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
101
- )
102
- # Set labels and title
103
- ax.set_ylabel(ylabel, color='white', fontsize=16)
104
- ax.set_title(title, color='white', fontsize=18, pad=20)
105
- # Set ticks and grid
106
- ax.set_xticks([])
107
- ax.tick_params(colors='white', labelsize=13)
108
- # Truncate axis to better fit the bars
109
- new_ymin, new_ymax = 1e9, -1e9
110
- for h, e in zip(bar_kwargs["height"], errors):
111
- new_ymin = min(new_ymin, 0.98 * (h - e))
112
- new_ymax = max(new_ymax, 1.02 * (h + e))
113
- ymin, ymax = ax.get_ylim()
114
- ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
115
-
116
-
117
- def create_matplotlib_bar_plot(per_scenario_data: dict):
118
  """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
119
 
120
  # Create figure with dark theme - maximum size for full screen
121
  plt.style.use('dark_background')
122
- fig, axs = plt.subplots(1, 3, figsize=(30, 16))
123
  fig.patch.set_facecolor('#000000')
124
 
125
- # Reorganize data
126
- per_scenario_data = reorder_data(per_scenario_data)
127
-
128
  # TTFT Plot (left)
129
- ttft_bars, ttft_errors = make_bar_kwargs(per_scenario_data, "ttft")
130
- draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token (lower is better)", "TTFT (seconds)")
131
 
132
- # ITL Plot (right)
133
- itl_bars, itl_errors = make_bar_kwargs(per_scenario_data, "itl")
134
- draw_bar_plot(axs[1], itl_bars, itl_errors, "Inter token latency (lower is better)", "ITL (seconds)")
135
 
136
- # E2E Plot (right)
137
- e2e_bars, e2e_errors = make_bar_kwargs(per_scenario_data, "e2e")
138
- draw_bar_plot(axs[2], e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
 
139
 
140
  # Add common legend with full text
141
  legend_labels = ttft_bars["label"] # Use full labels without truncation
142
  legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
 
 
143
  fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
144
- bbox_to_anchor=(0.5, -0.05), facecolor='black', edgecolor='white',
145
  labelcolor='white', fontsize=14)
146
 
147
  # Save plot to bytes with high DPI for crisp text
@@ -161,3 +93,29 @@ def create_matplotlib_bar_plot(per_scenario_data: dict):
161
  </div>
162
  """
163
  return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import base64
5
 
6
+ from plot_utils import get_color_for_config
7
+ from data import load_data
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def reorder_data(per_scenario_data: dict) -> dict:
11
  keys = list(per_scenario_data.keys())
 
21
  return per_scenario_data
22
 
23
 
24
+ def make_bar_kwargs(key: str) -> tuple[dict, list]:
25
+ # Prepare accumulators
26
+ current_x = 0
27
  bar_kwargs = {"x": [], "height": [], "color": [], "label": []}
28
+ errors_bars = []
29
+ x_ticks = []
30
+
31
+ for device_name, device_data in load_data(keep_common_scenarios_only=False).items():
32
+ per_scenario_data = device_data.get_bar_plot_data()
33
+ per_scenario_data = reorder_data(per_scenario_data)
34
+ device_xs = []
35
+
36
+ for scenario_name, scenario_data in per_scenario_data.items():
37
+ bar_kwargs["x"].append(current_x)
38
+ bar_kwargs["height"].append(np.median(scenario_data[key]))
39
+ bar_kwargs["color"].append(get_color_for_config(scenario_data["config"]))
40
+ bar_kwargs["label"].append(scenario_name)
41
+ errors_bars.append(np.std(scenario_data[key]))
42
+ device_xs.append(current_x)
43
+ current_x += 1
44
+
45
+ x_ticks.append((np.mean(device_xs), device_name))
46
+ current_x += 1.5
47
+ return bar_kwargs, errors_bars, x_ticks
48
+
49
+ def create_matplotlib_bar_plot() -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
51
 
52
  # Create figure with dark theme - maximum size for full screen
53
  plt.style.use('dark_background')
54
+ fig, axs = plt.subplots(2, 1, figsize=(30, 16), sharex=True)
55
  fig.patch.set_facecolor('#000000')
56
 
 
 
 
57
  # TTFT Plot (left)
58
+ ttft_bars, ttft_errors, x_ticks = make_bar_kwargs("ttft")
59
+ draw_bar_plot(axs[0], ttft_bars, ttft_errors, "Time to first token and inter-token latency (lower is better)", "TTFT (seconds)", x_ticks)
60
 
61
+ # # ITL Plot (right)
62
+ itl_bars, itl_errors, x_ticks = make_bar_kwargs("itl")
63
+ draw_bar_plot(axs[1], itl_bars, itl_errors, None, "ITL (seconds)", x_ticks)
64
 
65
+ # # E2E Plot (right)
66
+ # e2e_bars, e2e_errors = make_bar_kwargs("e2e")
67
+ # draw_bar_plot(axs, e2e_bars, e2e_errors, "End-to-end latency (lower is better)", "E2E (seconds)")
68
+ plt.tight_layout()
69
 
70
  # Add common legend with full text
71
  legend_labels = ttft_bars["label"] # Use full labels without truncation
72
  legend_handles = [plt.Rectangle((0,0),1,1, color=color) for color in ttft_bars["color"]]
73
+
74
+ # Put a legend to the right of the current axis
75
  fig.legend(legend_handles, legend_labels, loc='lower center', ncol=4,
76
+ bbox_to_anchor=(0.515, -0.15), facecolor='black', edgecolor='white',
77
  labelcolor='white', fontsize=14)
78
 
79
  # Save plot to bytes with high DPI for crisp text
 
93
  </div>
94
  """
95
  return html
96
+
97
+
98
+ def draw_bar_plot(ax: plt.Axes, bar_kwargs: dict, errors: list, title: str, ylabel: str, xticks: list[tuple[float, str]]):
99
+ ax.set_facecolor('#000000')
100
+ ax.grid(True, alpha=0.2, color='white', zorder=0)
101
+ # Draw bars
102
+ _ = ax.bar(**bar_kwargs, width=1.0, edgecolor='white', linewidth=1, zorder=3)
103
+ # Add error bars
104
+ ax.errorbar(
105
+ bar_kwargs["x"], bar_kwargs["height"], yerr=errors,
106
+ fmt='none', ecolor='white', alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4,
107
+ )
108
+ # Set labels and title
109
+ ax.set_ylabel(ylabel, color='white', fontsize=16)
110
+ ax.set_title(title, color='white', fontsize=18, pad=20)
111
+ # Set ticks and grid
112
+ ax.set_xticks([])
113
+ ax.tick_params(colors='white', labelsize=13)
114
+ ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16)
115
+ # Truncate axis to better fit the bars
116
+ new_ymin, new_ymax = 1e9, -1e9
117
+ for h, e in zip(bar_kwargs["height"], errors):
118
+ new_ymin = min(new_ymin, 0.98 * (h - e))
119
+ new_ymax = max(new_ymax, 1.02 * (h + e))
120
+ ymin, ymax = ax.get_ylim()
121
+ ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))
data.py CHANGED
@@ -50,3 +50,15 @@ class ModelBenchmarkData:
50
  per_scenario_data = {k: per_scenario_data[k] for k, _ in collapsed_keys.values()}
51
 
52
  return per_scenario_data
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  per_scenario_data = {k: per_scenario_data[k] for k, _ in collapsed_keys.values()}
51
 
52
  return per_scenario_data
53
+
54
+
55
+ def load_data(keep_common_scenarios_only: bool = True) -> dict[str, ModelBenchmarkData]:
56
+ data = {
57
+ "MI325": ModelBenchmarkData("mi325_data.json"),
58
+ "H100": ModelBenchmarkData("h100_data.json"),
59
+ }
60
+ if keep_common_scenarios_only:
61
+ common_scenarios = set(data["MI325"].data.keys()) & set(data["H100"].data.keys())
62
+ for device_name, device_data in data.items():
63
+ device_data.data = {k: v for k, v in device_data.data.items() if k in common_scenarios}
64
+ return data
plot_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Color manipulation functions
2
+ def hex_to_rgb(hex_color):
3
+ hex_color = hex_color.lstrip('#')
4
+ r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
5
+ return r, g, b
6
+
7
+ def blend_colors(rgb, hex_color, blend_strength):
8
+ other_rgb = hex_to_rgb(hex_color)
9
+ return tuple(map(lambda i: int(rgb[i] * blend_strength + other_rgb[i] * (1 - blend_strength)), range(3)))
10
+
11
+ def increase_brightness(r, g, b, factor):
12
+ return tuple(map(lambda x: int(x + (255 - x) * factor), (r, g, b)))
13
+
14
+ def decrease_brightness(r, g, b, factor):
15
+ return tuple(map(lambda x: int(x * factor), (r, g, b)))
16
+
17
+ def increase_saturation(r, g, b, factor) -> tuple[int, int, int]:
18
+ gray = 0.299 * r + 0.587 * g + 0.114 * b
19
+ return tuple(map(lambda x: int(gray + (x - gray) * factor), (r, g, b)))
20
+
21
+ def rgb_to_hex(r, g, b):
22
+ r, g, b = map(lambda x: min(max(x, 0), 255), (r, g, b))
23
+ return f"#{r:02x}{g:02x}{b:02x}"
24
+
25
+
26
+ # Color assignment function
27
+ def get_color_for_config(config: dict):
28
+ # Determine the main hue for the attention implementation
29
+ attn_implementation, sdpa_backend = config["attn_implementation"], config["sdpa_backend"]
30
+ compilation = config["compilation"]
31
+ if attn_implementation == "eager":
32
+ main_hue = "#FF4B4BFF" if compilation else "#FF4141FF"
33
+ elif attn_implementation == "sdpa":
34
+ main_hue = {
35
+ None: "#4A90E2" if compilation else "#2E82E1FF",
36
+ "math": "#408DDB" if compilation else "#227BD3FF",
37
+ "flash_attention": "#35A34D" if compilation else "#219F3CFF",
38
+ "efficient_attention": "#605895" if compilation else "#423691FF",
39
+ "cudnn_attention": "#774AE2" if compilation else "#5D27DCFF",
40
+ }[sdpa_backend] # fmt: off
41
+ elif attn_implementation == "flash_attention_2":
42
+ main_hue = "#FFD700" if compilation else "#FFBF00FF"
43
+ else:
44
+ raise ValueError(f"Unknown attention implementation: {attn_implementation}")
45
+ # Apply color modifications for compilation and kernelization
46
+ r, g, b = hex_to_rgb(main_hue)
47
+ if config["compilation"]:
48
+ delta = 0.2 + 0.2 * (len(config["compile_mode"]) - 7) / 8
49
+ r, g, b = increase_brightness(r, g, b, delta)
50
+ if config["kernelize"]:
51
+ r, g, b = decrease_brightness(r, g, b, 0.8)
52
+ # Return the color as a hex string
53
+ return rgb_to_hex(r, g, b)