Spaces:
Running
Running
File size: 5,403 Bytes
6b42508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import plotly.graph_objects as go
import plotly.io as pio
import numpy as np
"""
Stacked bar chart: GPU memory breakdown vs sequence length, with menus for Model Size and Recomputation.
Responsive, no zoom/pan, clean hover; styled to match the minimal theme.
"""
# Axes
seq_labels = ["1024", "2048", "4096", "8192"]
seq_scale = np.array([1, 2, 4, 8], dtype=float)
# Components and colors (aligned with the provided example)
components = [
("parameters", "rgb(78, 165, 183)"),
("gradients", "rgb(227, 138, 66)"),
("optimizer", "rgb(232, 137, 171)"),
("activations", "rgb(206, 192, 250)"),
]
# Model sizes and base memory (GB) for params/grad/opt (constant vs seq), by size
model_sizes = ["1B", "3B", "8B", "70B", "405B"]
params_mem = {
"1B": 4.0,
"3B": 13.3,
"8B": 26.0,
"70B": 244.0,
"405B": 1520.0,
}
# Optimizer ~= 2x params; gradients ~= params (illustrative)
# Activations base coefficient per size (growth ~ coeff * (seq/1024)^2)
act_coeff = {
"1B": 3.6,
"3B": 9.3,
"8B": 46.2,
"70B": 145.7,
"405B": 1519.9,
}
def activations_curve(size_key: str, recompute: str) -> np.ndarray:
base = act_coeff[size_key] * (seq_scale ** 2)
if recompute == "selective":
return base * 0.25
if recompute == "full":
return base * (1.0/16.0)
return base
def stack_for(size_key: str, recompute: str):
p = np.full_like(seq_scale, params_mem[size_key], dtype=float)
g = np.full_like(seq_scale, params_mem[size_key], dtype=float)
o = np.full_like(seq_scale, 2.0 * params_mem[size_key], dtype=float)
a = activations_curve(size_key, recompute)
return {
"parameters": p,
"gradients": g,
"optimizer": o,
"activations": a,
}
# Precompute all combinations
recomp_modes = ["none", "selective", "full"]
Y = {mode: {size: stack_for(size, mode) for size in model_sizes} for mode in recomp_modes}
# Build traces: 4 traces per size (20 total). Start with size index 0 visible
fig = go.Figure()
for size in model_sizes:
for comp_name, color in components:
fig.add_bar(
x=seq_labels,
y=Y["none"][size][comp_name],
name=comp_name,
marker=dict(color=color),
hovertemplate="Seq len=%{x}<br>Mem=%{y:.1f}GB<br>%{data.name}<extra></extra>",
showlegend=True,
visible=(size == model_sizes[0]),
)
# Compute y-axis ranges per size and recomputation
def max_total(size: str, mode: str) -> float:
stacks = Y[mode][size]
totals = stacks["parameters"] + stacks["gradients"] + stacks["optimizer"] + stacks["activations"]
return float(np.max(totals))
layout_y_ranges = {mode: {size: 1.05 * max_total(size, mode) for size in model_sizes} for mode in recomp_modes}
# Layout
fig.update_layout(
barmode="stack",
autosize=True,
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(l=40, r=28, t=20, b=40),
hovermode="x unified",
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0),
xaxis=dict(title=dict(text="Sequence Length"), fixedrange=True),
yaxis=dict(title=dict(text="Memory (GB)"), fixedrange=True),
)
# Updatemenus: Model Size (toggle visibility)
buttons_sizes = []
for i, size in enumerate(model_sizes):
visible = [False] * (len(model_sizes) * len(components))
start = i * len(components)
for j in range(len(components)):
visible[start + j] = True
buttons_sizes.append(dict(
label=size,
method="update",
args=[
{"visible": visible},
{"yaxis": {"range": [0, layout_y_ranges["none"][size]]}},
],
))
# Updatemenus: Recomputation (restyle y across all traces)
def y_for_mode(mode: str):
ys = []
for size in model_sizes:
stacks = Y[mode][size]
for comp_name, _ in components:
ys.append(stacks[comp_name])
return ys
buttons_recomp = []
for mode, label in [("none", "None"), ("selective", "selective"), ("full", "full")]:
ys = y_for_mode(mode)
# Flatten into the format expected by Plotly for multiple traces
buttons_recomp.append(dict(
label=label,
method="update",
args=[
{"y": ys},
{"yaxis": {"range": [0, max(layout_y_ranges[mode].values())]}},
],
))
fig.update_layout(
updatemenus=[
dict(
type="dropdown",
x=1.03, xanchor="left",
y=0.60, yanchor="top",
showactive=True,
active=0,
buttons=buttons_sizes,
),
dict(
type="dropdown",
x=1.03, xanchor="left",
y=0.40, yanchor="top",
showactive=True,
active=0,
buttons=buttons_recomp,
),
],
annotations=[
dict(text="Model Size:", x=1.03, xanchor="left", xref="paper", y=0.60, yanchor="bottom", yref="paper", showarrow=False),
dict(text="Recomputation:", x=1.03, xanchor="left", xref="paper", y=0.40, yanchor="bottom", yref="paper", showarrow=False),
],
)
# Write fragment
fig.write_html("./plotly-bar.html",
include_plotlyjs=False,
full_html=False,
config={
'displayModeBar': False,
'responsive': True,
'scrollZoom': False,
})
|