import matplotlib.pyplot as plt import matplotlib.patches as mpatches import io import numpy as np import base64 from plot_utils import get_color_for_config from data import load_data, ModelBenchmarkData def reorder_data(per_scenario_data: dict) -> dict: keys = list(per_scenario_data.keys()) def sorting_fn(key: str) -> float: cfg = per_scenario_data[key]["config"] attn_implementation = cfg["attn_implementation"] attn_impl_prio = { "flash_attention_2": 0, "sdpa": 1, "eager": 2, "flex_attention": 3, }[attn_implementation] sdpa_backend_prio = { None: -1, "flash_attention": 0, "math": 1, "efficient_attention": 2, "cudnn_attention": 3, }[cfg["sdpa_backend"]] return ( attn_impl_prio, sdpa_backend_prio, cfg["kernelize"], cfg["compile_mode"] is not None, ) keys.sort(key=sorting_fn) per_scenario_data = {k: per_scenario_data[k] for k in keys} return per_scenario_data def infer_bar_label(config: dict) -> str: """Format legend labels to be more readable.""" if config["attn_implementation"] == "eager": attn_implementation = "Eager" elif config["attn_implementation"] == "flash_attention_2": attn_implementation = "Flash attention" elif config["attn_implementation"] == "flex_attention": attn_implementation = "Flex attention" elif config["attn_implementation"] == "sdpa": attn_implementation = { "flash_attention": "SDPA (flash attention)", "efficient_attention": "SDPA (efficient_attention)", "cudnn_attention": "SDPA (cudnn)", "math": "SDPA (math)", }.get(config["sdpa_backend"], "SDPA (unknown backend)") else: attn_implementation = "Unknown" compile = "compiled" if config["compile_mode"] is not None else "no compile" kernels = "kernelized" if config["kernelize"] else "no kernels" return f"{attn_implementation}, {compile}, {kernels}" def infer_bar_hatch(config: dict) -> str: if config["compile_mode"] is not None: return "/" else: return "" def make_bar_kwargs( per_device_data: dict[str, ModelBenchmarkData], key: str ) -> tuple[dict, list]: # Prepare accumulators current_x = 0 bar_kwargs = {"x": [], "height": [], "color": [], "label": [], "hatch": []} errors_bars = [] x_ticks = [] for device_name, device_data in per_device_data.items(): per_scenario_data = device_data.get_bar_plot_data() per_scenario_data = reorder_data(per_scenario_data) device_xs = [] for scenario_name, scenario_data in per_scenario_data.items(): bar_kwargs["x"].append(current_x) bar_kwargs["height"].append(np.median(scenario_data[key])) bar_kwargs["color"].append(get_color_for_config(scenario_data["config"])) bar_kwargs["label"].append(infer_bar_label(scenario_data["config"])) bar_kwargs["hatch"].append(infer_bar_hatch(scenario_data["config"])) errors_bars.append(np.std(scenario_data[key])) device_xs.append(current_x) current_x += 1 x_ticks.append((np.mean(device_xs), device_name)) current_x += 1.5 return bar_kwargs, errors_bars, x_ticks def create_matplotlib_bar_plot() -> None: """Create side-by-side matplotlib bar charts for TTFT and TPOT data.""" # Create figure with dark theme - maximum size for full screen plt.style.use("dark_background") fig, axs = plt.subplots(2, 1, figsize=(20, 11), sharex=True) # used to be 30, 16 fig.patch.set_facecolor("#000000") # Load data and ensure coherence per_device_data = load_data() batch_size, sequence_length, num_tokens_to_generate = None, None, None for device_name, device_data in per_device_data.items(): bs, seqlen, n_tok = device_data.ensure_coherence() if batch_size is None: batch_size, sequence_length, num_tokens_to_generate = bs, seqlen, n_tok elif (bs, seqlen, n_tok) != ( batch_size, sequence_length, num_tokens_to_generate, ): fig.suptitle( f"Mismatch for batch size, sequence length and number of tokens to generate between configs: {bs} " f"!= {batch_size}, {seqlen} != {sequence_length}, {n_tok} != {num_tokens_to_generate}", color="white", fontsize=18, ) return None # TTFT Plot (top) ttft_bars, ttft_errors, x_ticks = make_bar_kwargs(per_device_data, "ttft") draw_bar_plot(axs[0], ttft_bars, ttft_errors, "TTFT (seconds)", x_ticks) # # ITL Plot (bottom) itl_bars, itl_errors, x_ticks = make_bar_kwargs(per_device_data, "itl") draw_bar_plot(axs[1], itl_bars, itl_errors, "ITL (seconds)", x_ticks) # Title and tight layout title = "\n".join( [ "Time to first token and inter-token latency (lower is better)", f"Batch size: {batch_size}, sequence length: {sequence_length}, new tokens: {num_tokens_to_generate}", ] ) fig.suptitle(title, color="white", fontsize=20, y=1.005, linespacing=1.5) plt.tight_layout() # Add common legend with full text legend_labels, legend_colors, legend_hatches = [], [], [] for label, color, hatch in zip( ttft_bars["label"], ttft_bars["color"], ttft_bars["hatch"] ): if label not in legend_labels: legend_labels.append(label) legend_colors.append(color) legend_hatches.append(hatch) # Make sure all attn implementations are equally represented # implementations = {} # for label, color, hatch in zip(legend_labels, legend_colors, legend_hatches): # impl = label.split(",")[0] # implementations[impl] = implementations.get(impl, []) + [(label, color, hatch)] # n_max = max(len(impls) for impls in implementations.values()) # for label_color_pairs in implementations.values(): # for _ in range(len(label_color_pairs), n_max): # label_color_pairs.append(("", "#000000")) # legend_labels, legend_colors = zip(*sum(implementations.values(), [])) legend_handles = [ mpatches.Patch(facecolor=color, hatch=hatch, label=label, edgecolor="white") for color, hatch, label in zip(legend_colors, legend_hatches, legend_labels) ] # Put a legend to the right of the current axis fig.legend( handles=legend_handles, loc="lower center", ncol=4, bbox_to_anchor=(0.515, -0.11), facecolor="black", edgecolor="white", labelcolor="white", fontsize=14, ) # Save plot to bytes with high DPI for crisp text buffer = io.BytesIO() plt.savefig(buffer, format="png", facecolor="#000000", bbox_inches="tight", dpi=150) buffer.seek(0) # Convert to base64 for HTML embedding img_data = base64.b64encode(buffer.getvalue()).decode() plt.close(fig) # Return HTML with embedded image - full page coverage html = f"""
""" return html def draw_bar_plot( ax: plt.Axes, bar_kwargs: dict, errors: list, ylabel: str, xticks: list[tuple[float, str]], adapt_ylim: bool = False, ) -> None: ax.set_facecolor("#000000") ax.grid(True, alpha=0.3, color="white", axis="y", zorder=0) # Draw bars _ = ax.bar(**bar_kwargs, width=1.0, edgecolor="white", linewidth=1, zorder=3) # Add error bars ax.errorbar( bar_kwargs["x"], bar_kwargs["height"], yerr=errors, fmt="none", ecolor="white", alpha=0.8, elinewidth=1.5, capthick=1.5, capsize=4, zorder=4, ) # Set labels, ticks and grid ax.set_ylabel(ylabel, color="white", fontsize=16) ax.set_xticks([]) ax.tick_params(colors="white", labelsize=13) ax.set_xticks([xt[0] for xt in xticks], [xt[1] for xt in xticks], fontsize=16) # Truncate axis to better fit the bars if adapt_ylim: new_ymin, new_ymax = 1e9, -1e9 for h, e in zip(bar_kwargs["height"], errors): new_ymin = min(new_ymin, 0.98 * (h - e)) new_ymax = max(new_ymax, 1.02 * (h + e)) ymin, ymax = ax.get_ylim() ax.set_ylim(max(ymin, new_ymin), min(ymax, new_ymax))