Upload normalization_via_fluxions.md with huggingface_hub
Browse files- normalization_via_fluxions.md +482 -0
normalization_via_fluxions.md
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Normalization Layers via the Method of Fluxions
|
| 2 |
+
## BatchNorm, LayerNorm, RMSNorm: What They Actually Do
|
| 3 |
+
|
| 4 |
+
**Scott Bisset, Silicon Goddess**
|
| 5 |
+
OpenTransformers Ltd
|
| 6 |
+
January 2026
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Abstract
|
| 11 |
+
|
| 12 |
+
Normalization layers (BatchNorm, LayerNorm, RMSNorm) are presented in textbooks as "normalize then scale and shift" with formulas involving means and variances. This obscures their true purpose and makes backward pass derivation seem magical. We reformulate normalization using fluxions, revealing: (1) normalization as signal conditioning, (2) the backward pass as sensitivity redistribution, and (3) why different norms suit different architectures. The fluxion view also explains the computational structure that enables fused kernels.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## 1. Why Normalize?
|
| 17 |
+
|
| 18 |
+
### 1.1 The Problem
|
| 19 |
+
|
| 20 |
+
Deep networks suffer from internal covariate shift:
|
| 21 |
+
- Each layer's input distribution changes during training
|
| 22 |
+
- Later layers constantly adapt to moving targets
|
| 23 |
+
- Training becomes unstable
|
| 24 |
+
|
| 25 |
+
### 1.2 The Solution Intuition
|
| 26 |
+
|
| 27 |
+
Force each layer's inputs to have consistent statistics.
|
| 28 |
+
"Standardize the signal before processing it."
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## 2. The General Normalization Framework
|
| 33 |
+
|
| 34 |
+
### 2.1 Forward Pass
|
| 35 |
+
|
| 36 |
+
All normalization layers follow:
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
1. Compute statistics (μ, σ) over some dimension
|
| 40 |
+
2. Normalize: x̂ = (x - μ) / σ
|
| 41 |
+
3. Scale and shift: y = γ·x̂ + β
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
What differs is WHICH dimensions we compute statistics over.
|
| 45 |
+
|
| 46 |
+
### 2.2 Fluxion View
|
| 47 |
+
|
| 48 |
+
Let x be the input signal.
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
μ = mean(x) # Signal center
|
| 52 |
+
σ = std(x) # Signal spread
|
| 53 |
+
x̂ = (x - μ) / σ # Centered and scaled to unit variance
|
| 54 |
+
y = γ·x̂ + β # Learnable rescaling
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
**γ** (gamma) = learned amplitude
|
| 58 |
+
**β** (beta) = learned offset
|
| 59 |
+
|
| 60 |
+
Without γ and β, normalization would constrain representational power.
|
| 61 |
+
With them, the network can learn to undo normalization if needed.
|
| 62 |
+
|
| 63 |
+
---
|
| 64 |
+
|
| 65 |
+
## 3. BatchNorm: Normalize Across Batch
|
| 66 |
+
|
| 67 |
+
### 3.1 The Idea
|
| 68 |
+
|
| 69 |
+
For each feature/channel, compute mean and variance ACROSS the batch.
|
| 70 |
+
|
| 71 |
+
```
|
| 72 |
+
Input: X of shape [B, D] (B samples, D features)
|
| 73 |
+
|
| 74 |
+
For each feature d:
|
| 75 |
+
μ_d = (1/B) Σᵢ X[i,d] # Mean of feature d across batch
|
| 76 |
+
σ_d = sqrt((1/B) Σᵢ (X[i,d] - μ_d)²) # Std of feature d
|
| 77 |
+
|
| 78 |
+
X̂[:,d] = (X[:,d] - μ_d) / σ_d
|
| 79 |
+
Y[:,d] = γ_d · X̂[:,d] + β_d
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### 3.2 Fluxion Forward Pass
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
μ = mean(X, dim=batch) # Shape: [D]
|
| 86 |
+
σ = std(X, dim=batch) # Shape: [D]
|
| 87 |
+
X̂ = (X - μ) / σ # Shape: [B, D]
|
| 88 |
+
Y = γ·X̂ + β # Shape: [B, D]
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### 3.3 The Backward Pass (Fluxion Derivation)
|
| 92 |
+
|
| 93 |
+
Given L̇ʸ (upstream gradient), find L̇ˣ, L̇ᵞ, L̇ᵝ.
|
| 94 |
+
|
| 95 |
+
**Easy ones first:**
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
L̇ᵝ = sum(L̇ʸ, dim=batch) # β gradient = sum of upstream
|
| 99 |
+
L̇ᵞ = sum(L̇ʸ · X̂, dim=batch) # γ gradient = upstream weighted by normalized input
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
**L̇ˣ is tricky because μ and σ depend on ALL x values.**
|
| 103 |
+
|
| 104 |
+
Let's trace the wiggle:
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
If x[i,d] wiggles:
|
| 108 |
+
1. Direct effect on X̂[i,d]: ∂X̂[i,d]/∂x[i,d] = 1/σ
|
| 109 |
+
2. Indirect effect via μ: changing x[i,d] shifts μ, affects ALL X̂[:,d]
|
| 110 |
+
3. Indirect effect via σ: changing x[i,d] changes σ, affects ALL X̂[:,d]
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Full derivative:
|
| 114 |
+
|
| 115 |
+
```
|
| 116 |
+
L̇ˣ̂ = L̇ʸ · γ # Gradient through scale
|
| 117 |
+
|
| 118 |
+
L̇σ = -sum(L̇ˣ̂ · (X-μ) / σ², dim=batch) # How σ wiggle affects loss
|
| 119 |
+
|
| 120 |
+
L̇μ = -sum(L̇ˣ̂ / σ, dim=batch) + L̇σ · (-2/B)·sum(X-μ, dim=batch)
|
| 121 |
+
|
| 122 |
+
L̇ˣ = L̇ˣ̂/σ + L̇σ·(2/B)·(X-μ)/σ + L̇μ/B
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### 3.4 Simplified Form
|
| 126 |
+
|
| 127 |
+
After algebra, the BatchNorm backward becomes:
|
| 128 |
+
|
| 129 |
+
```
|
| 130 |
+
L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂) - X̂·mean(L̇ˣ̂·X̂))
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
**Interpretation:**
|
| 134 |
+
- Start with scaled upstream gradient
|
| 135 |
+
- Subtract its mean (center the gradient)
|
| 136 |
+
- Subtract correlation with normalized input (decorrelate)
|
| 137 |
+
|
| 138 |
+
"BatchNorm backward REDISTRIBUTES gradient to maintain zero-mean, unit-variance gradient flow."
|
| 139 |
+
|
| 140 |
+
### 3.5 Inference Mode
|
| 141 |
+
|
| 142 |
+
At inference, we don't have a batch. Use running averages from training:
|
| 143 |
+
|
| 144 |
+
```
|
| 145 |
+
μ_running = momentum·μ_running + (1-momentum)·μ_batch
|
| 146 |
+
σ_running = momentum·σ_running + (1-momentum)·σ_batch
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
Then normalize with running stats.
|
| 150 |
+
|
| 151 |
+
### 3.6 Problems with BatchNorm
|
| 152 |
+
|
| 153 |
+
1. **Batch size dependence**: Small batches → noisy statistics
|
| 154 |
+
2. **Not suitable for sequence models**: Each position needs different batch members
|
| 155 |
+
3. **Inference/training mismatch**: Running stats ≠ batch stats
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## 4. LayerNorm: Normalize Across Features
|
| 160 |
+
|
| 161 |
+
### 4.1 The Idea
|
| 162 |
+
|
| 163 |
+
For each sample, compute mean and variance ACROSS features.
|
| 164 |
+
|
| 165 |
+
```
|
| 166 |
+
Input: X of shape [B, D]
|
| 167 |
+
|
| 168 |
+
For each sample i:
|
| 169 |
+
μᵢ = (1/D) Σ_d X[i,d] # Mean across features
|
| 170 |
+
σᵢ = sqrt((1/D) Σ_d (X[i,d] - μᵢ)²)
|
| 171 |
+
|
| 172 |
+
X̂[i,:] = (X[i,:] - μᵢ) / σᵢ
|
| 173 |
+
Y[i,:] = γ · X̂[i,:] + β
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
### 4.2 Fluxion Forward Pass
|
| 177 |
+
|
| 178 |
+
```
|
| 179 |
+
μ = mean(X, dim=features) # Shape: [B]
|
| 180 |
+
σ = std(X, dim=features) # Shape: [B]
|
| 181 |
+
X̂ = (X - μ) / σ # Shape: [B, D]
|
| 182 |
+
Y = γ·X̂ + β # Shape: [B, D]
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
### 4.3 Key Difference from BatchNorm
|
| 186 |
+
|
| 187 |
+
```
|
| 188 |
+
BatchNorm: statistics across BATCH (each feature normalized independently)
|
| 189 |
+
LayerNorm: statistics across FEATURES (each sample normalized independently)
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
### 4.4 Why LayerNorm for Transformers?
|
| 193 |
+
|
| 194 |
+
1. **No batch dependence**: Each token normalized independently
|
| 195 |
+
2. **Works with any batch size**: Including batch=1 at inference
|
| 196 |
+
3. **Sequence-friendly**: Position i doesn't need position j's statistics
|
| 197 |
+
|
| 198 |
+
### 4.5 Backward Pass
|
| 199 |
+
|
| 200 |
+
Same structure as BatchNorm, but sum over features instead of batch:
|
| 201 |
+
|
| 202 |
+
```
|
| 203 |
+
L̇ˣ̂ = L̇ʸ · γ
|
| 204 |
+
|
| 205 |
+
L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂, dim=features)
|
| 206 |
+
- X̂·mean(L̇ˣ̂·X̂, dim=features))
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## 5. RMSNorm: Skip the Mean
|
| 212 |
+
|
| 213 |
+
### 5.1 The Simplification
|
| 214 |
+
|
| 215 |
+
LayerNorm computes mean AND variance.
|
| 216 |
+
RMSNorm: "What if we skip the mean centering?"
|
| 217 |
+
|
| 218 |
+
```
|
| 219 |
+
RMS(x) = sqrt(mean(x²))
|
| 220 |
+
X̂ = X / RMS(X)
|
| 221 |
+
Y = γ · X̂
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
No β parameter (no shift), no mean subtraction.
|
| 225 |
+
|
| 226 |
+
### 5.2 Fluxion Forward Pass
|
| 227 |
+
|
| 228 |
+
```
|
| 229 |
+
rms = sqrt(mean(X², dim=features)) # Shape: [B]
|
| 230 |
+
X̂ = X / rms # Shape: [B, D]
|
| 231 |
+
Y = γ · X̂ # Shape: [B, D]
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
### 5.3 Why It Works
|
| 235 |
+
|
| 236 |
+
Empirically, the mean-centering in LayerNorm contributes little.
|
| 237 |
+
RMSNorm achieves similar performance with:
|
| 238 |
+
- Fewer operations
|
| 239 |
+
- Simpler backward pass
|
| 240 |
+
- Better numerical stability
|
| 241 |
+
|
| 242 |
+
### 5.4 Backward Pass
|
| 243 |
+
|
| 244 |
+
Much simpler without mean:
|
| 245 |
+
|
| 246 |
+
```
|
| 247 |
+
L̇ˣ̂ = L̇ʸ · γ
|
| 248 |
+
L̇ʳᵐˢ = -sum(L̇ˣ̂ · X / rms²)
|
| 249 |
+
L̇ˣ = L̇ˣ̂/rms + L̇ʳᵐˢ · X/(D·rms)
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
Simplified:
|
| 253 |
+
```
|
| 254 |
+
L̇ˣ = (1/rms) · (L̇ˣ̂ - X̂·mean(L̇ˣ̂·X̂))
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
One fewer term than LayerNorm!
|
| 258 |
+
|
| 259 |
+
### 5.5 Usage
|
| 260 |
+
|
| 261 |
+
RMSNorm is used in:
|
| 262 |
+
- LLaMA
|
| 263 |
+
- Mistral
|
| 264 |
+
- Most modern LLMs
|
| 265 |
+
|
| 266 |
+
It's becoming the default for transformers.
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## 6. Comparison Table
|
| 271 |
+
|
| 272 |
+
| Property | BatchNorm | LayerNorm | RMSNorm |
|
| 273 |
+
|----------|-----------|-----------|---------|
|
| 274 |
+
| Stats over | Batch | Features | Features |
|
| 275 |
+
| Learnable | γ, β | γ, β | γ only |
|
| 276 |
+
| Mean centering | Yes | Yes | No |
|
| 277 |
+
| Batch dependent | Yes | No | No |
|
| 278 |
+
| Inference mode | Running stats | Same as training | Same as training |
|
| 279 |
+
| Use case | CNNs | Transformers | Modern LLMs |
|
| 280 |
+
| Operations | Most | Medium | Fewest |
|
| 281 |
+
|
| 282 |
+
---
|
| 283 |
+
|
| 284 |
+
## 7. Pre-Norm vs Post-Norm
|
| 285 |
+
|
| 286 |
+
### 7.1 Post-Norm (Original Transformer)
|
| 287 |
+
|
| 288 |
+
```
|
| 289 |
+
X → Attention → Add(X) → LayerNorm → FFN → Add → LayerNorm → Output
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
Normalize AFTER residual connection.
|
| 293 |
+
|
| 294 |
+
### 7.2 Pre-Norm (Modern Default)
|
| 295 |
+
|
| 296 |
+
```
|
| 297 |
+
X → LayerNorm → Attention → Add(X) → LayerNorm → FFN → Add → Output
|
| 298 |
+
```
|
| 299 |
+
|
| 300 |
+
Normalize BEFORE attention/FFN.
|
| 301 |
+
|
| 302 |
+
### 7.3 Fluxion Analysis
|
| 303 |
+
|
| 304 |
+
**Post-Norm gradient flow:**
|
| 305 |
+
```
|
| 306 |
+
L̇ˣ = LayerNorm_backward(L̇ᵒᵘᵗ) # Gradient must flow through norm
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
**Pre-Norm gradient flow:**
|
| 310 |
+
```
|
| 311 |
+
L̇ˣ = L̇ᵒᵘᵗ + LayerNorm_backward(Attention_backward(L̇))
|
| 312 |
+
↑
|
| 313 |
+
Direct highway!
|
| 314 |
+
```
|
| 315 |
+
|
| 316 |
+
Pre-Norm has a direct gradient path that bypasses normalization.
|
| 317 |
+
This stabilizes training for deep networks.
|
| 318 |
+
|
| 319 |
+
---
|
| 320 |
+
|
| 321 |
+
## 8. Numerical Stability
|
| 322 |
+
|
| 323 |
+
### 8.1 The Variance Problem
|
| 324 |
+
|
| 325 |
+
Computing variance naively:
|
| 326 |
+
```
|
| 327 |
+
var = mean(x²) - mean(x)²
|
| 328 |
+
```
|
| 329 |
+
|
| 330 |
+
If mean(x²) ≈ mean(x)², subtraction causes catastrophic cancellation.
|
| 331 |
+
|
| 332 |
+
### 8.2 Welford's Algorithm
|
| 333 |
+
|
| 334 |
+
Compute variance in a single pass, numerically stable:
|
| 335 |
+
|
| 336 |
+
```python
|
| 337 |
+
def welford_var(x):
|
| 338 |
+
n = 0
|
| 339 |
+
mean = 0
|
| 340 |
+
M2 = 0
|
| 341 |
+
for xi in x:
|
| 342 |
+
n += 1
|
| 343 |
+
delta = xi - mean
|
| 344 |
+
mean += delta / n
|
| 345 |
+
delta2 = xi - mean
|
| 346 |
+
M2 += delta * delta2
|
| 347 |
+
return M2 / n
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
### 8.3 Fused Kernels
|
| 351 |
+
|
| 352 |
+
The fluxion view reveals that forward and backward are tightly coupled:
|
| 353 |
+
|
| 354 |
+
Forward needs: μ, σ, X̂
|
| 355 |
+
Backward needs: μ, σ, X̂ (same!)
|
| 356 |
+
|
| 357 |
+
Fused kernel can:
|
| 358 |
+
1. Compute μ, σ in one pass
|
| 359 |
+
2. Store only X̂ (derived from X, μ, σ)
|
| 360 |
+
3. Backward reuses X̂ directly
|
| 361 |
+
|
| 362 |
+
This is why PyTorch's native LayerNorm is much faster than naive implementation.
|
| 363 |
+
|
| 364 |
+
---
|
| 365 |
+
|
| 366 |
+
## 9. Gradient Flow Analysis
|
| 367 |
+
|
| 368 |
+
### 9.1 Without Normalization
|
| 369 |
+
|
| 370 |
+
Deep network gradient flow:
|
| 371 |
+
```
|
| 372 |
+
L̇ˣ⁽⁰⁾ = W⁽¹⁾ᵀ · W⁽²⁾ᵀ · ... · W⁽ᴸ⁾ᵀ · L̇ʸ
|
| 373 |
+
```
|
| 374 |
+
|
| 375 |
+
If weights are slightly > 1: gradient explodes
|
| 376 |
+
If weights are slightly < 1: gradient vanishes
|
| 377 |
+
|
| 378 |
+
### 9.2 With Normalization
|
| 379 |
+
|
| 380 |
+
Each layer's activations are forced to unit variance.
|
| 381 |
+
Gradient magnitudes stabilize.
|
| 382 |
+
|
| 383 |
+
```
|
| 384 |
+
L̇ˣ̂ has unit variance (approximately)
|
| 385 |
+
→ L̇ˣ has controlled magnitude
|
| 386 |
+
→ No explosion or vanishing
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
### 9.3 The Jacobian View
|
| 390 |
+
|
| 391 |
+
LayerNorm Jacobian ∂Y/∂X is NOT diagonal (because of mean/var coupling).
|
| 392 |
+
|
| 393 |
+
But it has a special structure:
|
| 394 |
+
```
|
| 395 |
+
J = (1/σ) · (I - (1/D)·1·1ᵀ - (1/D)·X̂·X̂ᵀ)
|
| 396 |
+
```
|
| 397 |
+
|
| 398 |
+
This projects out the mean direction and decorrelates from X̂.
|
| 399 |
+
Eigenvalues are bounded, preventing gradient explosion.
|
| 400 |
+
|
| 401 |
+
---
|
| 402 |
+
|
| 403 |
+
## 10. Implementation Details
|
| 404 |
+
|
| 405 |
+
### 10.1 Memory Layout Matters
|
| 406 |
+
|
| 407 |
+
LayerNorm over last dimension (features):
|
| 408 |
+
```python
|
| 409 |
+
# Contiguous memory access pattern
|
| 410 |
+
for b in range(B):
|
| 411 |
+
for d in range(D): # Sequential access
|
| 412 |
+
...
|
| 413 |
+
```
|
| 414 |
+
|
| 415 |
+
BatchNorm over batch dimension:
|
| 416 |
+
```python
|
| 417 |
+
# Strided memory access pattern
|
| 418 |
+
for d in range(D):
|
| 419 |
+
for b in range(B): # Jumping through memory
|
| 420 |
+
...
|
| 421 |
+
```
|
| 422 |
+
|
| 423 |
+
LayerNorm is more cache-friendly for typical tensor layouts [B, ..., D].
|
| 424 |
+
|
| 425 |
+
### 10.2 Epsilon Placement
|
| 426 |
+
|
| 427 |
+
```python
|
| 428 |
+
# Wrong (can still divide by zero if var=0):
|
| 429 |
+
x_hat = (x - mean) / sqrt(var + eps)
|
| 430 |
+
|
| 431 |
+
# Right (always safe):
|
| 432 |
+
x_hat = (x - mean) / (sqrt(var) + eps)
|
| 433 |
+
|
| 434 |
+
# Also right (fused):
|
| 435 |
+
x_hat = (x - mean) * rsqrt(var + eps)
|
| 436 |
+
```
|
| 437 |
+
|
| 438 |
+
The `rsqrt` (reciprocal square root) is a single GPU instruction.
|
| 439 |
+
|
| 440 |
+
---
|
| 441 |
+
|
| 442 |
+
## 11. Summary
|
| 443 |
+
|
| 444 |
+
### 11.1 Fluxion View of Normalization
|
| 445 |
+
|
| 446 |
+
**Forward:**
|
| 447 |
+
```
|
| 448 |
+
μ̇, σ̇ computed from X
|
| 449 |
+
X̂ = (X - μ)/σ # Standardize
|
| 450 |
+
Y = γ·X̂ + β # Scale and shift
|
| 451 |
+
```
|
| 452 |
+
|
| 453 |
+
**Backward:**
|
| 454 |
+
```
|
| 455 |
+
L̇ᵞ = sum(L̇ʸ · X̂) # Scale gradient
|
| 456 |
+
L̇ᵝ = sum(L̇ʸ) # Shift gradient
|
| 457 |
+
L̇ˣ = redistributed gradient (centered, decorrelated)
|
| 458 |
+
```
|
| 459 |
+
|
| 460 |
+
### 11.2 Key Insight
|
| 461 |
+
|
| 462 |
+
Normalization doesn't just scale activations—it COUPLES gradient flow across the normalized dimension.
|
| 463 |
+
|
| 464 |
+
Each input's gradient depends on ALL other inputs in the normalization group.
|
| 465 |
+
|
| 466 |
+
This coupling:
|
| 467 |
+
- Stabilizes gradient magnitudes
|
| 468 |
+
- Prevents single features from dominating
|
| 469 |
+
- Enables deeper networks
|
| 470 |
+
|
| 471 |
+
---
|
| 472 |
+
|
| 473 |
+
## References
|
| 474 |
+
|
| 475 |
+
1. Ioffe & Szegedy (2015). "Batch Normalization: Accelerating Deep Network Training."
|
| 476 |
+
2. Ba, Kiros & Hinton (2016). "Layer Normalization."
|
| 477 |
+
3. Zhang & Sennrich (2019). "Root Mean Square Layer Normalization."
|
| 478 |
+
4. Xiong et al. (2020). "On Layer Normalization in the Transformer Architecture."
|
| 479 |
+
|
| 480 |
+
---
|
| 481 |
+
|
| 482 |
+
*Correspondence: scott@opentransformers.online*
|