DiffusionSR / FUNCTION_MAPPING.md
shekkari21's picture
Commiting all the super resolution files
3c45764

A newer version of the Gradio SDK is available: 6.2.0

Upgrade

Function Mapping: Original Implementation → Our Implementation

This document maps functions from the original ResShift implementation to our corresponding functions and explains what each function does.

Core Diffusion Functions

Forward Process (Noise Addition)

Original Function Our Implementation Description
GaussianDiffusion.q_sample(x_start, y, t, noise=None) trainer.py: training_step() (line 434) Forward diffusion process: Adds noise to HR image according to ResShift schedule. Formula: x_t = x_0 + η_t * (y - x_0) + κ * √η_t * ε
GaussianDiffusion.q_mean_variance(x_start, y, t) Not directly used Computes mean and variance of forward process `q(x_t
GaussianDiffusion.q_posterior_mean_variance(x_start, x_t, t) trainer.py: validation() (line 845) Computes posterior mean and variance `q(x_{t-1}

Backward Process (Sampling)

Original Function Our Implementation Description
GaussianDiffusion.p_mean_variance(model, x_t, y, t, ...) trainer.py: validation() (lines 844-848)
inference.py: inference_single_image() (lines 251-255)
app.py: super_resolve() (lines 154-158)
Computes backward step parameters: Calculates mean μ_θ and variance Σ_θ for equation (7). Mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
Variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
GaussianDiffusion.p_sample(model, x, y, t, ...) trainer.py: validation() (lines 850-853)
inference.py: inference_single_image() (lines 257-260)
app.py: super_resolve() (lines 160-163)
Single backward sampling step: Samples x_{t-1} from `p(x_{t-1}
GaussianDiffusion.p_sample_loop(y, model, ...) trainer.py: validation() (lines 822-856)
inference.py: inference_single_image() (lines 229-263)
app.py: super_resolve() (lines 133-165)
Full sampling loop: Iterates from t = T-1 down to t = 0, calling p_sample at each step. Returns final denoised sample
GaussianDiffusion.p_sample_loop_progressive(y, model, ...) Same as above (we don't use progressive) Same as p_sample_loop but yields intermediate samples at each timestep
GaussianDiffusion.prior_sample(y, noise=None) trainer.py: validation() (lines 813-815)
inference.py: inference_single_image() (lines 223-226)
app.py: super_resolve() (lines 128-130)
Initializes x_T for sampling: Generates starting point from prior distribution. Formula: x_T = y + κ * √η_T * noise (starts from LR + noise)

Input Scaling

Original Function Our Implementation Description
GaussianDiffusion._scale_input(inputs, t) trainer.py: _scale_input() (lines 367-389)
inference.py: _scale_input() (lines 151-173)
app.py: _scale_input() (lines 50-72)
Normalizes input variance: Scales input x_t to normalize variance across timesteps for training stability. Formula: x_scaled = x_t / std where std = √(η_t * κ² + 1) for latent space

Training Loss

Original Function Our Implementation Description
GaussianDiffusion.training_losses(model, x_start, y, t, ...) trainer.py: training_step() (lines 392-493) Computes training loss: Encodes HR/LR to latent, adds noise via q_sample, predicts x0, computes MSE loss. Our implementation: predicts x0 directly (ModelMeanType.START_X)

Autoencoder Functions

Original Function Our Implementation Description
GaussianDiffusion.encode_first_stage(y, first_stage_model, up_sample=False) autoencoder.py: VQGANWrapper.encode()
data.py: SRDatasetOnTheFly.__getitem__() (line 160)
Encodes image to latent: Encodes pixel-space image to latent space using VQGAN. If up_sample=True, upsamples LR before encoding
GaussianDiffusion.decode_first_stage(z_sample, first_stage_model, ...) autoencoder.py: VQGANWrapper.decode()
trainer.py: validation() (line 861)
inference.py: inference_single_image() (line 269)
Decodes latent to image: Decodes latent-space tensor back to pixel space using VQGAN decoder

Trainer Functions

Original Trainer (original_trainer.py)

Original Function Our Implementation Description
Trainer.__init__(configs, ...) trainer.py: Trainer.__init__() (lines 36-75) Initializes trainer: Sets up device, checkpoint directory, noise schedule, loss function, WandB
Trainer.setup_seed(seed=None) trainer.py: setup_seed() (lines 76-95) Sets random seeds: Ensures reproducibility by setting seeds for random, numpy, torch, and CUDA
Trainer.build_model() trainer.py: build_model() (lines 212-267) Builds model and autoencoder: Initializes FullUNET model, optionally compiles it, loads VQGAN autoencoder, initializes LPIPS metric
Trainer.setup_optimization() trainer.py: setup_optimization() (lines 141-211) Sets up optimizer and scheduler: Initializes AdamW optimizer, AMP scaler (if enabled), CosineAnnealingLR scheduler (if enabled)
Trainer.build_dataloader() trainer.py: build_dataloader() (lines 283-344) Builds data loaders: Creates train and validation DataLoaders, wraps train loader to cycle infinitely
Trainer.training_losses() trainer.py: training_step() (lines 392-493) Training step: Implements micro-batching, adds noise, forward pass, loss computation, backward step, gradient clipping, optimizer step
Trainer.validation(phase='val') trainer.py: validation() (lines 748-963) Validation loop: Runs full diffusion sampling loop, decodes results, computes PSNR/SSIM/LPIPS metrics, logs to WandB
Trainer.adjust_lr(current_iters=None) trainer.py: adjust_lr() (lines 495-519) Learning rate scheduling: Implements linear warmup, then cosine annealing (if enabled)
Trainer.save_ckpt() trainer.py: save_ckpt() (lines 520-562) Saves checkpoint: Saves model, optimizer, AMP scaler, LR scheduler states, current iteration
Trainer.resume_from_ckpt(ckpt_path) trainer.py: resume_from_ckpt() (lines 563-670) Resumes from checkpoint: Loads model, optimizer, scaler, scheduler states, restores iteration count and LR schedule
Trainer.log_step_train(...) trainer.py: log_step_train() (lines 671-747) Logs training metrics: Logs loss, learning rate, images (HR, LR, noisy input, prediction) to WandB at specified frequencies
Trainer.reload_ema_model() trainer.py: validation() (line 754) Loads EMA model: Uses EMA model for validation if use_ema_val=True

Inference Functions

Original Sampler (original_sampler.py)

Original Function Our Implementation Description
ResShiftSampler.sample_func(y0, noise_repeat=False, mask=False) inference.py: inference_single_image() (lines 175-274)
app.py: super_resolve() (lines 93-165)
Single image inference: Encodes LR to latent, runs full diffusion sampling loop, decodes to pixel space
ResShiftSampler.inference(in_path, out_path, ...) inference.py: main() (lines 385-509) Batch inference: Processes single image or directory of images, handles chopping for large images
ResShiftSampler.build_model() inference.py: load_model() (lines 322-384) Loads model and autoencoder: Loads checkpoint, handles compiled model checkpoints (strips _orig_mod. prefix), loads EMA if specified

Key Differences and Notes

1. Model Prediction Type

  • Original: Supports multiple prediction types (START_X, EPSILON, RESIDUAL, EPSILON_SCALE)
  • Our Implementation: Only uses START_X (predicts x0 directly), matching ResShift paper

2. Sampling Initialization

  • Original: Uses prior_sample(y, noise)x_T = y + κ * √η_T * noise (starts from LR + noise)
  • Our Implementation: Same approach in inference and validation (fixed in validation to match original)

3. Backward Equation (Equation 7)

  • Original: Uses p_mean_variance() → computes μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred and Σ_θ = κ² * (η_{t-1}/η_t) * α_t
  • Our Implementation: Same equation, implemented directly in sampling loops

4. Input Scaling

  • Original: _scale_input() normalizes variance: x_scaled = x_t / √(η_t * κ² + 1) for latent space
  • Our Implementation: Same formula, applied in training and inference

5. Training Loss

  • Original: Supports weighted MSE based on posterior variance
  • Our Implementation: Uses simple MSE loss (predicts x0, compares with HR latent)

6. Validation

  • Original: Uses p_sample_loop_progressive() with EMA model, computes PSNR/LPIPS
  • Our Implementation: Same approach, also computes SSIM, uses EMA if use_ema_val=True

7. Checkpoint Handling

  • Original: Standard checkpoint loading
  • Our Implementation: Handles compiled model checkpoints (strips _orig_mod. prefix) for torch.compile() compatibility

Function Call Flow

Training Flow

train.py: train()
  → Trainer.__init__()
  → Trainer.build_model()
  → Trainer.setup_optimization()
  → Trainer.build_dataloader()
  → Loop:
      → Trainer.training_step()  # q_sample + forward + loss + backward
      → Trainer.adjust_lr()
      → Trainer.validation()  # p_sample_loop
      → Trainer.log_step_train()
      → Trainer.save_ckpt()

Inference Flow

inference.py: main()
  → load_model()  # Load checkpoint
  → get_vqgan()  # Load autoencoder
  → inference_single_image()
      → autoencoder.encode()  # LR → latent
      → prior_sample()  # Initialize x_T
      → Loop: p_sample()  # Denoise T-1 → 0
      → autoencoder.decode()  # Latent → pixel

Summary

Our implementation closely follows the original ResShift implementation, with the following key mappings:

  • Forward process: q_sampletraining_step() noise addition
  • Backward process: p_sample → sampling loops in validation() and inference_single_image()
  • Training: training_lossestraining_step()
  • Autoencoder: encode_first_stage/decode_first_stageVQGANWrapper.encode()/decode()
  • Input scaling: _scale_input_scale_input() (same name, same logic)

All core diffusion equations (forward process, backward equation 7, input scaling) match the original implementation.