OpenTransformer commited on
Commit
021165c
·
verified ·
1 Parent(s): f256bdc

Upload normalization_via_fluxions.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. normalization_via_fluxions.md +482 -0
normalization_via_fluxions.md ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Normalization Layers via the Method of Fluxions
2
+ ## BatchNorm, LayerNorm, RMSNorm: What They Actually Do
3
+
4
+ **Scott Bisset, Silicon Goddess**
5
+ OpenTransformers Ltd
6
+ January 2026
7
+
8
+ ---
9
+
10
+ ## Abstract
11
+
12
+ Normalization layers (BatchNorm, LayerNorm, RMSNorm) are presented in textbooks as "normalize then scale and shift" with formulas involving means and variances. This obscures their true purpose and makes backward pass derivation seem magical. We reformulate normalization using fluxions, revealing: (1) normalization as signal conditioning, (2) the backward pass as sensitivity redistribution, and (3) why different norms suit different architectures. The fluxion view also explains the computational structure that enables fused kernels.
13
+
14
+ ---
15
+
16
+ ## 1. Why Normalize?
17
+
18
+ ### 1.1 The Problem
19
+
20
+ Deep networks suffer from internal covariate shift:
21
+ - Each layer's input distribution changes during training
22
+ - Later layers constantly adapt to moving targets
23
+ - Training becomes unstable
24
+
25
+ ### 1.2 The Solution Intuition
26
+
27
+ Force each layer's inputs to have consistent statistics.
28
+ "Standardize the signal before processing it."
29
+
30
+ ---
31
+
32
+ ## 2. The General Normalization Framework
33
+
34
+ ### 2.1 Forward Pass
35
+
36
+ All normalization layers follow:
37
+
38
+ ```
39
+ 1. Compute statistics (μ, σ) over some dimension
40
+ 2. Normalize: x̂ = (x - μ) / σ
41
+ 3. Scale and shift: y = γ·x̂ + β
42
+ ```
43
+
44
+ What differs is WHICH dimensions we compute statistics over.
45
+
46
+ ### 2.2 Fluxion View
47
+
48
+ Let x be the input signal.
49
+
50
+ ```
51
+ μ = mean(x) # Signal center
52
+ σ = std(x) # Signal spread
53
+ x̂ = (x - μ) / σ # Centered and scaled to unit variance
54
+ y = γ·x̂ + β # Learnable rescaling
55
+ ```
56
+
57
+ **γ** (gamma) = learned amplitude
58
+ **β** (beta) = learned offset
59
+
60
+ Without γ and β, normalization would constrain representational power.
61
+ With them, the network can learn to undo normalization if needed.
62
+
63
+ ---
64
+
65
+ ## 3. BatchNorm: Normalize Across Batch
66
+
67
+ ### 3.1 The Idea
68
+
69
+ For each feature/channel, compute mean and variance ACROSS the batch.
70
+
71
+ ```
72
+ Input: X of shape [B, D] (B samples, D features)
73
+
74
+ For each feature d:
75
+ μ_d = (1/B) Σᵢ X[i,d] # Mean of feature d across batch
76
+ σ_d = sqrt((1/B) Σᵢ (X[i,d] - μ_d)²) # Std of feature d
77
+
78
+ X̂[:,d] = (X[:,d] - μ_d) / σ_d
79
+ Y[:,d] = γ_d · X̂[:,d] + β_d
80
+ ```
81
+
82
+ ### 3.2 Fluxion Forward Pass
83
+
84
+ ```
85
+ μ = mean(X, dim=batch) # Shape: [D]
86
+ σ = std(X, dim=batch) # Shape: [D]
87
+ X̂ = (X - μ) / σ # Shape: [B, D]
88
+ Y = γ·X̂ + β # Shape: [B, D]
89
+ ```
90
+
91
+ ### 3.3 The Backward Pass (Fluxion Derivation)
92
+
93
+ Given L̇ʸ (upstream gradient), find L̇ˣ, L̇ᵞ, L̇ᵝ.
94
+
95
+ **Easy ones first:**
96
+
97
+ ```
98
+ L̇ᵝ = sum(L̇ʸ, dim=batch) # β gradient = sum of upstream
99
+ L̇ᵞ = sum(L̇ʸ · X̂, dim=batch) # γ gradient = upstream weighted by normalized input
100
+ ```
101
+
102
+ **L̇ˣ is tricky because μ and σ depend on ALL x values.**
103
+
104
+ Let's trace the wiggle:
105
+
106
+ ```
107
+ If x[i,d] wiggles:
108
+ 1. Direct effect on X̂[i,d]: ∂X̂[i,d]/∂x[i,d] = 1/σ
109
+ 2. Indirect effect via μ: changing x[i,d] shifts μ, affects ALL X̂[:,d]
110
+ 3. Indirect effect via σ: changing x[i,d] changes σ, affects ALL X̂[:,d]
111
+ ```
112
+
113
+ Full derivative:
114
+
115
+ ```
116
+ L̇ˣ̂ = L̇ʸ · γ # Gradient through scale
117
+
118
+ L̇σ = -sum(L̇ˣ̂ · (X-μ) / σ², dim=batch) # How σ wiggle affects loss
119
+
120
+ L̇μ = -sum(L̇ˣ̂ / σ, dim=batch) + L̇σ · (-2/B)·sum(X-μ, dim=batch)
121
+
122
+ L̇ˣ = L̇ˣ̂/σ + L̇σ·(2/B)·(X-μ)/σ + L̇μ/B
123
+ ```
124
+
125
+ ### 3.4 Simplified Form
126
+
127
+ After algebra, the BatchNorm backward becomes:
128
+
129
+ ```
130
+ L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂) - X̂·mean(L̇ˣ̂·X̂))
131
+ ```
132
+
133
+ **Interpretation:**
134
+ - Start with scaled upstream gradient
135
+ - Subtract its mean (center the gradient)
136
+ - Subtract correlation with normalized input (decorrelate)
137
+
138
+ "BatchNorm backward REDISTRIBUTES gradient to maintain zero-mean, unit-variance gradient flow."
139
+
140
+ ### 3.5 Inference Mode
141
+
142
+ At inference, we don't have a batch. Use running averages from training:
143
+
144
+ ```
145
+ μ_running = momentum·μ_running + (1-momentum)·μ_batch
146
+ σ_running = momentum·σ_running + (1-momentum)·σ_batch
147
+ ```
148
+
149
+ Then normalize with running stats.
150
+
151
+ ### 3.6 Problems with BatchNorm
152
+
153
+ 1. **Batch size dependence**: Small batches → noisy statistics
154
+ 2. **Not suitable for sequence models**: Each position needs different batch members
155
+ 3. **Inference/training mismatch**: Running stats ≠ batch stats
156
+
157
+ ---
158
+
159
+ ## 4. LayerNorm: Normalize Across Features
160
+
161
+ ### 4.1 The Idea
162
+
163
+ For each sample, compute mean and variance ACROSS features.
164
+
165
+ ```
166
+ Input: X of shape [B, D]
167
+
168
+ For each sample i:
169
+ μᵢ = (1/D) Σ_d X[i,d] # Mean across features
170
+ σᵢ = sqrt((1/D) Σ_d (X[i,d] - μᵢ)²)
171
+
172
+ X̂[i,:] = (X[i,:] - μᵢ) / σᵢ
173
+ Y[i,:] = γ · X̂[i,:] + β
174
+ ```
175
+
176
+ ### 4.2 Fluxion Forward Pass
177
+
178
+ ```
179
+ μ = mean(X, dim=features) # Shape: [B]
180
+ σ = std(X, dim=features) # Shape: [B]
181
+ X̂ = (X - μ) / σ # Shape: [B, D]
182
+ Y = γ·X̂ + β # Shape: [B, D]
183
+ ```
184
+
185
+ ### 4.3 Key Difference from BatchNorm
186
+
187
+ ```
188
+ BatchNorm: statistics across BATCH (each feature normalized independently)
189
+ LayerNorm: statistics across FEATURES (each sample normalized independently)
190
+ ```
191
+
192
+ ### 4.4 Why LayerNorm for Transformers?
193
+
194
+ 1. **No batch dependence**: Each token normalized independently
195
+ 2. **Works with any batch size**: Including batch=1 at inference
196
+ 3. **Sequence-friendly**: Position i doesn't need position j's statistics
197
+
198
+ ### 4.5 Backward Pass
199
+
200
+ Same structure as BatchNorm, but sum over features instead of batch:
201
+
202
+ ```
203
+ L̇ˣ̂ = L̇ʸ · γ
204
+
205
+ L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂, dim=features)
206
+ - X̂·mean(L̇ˣ̂·X̂, dim=features))
207
+ ```
208
+
209
+ ---
210
+
211
+ ## 5. RMSNorm: Skip the Mean
212
+
213
+ ### 5.1 The Simplification
214
+
215
+ LayerNorm computes mean AND variance.
216
+ RMSNorm: "What if we skip the mean centering?"
217
+
218
+ ```
219
+ RMS(x) = sqrt(mean(x²))
220
+ X̂ = X / RMS(X)
221
+ Y = γ · X̂
222
+ ```
223
+
224
+ No β parameter (no shift), no mean subtraction.
225
+
226
+ ### 5.2 Fluxion Forward Pass
227
+
228
+ ```
229
+ rms = sqrt(mean(X², dim=features)) # Shape: [B]
230
+ X̂ = X / rms # Shape: [B, D]
231
+ Y = γ · X̂ # Shape: [B, D]
232
+ ```
233
+
234
+ ### 5.3 Why It Works
235
+
236
+ Empirically, the mean-centering in LayerNorm contributes little.
237
+ RMSNorm achieves similar performance with:
238
+ - Fewer operations
239
+ - Simpler backward pass
240
+ - Better numerical stability
241
+
242
+ ### 5.4 Backward Pass
243
+
244
+ Much simpler without mean:
245
+
246
+ ```
247
+ L̇ˣ̂ = L̇ʸ · γ
248
+ L̇ʳᵐˢ = -sum(L̇ˣ̂ · X / rms²)
249
+ L̇ˣ = L̇ˣ̂/rms + L̇ʳᵐˢ · X/(D·rms)
250
+ ```
251
+
252
+ Simplified:
253
+ ```
254
+ L̇ˣ = (1/rms) · (L̇ˣ̂ - X̂·mean(L̇ˣ̂·X̂))
255
+ ```
256
+
257
+ One fewer term than LayerNorm!
258
+
259
+ ### 5.5 Usage
260
+
261
+ RMSNorm is used in:
262
+ - LLaMA
263
+ - Mistral
264
+ - Most modern LLMs
265
+
266
+ It's becoming the default for transformers.
267
+
268
+ ---
269
+
270
+ ## 6. Comparison Table
271
+
272
+ | Property | BatchNorm | LayerNorm | RMSNorm |
273
+ |----------|-----------|-----------|---------|
274
+ | Stats over | Batch | Features | Features |
275
+ | Learnable | γ, β | γ, β | γ only |
276
+ | Mean centering | Yes | Yes | No |
277
+ | Batch dependent | Yes | No | No |
278
+ | Inference mode | Running stats | Same as training | Same as training |
279
+ | Use case | CNNs | Transformers | Modern LLMs |
280
+ | Operations | Most | Medium | Fewest |
281
+
282
+ ---
283
+
284
+ ## 7. Pre-Norm vs Post-Norm
285
+
286
+ ### 7.1 Post-Norm (Original Transformer)
287
+
288
+ ```
289
+ X → Attention → Add(X) → LayerNorm → FFN → Add → LayerNorm → Output
290
+ ```
291
+
292
+ Normalize AFTER residual connection.
293
+
294
+ ### 7.2 Pre-Norm (Modern Default)
295
+
296
+ ```
297
+ X → LayerNorm → Attention → Add(X) → LayerNorm → FFN → Add → Output
298
+ ```
299
+
300
+ Normalize BEFORE attention/FFN.
301
+
302
+ ### 7.3 Fluxion Analysis
303
+
304
+ **Post-Norm gradient flow:**
305
+ ```
306
+ L̇ˣ = LayerNorm_backward(L̇ᵒᵘᵗ) # Gradient must flow through norm
307
+ ```
308
+
309
+ **Pre-Norm gradient flow:**
310
+ ```
311
+ L̇ˣ = L̇ᵒᵘᵗ + LayerNorm_backward(Attention_backward(L̇))
312
+
313
+ Direct highway!
314
+ ```
315
+
316
+ Pre-Norm has a direct gradient path that bypasses normalization.
317
+ This stabilizes training for deep networks.
318
+
319
+ ---
320
+
321
+ ## 8. Numerical Stability
322
+
323
+ ### 8.1 The Variance Problem
324
+
325
+ Computing variance naively:
326
+ ```
327
+ var = mean(x²) - mean(x)²
328
+ ```
329
+
330
+ If mean(x²) ≈ mean(x)², subtraction causes catastrophic cancellation.
331
+
332
+ ### 8.2 Welford's Algorithm
333
+
334
+ Compute variance in a single pass, numerically stable:
335
+
336
+ ```python
337
+ def welford_var(x):
338
+ n = 0
339
+ mean = 0
340
+ M2 = 0
341
+ for xi in x:
342
+ n += 1
343
+ delta = xi - mean
344
+ mean += delta / n
345
+ delta2 = xi - mean
346
+ M2 += delta * delta2
347
+ return M2 / n
348
+ ```
349
+
350
+ ### 8.3 Fused Kernels
351
+
352
+ The fluxion view reveals that forward and backward are tightly coupled:
353
+
354
+ Forward needs: μ, σ, X̂
355
+ Backward needs: μ, σ, X̂ (same!)
356
+
357
+ Fused kernel can:
358
+ 1. Compute μ, σ in one pass
359
+ 2. Store only X̂ (derived from X, μ, σ)
360
+ 3. Backward reuses X̂ directly
361
+
362
+ This is why PyTorch's native LayerNorm is much faster than naive implementation.
363
+
364
+ ---
365
+
366
+ ## 9. Gradient Flow Analysis
367
+
368
+ ### 9.1 Without Normalization
369
+
370
+ Deep network gradient flow:
371
+ ```
372
+ L̇ˣ⁽⁰⁾ = W⁽¹⁾ᵀ · W⁽²⁾ᵀ · ... · W⁽ᴸ⁾ᵀ · L̇ʸ
373
+ ```
374
+
375
+ If weights are slightly > 1: gradient explodes
376
+ If weights are slightly < 1: gradient vanishes
377
+
378
+ ### 9.2 With Normalization
379
+
380
+ Each layer's activations are forced to unit variance.
381
+ Gradient magnitudes stabilize.
382
+
383
+ ```
384
+ L̇ˣ̂ has unit variance (approximately)
385
+ → L̇ˣ has controlled magnitude
386
+ → No explosion or vanishing
387
+ ```
388
+
389
+ ### 9.3 The Jacobian View
390
+
391
+ LayerNorm Jacobian ∂Y/∂X is NOT diagonal (because of mean/var coupling).
392
+
393
+ But it has a special structure:
394
+ ```
395
+ J = (1/σ) · (I - (1/D)·1·1ᵀ - (1/D)·X̂·X̂ᵀ)
396
+ ```
397
+
398
+ This projects out the mean direction and decorrelates from X̂.
399
+ Eigenvalues are bounded, preventing gradient explosion.
400
+
401
+ ---
402
+
403
+ ## 10. Implementation Details
404
+
405
+ ### 10.1 Memory Layout Matters
406
+
407
+ LayerNorm over last dimension (features):
408
+ ```python
409
+ # Contiguous memory access pattern
410
+ for b in range(B):
411
+ for d in range(D): # Sequential access
412
+ ...
413
+ ```
414
+
415
+ BatchNorm over batch dimension:
416
+ ```python
417
+ # Strided memory access pattern
418
+ for d in range(D):
419
+ for b in range(B): # Jumping through memory
420
+ ...
421
+ ```
422
+
423
+ LayerNorm is more cache-friendly for typical tensor layouts [B, ..., D].
424
+
425
+ ### 10.2 Epsilon Placement
426
+
427
+ ```python
428
+ # Wrong (can still divide by zero if var=0):
429
+ x_hat = (x - mean) / sqrt(var + eps)
430
+
431
+ # Right (always safe):
432
+ x_hat = (x - mean) / (sqrt(var) + eps)
433
+
434
+ # Also right (fused):
435
+ x_hat = (x - mean) * rsqrt(var + eps)
436
+ ```
437
+
438
+ The `rsqrt` (reciprocal square root) is a single GPU instruction.
439
+
440
+ ---
441
+
442
+ ## 11. Summary
443
+
444
+ ### 11.1 Fluxion View of Normalization
445
+
446
+ **Forward:**
447
+ ```
448
+ μ̇, σ̇ computed from X
449
+ X̂ = (X - μ)/σ # Standardize
450
+ Y = γ·X̂ + β # Scale and shift
451
+ ```
452
+
453
+ **Backward:**
454
+ ```
455
+ L̇ᵞ = sum(L̇ʸ · X̂) # Scale gradient
456
+ L̇ᵝ = sum(L̇ʸ) # Shift gradient
457
+ L̇ˣ = redistributed gradient (centered, decorrelated)
458
+ ```
459
+
460
+ ### 11.2 Key Insight
461
+
462
+ Normalization doesn't just scale activations—it COUPLES gradient flow across the normalized dimension.
463
+
464
+ Each input's gradient depends on ALL other inputs in the normalization group.
465
+
466
+ This coupling:
467
+ - Stabilizes gradient magnitudes
468
+ - Prevents single features from dominating
469
+ - Enables deeper networks
470
+
471
+ ---
472
+
473
+ ## References
474
+
475
+ 1. Ioffe & Szegedy (2015). "Batch Normalization: Accelerating Deep Network Training."
476
+ 2. Ba, Kiros & Hinton (2016). "Layer Normalization."
477
+ 3. Zhang & Sennrich (2019). "Root Mean Square Layer Normalization."
478
+ 4. Xiong et al. (2020). "On Layer Normalization in the Transformer Architecture."
479
+
480
+ ---
481
+
482
+ *Correspondence: scott@opentransformers.online*