# Function Mapping: Original Implementation → Our Implementation
This document maps functions from the original ResShift implementation to our corresponding functions and explains what each function does.
## Core Diffusion Functions
### Forward Process (Noise Addition)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.q_sample(x_start, y, t, noise=None)` | `trainer.py: training_step()` (line 434) | **Forward diffusion process**: Adds noise to HR image according to ResShift schedule. Formula: `x_t = x_0 + η_t * (y - x_0) + κ * √η_t * ε` |
| `GaussianDiffusion.q_mean_variance(x_start, y, t)` | Not directly used | Computes mean and variance of forward process `q(x_t | x_0, y)` |
| `GaussianDiffusion.q_posterior_mean_variance(x_start, x_t, t)` | `trainer.py: validation()` (line 845) | Computes posterior mean and variance `q(x_{t-1} | x_t, x_0)` for backward sampling. Used in equation (7) |
### Backward Process (Sampling)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.p_mean_variance(model, x_t, y, t, ...)` | `trainer.py: validation()` (lines 844-848)
`inference.py: inference_single_image()` (lines 251-255)
`app.py: super_resolve()` (lines 154-158) | **Computes backward step parameters**: Calculates mean `μ_θ` and variance `Σ_θ` for equation (7). Mean: `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)`
Variance: `Σ_θ = κ² * (η_{t-1}/η_t) * α_t` |
| `GaussianDiffusion.p_sample(model, x, y, t, ...)` | `trainer.py: validation()` (lines 850-853)
`inference.py: inference_single_image()` (lines 257-260)
`app.py: super_resolve()` (lines 160-163) | **Single backward sampling step**: Samples `x_{t-1}` from `p(x_{t-1} | x_t, y)` using equation (7). Formula: `x_{t-1} = μ_θ + √Σ_θ * ε` (with nonzero_mask for t > 0) |
| `GaussianDiffusion.p_sample_loop(y, model, ...)` | `trainer.py: validation()` (lines 822-856)
`inference.py: inference_single_image()` (lines 229-263)
`app.py: super_resolve()` (lines 133-165) | **Full sampling loop**: Iterates from t = T-1 down to t = 0, calling `p_sample` at each step. Returns final denoised sample |
| `GaussianDiffusion.p_sample_loop_progressive(y, model, ...)` | Same as above (we don't use progressive) | Same as `p_sample_loop` but yields intermediate samples at each timestep |
| `GaussianDiffusion.prior_sample(y, noise=None)` | `trainer.py: validation()` (lines 813-815)
`inference.py: inference_single_image()` (lines 223-226)
`app.py: super_resolve()` (lines 128-130) | **Initializes x_T for sampling**: Generates starting point from prior distribution. Formula: `x_T = y + κ * √η_T * noise` (starts from LR + noise) |
### Input Scaling
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion._scale_input(inputs, t)` | `trainer.py: _scale_input()` (lines 367-389)
`inference.py: _scale_input()` (lines 151-173)
`app.py: _scale_input()` (lines 50-72) | **Normalizes input variance**: Scales input `x_t` to normalize variance across timesteps for training stability. Formula: `x_scaled = x_t / std` where `std = √(η_t * κ² + 1)` for latent space |
### Training Loss
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.training_losses(model, x_start, y, t, ...)` | `trainer.py: training_step()` (lines 392-493) | **Computes training loss**: Encodes HR/LR to latent, adds noise via `q_sample`, predicts x0, computes MSE loss. Our implementation: predicts x0 directly (ModelMeanType.START_X) |
### Autoencoder Functions
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.encode_first_stage(y, first_stage_model, up_sample=False)` | `autoencoder.py: VQGANWrapper.encode()`
`data.py: SRDatasetOnTheFly.__getitem__()` (line 160) | **Encodes image to latent**: Encodes pixel-space image to latent space using VQGAN. If `up_sample=True`, upsamples LR before encoding |
| `GaussianDiffusion.decode_first_stage(z_sample, first_stage_model, ...)` | `autoencoder.py: VQGANWrapper.decode()`
`trainer.py: validation()` (line 861)
`inference.py: inference_single_image()` (line 269) | **Decodes latent to image**: Decodes latent-space tensor back to pixel space using VQGAN decoder |
## Trainer Functions
### Original Trainer (`original_trainer.py`)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `Trainer.__init__(configs, ...)` | `trainer.py: Trainer.__init__()` (lines 36-75) | **Initializes trainer**: Sets up device, checkpoint directory, noise schedule, loss function, WandB |
| `Trainer.setup_seed(seed=None)` | `trainer.py: setup_seed()` (lines 76-95) | **Sets random seeds**: Ensures reproducibility by setting seeds for random, numpy, torch, and CUDA |
| `Trainer.build_model()` | `trainer.py: build_model()` (lines 212-267) | **Builds model and autoencoder**: Initializes FullUNET model, optionally compiles it, loads VQGAN autoencoder, initializes LPIPS metric |
| `Trainer.setup_optimization()` | `trainer.py: setup_optimization()` (lines 141-211) | **Sets up optimizer and scheduler**: Initializes AdamW optimizer, AMP scaler (if enabled), CosineAnnealingLR scheduler (if enabled) |
| `Trainer.build_dataloader()` | `trainer.py: build_dataloader()` (lines 283-344) | **Builds data loaders**: Creates train and validation DataLoaders, wraps train loader to cycle infinitely |
| `Trainer.training_losses()` | `trainer.py: training_step()` (lines 392-493) | **Training step**: Implements micro-batching, adds noise, forward pass, loss computation, backward step, gradient clipping, optimizer step |
| `Trainer.validation(phase='val')` | `trainer.py: validation()` (lines 748-963) | **Validation loop**: Runs full diffusion sampling loop, decodes results, computes PSNR/SSIM/LPIPS metrics, logs to WandB |
| `Trainer.adjust_lr(current_iters=None)` | `trainer.py: adjust_lr()` (lines 495-519) | **Learning rate scheduling**: Implements linear warmup, then cosine annealing (if enabled) |
| `Trainer.save_ckpt()` | `trainer.py: save_ckpt()` (lines 520-562) | **Saves checkpoint**: Saves model, optimizer, AMP scaler, LR scheduler states, current iteration |
| `Trainer.resume_from_ckpt(ckpt_path)` | `trainer.py: resume_from_ckpt()` (lines 563-670) | **Resumes from checkpoint**: Loads model, optimizer, scaler, scheduler states, restores iteration count and LR schedule |
| `Trainer.log_step_train(...)` | `trainer.py: log_step_train()` (lines 671-747) | **Logs training metrics**: Logs loss, learning rate, images (HR, LR, noisy input, prediction) to WandB at specified frequencies |
| `Trainer.reload_ema_model()` | `trainer.py: validation()` (line 754) | **Loads EMA model**: Uses EMA model for validation if `use_ema_val=True` |
## Inference Functions
### Original Sampler (`original_sampler.py`)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `ResShiftSampler.sample_func(y0, noise_repeat=False, mask=False)` | `inference.py: inference_single_image()` (lines 175-274)
`app.py: super_resolve()` (lines 93-165) | **Single image inference**: Encodes LR to latent, runs full diffusion sampling loop, decodes to pixel space |
| `ResShiftSampler.inference(in_path, out_path, ...)` | `inference.py: main()` (lines 385-509) | **Batch inference**: Processes single image or directory of images, handles chopping for large images |
| `ResShiftSampler.build_model()` | `inference.py: load_model()` (lines 322-384) | **Loads model and autoencoder**: Loads checkpoint, handles compiled model checkpoints (strips `_orig_mod.` prefix), loads EMA if specified |
## Key Differences and Notes
### 1. **Model Prediction Type**
- **Original**: Supports multiple prediction types (START_X, EPSILON, RESIDUAL, EPSILON_SCALE)
- **Our Implementation**: Only uses START_X (predicts x0 directly), matching ResShift paper
### 2. **Sampling Initialization**
- **Original**: Uses `prior_sample(y, noise)` → `x_T = y + κ * √η_T * noise` (starts from LR + noise)
- **Our Implementation**: Same approach in inference and validation (fixed in validation to match original)
### 3. **Backward Equation (Equation 7)**
- **Original**: Uses `p_mean_variance()` → computes `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred` and `Σ_θ = κ² * (η_{t-1}/η_t) * α_t`
- **Our Implementation**: Same equation, implemented directly in sampling loops
### 4. **Input Scaling**
- **Original**: `_scale_input()` normalizes variance: `x_scaled = x_t / √(η_t * κ² + 1)` for latent space
- **Our Implementation**: Same formula, applied in training and inference
### 5. **Training Loss**
- **Original**: Supports weighted MSE based on posterior variance
- **Our Implementation**: Uses simple MSE loss (predicts x0, compares with HR latent)
### 6. **Validation**
- **Original**: Uses `p_sample_loop_progressive()` with EMA model, computes PSNR/LPIPS
- **Our Implementation**: Same approach, also computes SSIM, uses EMA if `use_ema_val=True`
### 7. **Checkpoint Handling**
- **Original**: Standard checkpoint loading
- **Our Implementation**: Handles compiled model checkpoints (strips `_orig_mod.` prefix) for `torch.compile()` compatibility
## Function Call Flow
### Training Flow
```
train.py: train()
→ Trainer.__init__()
→ Trainer.build_model()
→ Trainer.setup_optimization()
→ Trainer.build_dataloader()
→ Loop:
→ Trainer.training_step() # q_sample + forward + loss + backward
→ Trainer.adjust_lr()
→ Trainer.validation() # p_sample_loop
→ Trainer.log_step_train()
→ Trainer.save_ckpt()
```
### Inference Flow
```
inference.py: main()
→ load_model() # Load checkpoint
→ get_vqgan() # Load autoencoder
→ inference_single_image()
→ autoencoder.encode() # LR → latent
→ prior_sample() # Initialize x_T
→ Loop: p_sample() # Denoise T-1 → 0
→ autoencoder.decode() # Latent → pixel
```
## Summary
Our implementation closely follows the original ResShift implementation, with the following key mappings:
- **Forward process**: `q_sample` → `training_step()` noise addition
- **Backward process**: `p_sample` → sampling loops in `validation()` and `inference_single_image()`
- **Training**: `training_losses` → `training_step()`
- **Autoencoder**: `encode_first_stage`/`decode_first_stage` → `VQGANWrapper.encode()`/`decode()`
- **Input scaling**: `_scale_input` → `_scale_input()` (same name, same logic)
All core diffusion equations (forward process, backward equation 7, input scaling) match the original implementation.