# Function Mapping: Original Implementation → Our Implementation This document maps functions from the original ResShift implementation to our corresponding functions and explains what each function does. ## Core Diffusion Functions ### Forward Process (Noise Addition) | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `GaussianDiffusion.q_sample(x_start, y, t, noise=None)` | `trainer.py: training_step()` (line 434) | **Forward diffusion process**: Adds noise to HR image according to ResShift schedule. Formula: `x_t = x_0 + η_t * (y - x_0) + κ * √η_t * ε` | | `GaussianDiffusion.q_mean_variance(x_start, y, t)` | Not directly used | Computes mean and variance of forward process `q(x_t | x_0, y)` | | `GaussianDiffusion.q_posterior_mean_variance(x_start, x_t, t)` | `trainer.py: validation()` (line 845) | Computes posterior mean and variance `q(x_{t-1} | x_t, x_0)` for backward sampling. Used in equation (7) | ### Backward Process (Sampling) | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `GaussianDiffusion.p_mean_variance(model, x_t, y, t, ...)` | `trainer.py: validation()` (lines 844-848)
`inference.py: inference_single_image()` (lines 251-255)
`app.py: super_resolve()` (lines 154-158) | **Computes backward step parameters**: Calculates mean `μ_θ` and variance `Σ_θ` for equation (7). Mean: `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)`
Variance: `Σ_θ = κ² * (η_{t-1}/η_t) * α_t` | | `GaussianDiffusion.p_sample(model, x, y, t, ...)` | `trainer.py: validation()` (lines 850-853)
`inference.py: inference_single_image()` (lines 257-260)
`app.py: super_resolve()` (lines 160-163) | **Single backward sampling step**: Samples `x_{t-1}` from `p(x_{t-1} | x_t, y)` using equation (7). Formula: `x_{t-1} = μ_θ + √Σ_θ * ε` (with nonzero_mask for t > 0) | | `GaussianDiffusion.p_sample_loop(y, model, ...)` | `trainer.py: validation()` (lines 822-856)
`inference.py: inference_single_image()` (lines 229-263)
`app.py: super_resolve()` (lines 133-165) | **Full sampling loop**: Iterates from t = T-1 down to t = 0, calling `p_sample` at each step. Returns final denoised sample | | `GaussianDiffusion.p_sample_loop_progressive(y, model, ...)` | Same as above (we don't use progressive) | Same as `p_sample_loop` but yields intermediate samples at each timestep | | `GaussianDiffusion.prior_sample(y, noise=None)` | `trainer.py: validation()` (lines 813-815)
`inference.py: inference_single_image()` (lines 223-226)
`app.py: super_resolve()` (lines 128-130) | **Initializes x_T for sampling**: Generates starting point from prior distribution. Formula: `x_T = y + κ * √η_T * noise` (starts from LR + noise) | ### Input Scaling | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `GaussianDiffusion._scale_input(inputs, t)` | `trainer.py: _scale_input()` (lines 367-389)
`inference.py: _scale_input()` (lines 151-173)
`app.py: _scale_input()` (lines 50-72) | **Normalizes input variance**: Scales input `x_t` to normalize variance across timesteps for training stability. Formula: `x_scaled = x_t / std` where `std = √(η_t * κ² + 1)` for latent space | ### Training Loss | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `GaussianDiffusion.training_losses(model, x_start, y, t, ...)` | `trainer.py: training_step()` (lines 392-493) | **Computes training loss**: Encodes HR/LR to latent, adds noise via `q_sample`, predicts x0, computes MSE loss. Our implementation: predicts x0 directly (ModelMeanType.START_X) | ### Autoencoder Functions | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `GaussianDiffusion.encode_first_stage(y, first_stage_model, up_sample=False)` | `autoencoder.py: VQGANWrapper.encode()`
`data.py: SRDatasetOnTheFly.__getitem__()` (line 160) | **Encodes image to latent**: Encodes pixel-space image to latent space using VQGAN. If `up_sample=True`, upsamples LR before encoding | | `GaussianDiffusion.decode_first_stage(z_sample, first_stage_model, ...)` | `autoencoder.py: VQGANWrapper.decode()`
`trainer.py: validation()` (line 861)
`inference.py: inference_single_image()` (line 269) | **Decodes latent to image**: Decodes latent-space tensor back to pixel space using VQGAN decoder | ## Trainer Functions ### Original Trainer (`original_trainer.py`) | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `Trainer.__init__(configs, ...)` | `trainer.py: Trainer.__init__()` (lines 36-75) | **Initializes trainer**: Sets up device, checkpoint directory, noise schedule, loss function, WandB | | `Trainer.setup_seed(seed=None)` | `trainer.py: setup_seed()` (lines 76-95) | **Sets random seeds**: Ensures reproducibility by setting seeds for random, numpy, torch, and CUDA | | `Trainer.build_model()` | `trainer.py: build_model()` (lines 212-267) | **Builds model and autoencoder**: Initializes FullUNET model, optionally compiles it, loads VQGAN autoencoder, initializes LPIPS metric | | `Trainer.setup_optimization()` | `trainer.py: setup_optimization()` (lines 141-211) | **Sets up optimizer and scheduler**: Initializes AdamW optimizer, AMP scaler (if enabled), CosineAnnealingLR scheduler (if enabled) | | `Trainer.build_dataloader()` | `trainer.py: build_dataloader()` (lines 283-344) | **Builds data loaders**: Creates train and validation DataLoaders, wraps train loader to cycle infinitely | | `Trainer.training_losses()` | `trainer.py: training_step()` (lines 392-493) | **Training step**: Implements micro-batching, adds noise, forward pass, loss computation, backward step, gradient clipping, optimizer step | | `Trainer.validation(phase='val')` | `trainer.py: validation()` (lines 748-963) | **Validation loop**: Runs full diffusion sampling loop, decodes results, computes PSNR/SSIM/LPIPS metrics, logs to WandB | | `Trainer.adjust_lr(current_iters=None)` | `trainer.py: adjust_lr()` (lines 495-519) | **Learning rate scheduling**: Implements linear warmup, then cosine annealing (if enabled) | | `Trainer.save_ckpt()` | `trainer.py: save_ckpt()` (lines 520-562) | **Saves checkpoint**: Saves model, optimizer, AMP scaler, LR scheduler states, current iteration | | `Trainer.resume_from_ckpt(ckpt_path)` | `trainer.py: resume_from_ckpt()` (lines 563-670) | **Resumes from checkpoint**: Loads model, optimizer, scaler, scheduler states, restores iteration count and LR schedule | | `Trainer.log_step_train(...)` | `trainer.py: log_step_train()` (lines 671-747) | **Logs training metrics**: Logs loss, learning rate, images (HR, LR, noisy input, prediction) to WandB at specified frequencies | | `Trainer.reload_ema_model()` | `trainer.py: validation()` (line 754) | **Loads EMA model**: Uses EMA model for validation if `use_ema_val=True` | ## Inference Functions ### Original Sampler (`original_sampler.py`) | Original Function | Our Implementation | Description | |------------------|-------------------|-------------| | `ResShiftSampler.sample_func(y0, noise_repeat=False, mask=False)` | `inference.py: inference_single_image()` (lines 175-274)
`app.py: super_resolve()` (lines 93-165) | **Single image inference**: Encodes LR to latent, runs full diffusion sampling loop, decodes to pixel space | | `ResShiftSampler.inference(in_path, out_path, ...)` | `inference.py: main()` (lines 385-509) | **Batch inference**: Processes single image or directory of images, handles chopping for large images | | `ResShiftSampler.build_model()` | `inference.py: load_model()` (lines 322-384) | **Loads model and autoencoder**: Loads checkpoint, handles compiled model checkpoints (strips `_orig_mod.` prefix), loads EMA if specified | ## Key Differences and Notes ### 1. **Model Prediction Type** - **Original**: Supports multiple prediction types (START_X, EPSILON, RESIDUAL, EPSILON_SCALE) - **Our Implementation**: Only uses START_X (predicts x0 directly), matching ResShift paper ### 2. **Sampling Initialization** - **Original**: Uses `prior_sample(y, noise)` → `x_T = y + κ * √η_T * noise` (starts from LR + noise) - **Our Implementation**: Same approach in inference and validation (fixed in validation to match original) ### 3. **Backward Equation (Equation 7)** - **Original**: Uses `p_mean_variance()` → computes `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred` and `Σ_θ = κ² * (η_{t-1}/η_t) * α_t` - **Our Implementation**: Same equation, implemented directly in sampling loops ### 4. **Input Scaling** - **Original**: `_scale_input()` normalizes variance: `x_scaled = x_t / √(η_t * κ² + 1)` for latent space - **Our Implementation**: Same formula, applied in training and inference ### 5. **Training Loss** - **Original**: Supports weighted MSE based on posterior variance - **Our Implementation**: Uses simple MSE loss (predicts x0, compares with HR latent) ### 6. **Validation** - **Original**: Uses `p_sample_loop_progressive()` with EMA model, computes PSNR/LPIPS - **Our Implementation**: Same approach, also computes SSIM, uses EMA if `use_ema_val=True` ### 7. **Checkpoint Handling** - **Original**: Standard checkpoint loading - **Our Implementation**: Handles compiled model checkpoints (strips `_orig_mod.` prefix) for `torch.compile()` compatibility ## Function Call Flow ### Training Flow ``` train.py: train() → Trainer.__init__() → Trainer.build_model() → Trainer.setup_optimization() → Trainer.build_dataloader() → Loop: → Trainer.training_step() # q_sample + forward + loss + backward → Trainer.adjust_lr() → Trainer.validation() # p_sample_loop → Trainer.log_step_train() → Trainer.save_ckpt() ``` ### Inference Flow ``` inference.py: main() → load_model() # Load checkpoint → get_vqgan() # Load autoencoder → inference_single_image() → autoencoder.encode() # LR → latent → prior_sample() # Initialize x_T → Loop: p_sample() # Denoise T-1 → 0 → autoencoder.decode() # Latent → pixel ``` ## Summary Our implementation closely follows the original ResShift implementation, with the following key mappings: - **Forward process**: `q_sample` → `training_step()` noise addition - **Backward process**: `p_sample` → sampling loops in `validation()` and `inference_single_image()` - **Training**: `training_losses` → `training_step()` - **Autoencoder**: `encode_first_stage`/`decode_first_stage` → `VQGANWrapper.encode()`/`decode()` - **Input scaling**: `_scale_input` → `_scale_input()` (same name, same logic) All core diffusion equations (forward process, backward equation 7, input scaling) match the original implementation.