| | """ |
| | Generate all visualizations for speculative decoding paper. |
| | |
| | Creates publication-quality figures matching PAPER_OUTLINE.md specifications. |
| | |
| | Author: Claude Code |
| | Date: 2025-11-30 |
| | """ |
| |
|
| | import pandas as pd |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from pathlib import Path |
| | from typing import Dict, List |
| |
|
| | |
| | plt.style.use('seaborn-v0_8-paper') |
| | sns.set_palette("colorblind") |
| | plt.rcParams['figure.dpi'] = 300 |
| | plt.rcParams['savefig.dpi'] = 300 |
| | plt.rcParams['font.size'] = 10 |
| | plt.rcParams['axes.labelsize'] = 11 |
| | plt.rcParams['axes.titlesize'] = 12 |
| | plt.rcParams['xtick.labelsize'] = 9 |
| | plt.rcParams['ytick.labelsize'] = 9 |
| |
|
| | |
| | DATA_DIR = Path(__file__).parent.parent / "data" |
| | FIGURES_DIR = Path(__file__).parent.parent / "paper" / "figures" |
| | FIGURES_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| |
|
| | def figure3_rejection_by_domain(df: pd.DataFrame): |
| | """Bar chart: Rejection rates by domain.""" |
| |
|
| | print("\n๐ Generating Figure 3: Rejection by Domain...") |
| |
|
| | |
| | rejection_rates = df.groupby('domain')['is_rejected'].mean().sort_values() |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 5)) |
| |
|
| | colors = ['#2ecc71', '#3498db', '#e74c3c', '#e67e22'] |
| | bars = ax.bar(range(len(rejection_rates)), rejection_rates.values * 100, color=colors) |
| |
|
| | |
| | ax.set_xlabel('Domain') |
| | ax.set_ylabel('Rejection Rate (%)') |
| | ax.set_title('Draft Rejection Rates by Domain') |
| | ax.set_xticks(range(len(rejection_rates))) |
| | ax.set_xticklabels([d.replace('_', '-').title() for d in rejection_rates.index], rotation=15, ha='right') |
| | ax.set_ylim(0, 40) |
| | ax.grid(axis='y', alpha=0.3) |
| |
|
| | |
| | for i, (bar, val) in enumerate(zip(bars, rejection_rates.values)): |
| | ax.text(bar.get_x() + bar.get_width()/2, val*100 + 1, f'{val*100:.1f}%', |
| | ha='center', va='bottom', fontsize=9, fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | output_path = FIGURES_DIR / "figure3_rejection_by_domain.png" |
| | plt.savefig(output_path, bbox_inches='tight') |
| | plt.close() |
| |
|
| | print(f" โ
Saved: {output_path}") |
| |
|
| |
|
| | def figure4_rejection_vs_position(df: pd.DataFrame): |
| | """Line plot: Rejection rate vs token position.""" |
| |
|
| | print("\n๐ Generating Figure 4: Rejection vs Position...") |
| |
|
| | |
| | df['position_bin'] = pd.cut(df['token_position'], bins=20) |
| | position_rates = df.groupby('position_bin')['is_rejected'].mean() |
| |
|
| | |
| | bin_centers = [(interval.left + interval.right) / 2 for interval in position_rates.index] |
| |
|
| | fig, ax = plt.subplots(figsize=(10, 5)) |
| |
|
| | ax.plot(bin_centers, position_rates.values * 100, marker='o', linewidth=2, markersize=6, |
| | color='#3498db', label='Rejection Rate') |
| |
|
| | |
| | ax.axvspan(0, 20, alpha=0.1, color='red', label='Early (<20)') |
| | ax.axvspan(100, max(bin_centers), alpha=0.1, color='green', label='Late (>100)') |
| |
|
| | ax.set_xlabel('Token Position in Sequence') |
| | ax.set_ylabel('Rejection Rate (%)') |
| | ax.set_title('Draft Rejection Rate by Token Position') |
| | ax.set_ylim(20, 35) |
| | ax.grid(alpha=0.3) |
| | ax.legend() |
| |
|
| | plt.tight_layout() |
| | output_path = FIGURES_DIR / "figure4_rejection_vs_position.png" |
| | plt.savefig(output_path, bbox_inches='tight') |
| | plt.close() |
| |
|
| | print(f" โ
Saved: {output_path}") |
| |
|
| |
|
| | def figure5_mask_performance_heatmap(df: pd.DataFrame): |
| | """Heatmap: Mask performance by domain.""" |
| |
|
| | print("\n๐ Generating Figure 5: Mask Performance Heatmap...") |
| |
|
| | |
| | pivot = df.groupby(['domain', 'mask_type'])['is_accepted'].mean().unstack() * 100 |
| |
|
| | |
| | mask_order = ['causal', 'tidar', 'bidirectional', 'windowed', 'strided'] |
| | domain_order = ['code', 'math', 'translation'] |
| | pivot = pivot.loc[domain_order, mask_order] |
| |
|
| | fig, ax = plt.subplots(figsize=(10, 5)) |
| |
|
| | sns.heatmap(pivot, annot=True, fmt='.1f', cmap='RdYlGn', vmin=5, vmax=35, |
| | cbar_kws={'label': 'Acceptance Rate (%)'}, ax=ax, linewidths=0.5) |
| |
|
| | ax.set_xlabel('Attention Mask Type') |
| | ax.set_ylabel('Domain') |
| | ax.set_title('Acceptance Rate by Domain and Attention Mask') |
| | ax.set_yticklabels([d.replace('_', '-').title() for d in domain_order], rotation=0) |
| | ax.set_xticklabels([m.title() for m in mask_order], rotation=15, ha='right') |
| |
|
| | plt.tight_layout() |
| | output_path = FIGURES_DIR / "figure5_mask_performance_heatmap.png" |
| | plt.savefig(output_path, bbox_inches='tight') |
| | plt.close() |
| |
|
| | print(f" โ
Saved: {output_path}") |
| |
|
| |
|
| | def figure6_throughput_quality_tradeoff(ablation_df: pd.DataFrame): |
| | """Scatter plot: Throughput vs quality trade-off.""" |
| |
|
| | print("\n๐ Generating Figure 6: Throughput-Quality Trade-off...") |
| |
|
| | |
| | mask_stats = ablation_df.groupby('mask_type').agg({ |
| | 'throughput_tokens_per_sec': 'mean', |
| | 'is_accepted': 'mean' |
| | }).reset_index() |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| |
|
| | colors = {'causal': '#3498db', 'tidar': '#9b59b6', 'bidirectional': '#2ecc71', |
| | 'windowed': '#e74c3c', 'strided': '#e67e22'} |
| |
|
| | for _, row in mask_stats.iterrows(): |
| | ax.scatter(row['throughput_tokens_per_sec'], row['is_accepted'] * 100, |
| | s=200, color=colors.get(row['mask_type'], 'gray'), |
| | alpha=0.7, edgecolors='black', linewidth=1.5) |
| | ax.text(row['throughput_tokens_per_sec'] + 5, row['is_accepted'] * 100 + 1, |
| | row['mask_type'].title(), fontsize=9, fontweight='bold') |
| |
|
| | ax.set_xlabel('Throughput (tokens/second)') |
| | ax.set_ylabel('Acceptance Rate (%)') |
| | ax.set_title('Throughput-Quality Trade-off Across Attention Masks') |
| | ax.grid(alpha=0.3) |
| | ax.set_xlim(40, 150) |
| |
|
| | plt.tight_layout() |
| | output_path = FIGURES_DIR / "figure6_throughput_quality_tradeoff.png" |
| | plt.savefig(output_path, bbox_inches='tight') |
| | plt.close() |
| |
|
| | print(f" โ
Saved: {output_path}") |
| |
|
| |
|
| | def figure_domain_comparison_table(df: pd.DataFrame, quality_df: pd.DataFrame): |
| | """Generate formatted table image for domain comparison.""" |
| |
|
| | print("\n๐ Generating Table 1: Domain Comparison...") |
| |
|
| | |
| | domain_stats = df.groupby('domain').agg({ |
| | 'is_rejected': 'mean', |
| | 'sequence_length': 'mean' |
| | }).reset_index() |
| |
|
| | |
| | domain_stats = domain_stats.merge(quality_df, on='domain', how='left') |
| |
|
| | |
| | fig, ax = plt.subplots(figsize=(12, 4)) |
| | ax.axis('tight') |
| | ax.axis('off') |
| |
|
| | table_data = [] |
| | for _, row in domain_stats.iterrows(): |
| | table_data.append([ |
| | row['domain'].replace('_', '-').title(), |
| | f"{row['is_rejected']*100:.1f}%", |
| | f"{row['metric']}", |
| | f"{row['value']:.2f}" if row['value'] < 1 else f"{row['value']:.1f}", |
| | f"{row['samples']}" |
| | ]) |
| |
|
| | headers = ['Domain', 'Rejection Rate', 'Quality Metric', 'Score', 'Samples'] |
| |
|
| | table = ax.table(cellText=table_data, colLabels=headers, loc='center', |
| | cellLoc='center', colWidths=[0.2, 0.2, 0.2, 0.15, 0.15]) |
| |
|
| | table.auto_set_font_size(False) |
| | table.set_fontsize(10) |
| | table.scale(1, 2) |
| |
|
| | |
| | for i in range(len(headers)): |
| | table[(0, i)].set_facecolor('#3498db') |
| | table[(0, i)].set_text_props(weight='bold', color='white') |
| |
|
| | |
| | for i in range(1, len(table_data) + 1): |
| | for j in range(len(headers)): |
| | if i % 2 == 0: |
| | table[(i, j)].set_facecolor('#ecf0f1') |
| |
|
| | plt.title('Table 1: Domain-Specific Rejection Rates and Quality Metrics', |
| | fontsize=12, fontweight='bold', pad=20) |
| |
|
| | output_path = FIGURES_DIR / "table1_domain_comparison.png" |
| | plt.savefig(output_path, bbox_inches='tight', dpi=300) |
| | plt.close() |
| |
|
| | print(f" โ
Saved: {output_path}") |
| |
|
| |
|
| | def main(): |
| | """Generate all visualizations.""" |
| |
|
| | print("=" * 60) |
| | print("Generating Publication-Quality Visualizations") |
| | print("=" * 60) |
| |
|
| | |
| | print("\nLoading data...") |
| | cross_domain_df = pd.read_csv(DATA_DIR / "phase1_cross_domain.csv") |
| | ablation_df = pd.read_csv(DATA_DIR / "phase3_ablation.csv") |
| | quality_df = pd.read_csv(DATA_DIR / "quality_metrics.csv") |
| | print(f"โ
Data loaded") |
| |
|
| | |
| | figure3_rejection_by_domain(cross_domain_df) |
| | figure4_rejection_vs_position(cross_domain_df) |
| | figure5_mask_performance_heatmap(ablation_df) |
| | figure6_throughput_quality_tradeoff(ablation_df) |
| | figure_domain_comparison_table(cross_domain_df, quality_df) |
| |
|
| | print("\n" + "=" * 60) |
| | print(f"โ
All figures generated!") |
| | print(f" Saved to: {FIGURES_DIR}") |
| | print("=" * 60) |
| |
|
| | print("\n=== Generated Figures ===") |
| | for fig_path in sorted(FIGURES_DIR.glob("*.png")): |
| | print(f" - {fig_path.name}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|