Spaces:
Sleeping
Sleeping
A newer version of the Gradio SDK is available:
6.2.0
Function Mapping: Original Implementation → Our Implementation
This document maps functions from the original ResShift implementation to our corresponding functions and explains what each function does.
Core Diffusion Functions
Forward Process (Noise Addition)
| Original Function | Our Implementation | Description |
|---|---|---|
GaussianDiffusion.q_sample(x_start, y, t, noise=None) |
trainer.py: training_step() (line 434) |
Forward diffusion process: Adds noise to HR image according to ResShift schedule. Formula: x_t = x_0 + η_t * (y - x_0) + κ * √η_t * ε |
GaussianDiffusion.q_mean_variance(x_start, y, t) |
Not directly used | Computes mean and variance of forward process `q(x_t |
GaussianDiffusion.q_posterior_mean_variance(x_start, x_t, t) |
trainer.py: validation() (line 845) |
Computes posterior mean and variance `q(x_{t-1} |
Backward Process (Sampling)
| Original Function | Our Implementation | Description |
|---|---|---|
GaussianDiffusion.p_mean_variance(model, x_t, y, t, ...) |
trainer.py: validation() (lines 844-848)inference.py: inference_single_image() (lines 251-255)app.py: super_resolve() (lines 154-158) |
Computes backward step parameters: Calculates mean μ_θ and variance Σ_θ for equation (7). Mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)Variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t |
GaussianDiffusion.p_sample(model, x, y, t, ...) |
trainer.py: validation() (lines 850-853)inference.py: inference_single_image() (lines 257-260)app.py: super_resolve() (lines 160-163) |
Single backward sampling step: Samples x_{t-1} from `p(x_{t-1} |
GaussianDiffusion.p_sample_loop(y, model, ...) |
trainer.py: validation() (lines 822-856)inference.py: inference_single_image() (lines 229-263)app.py: super_resolve() (lines 133-165) |
Full sampling loop: Iterates from t = T-1 down to t = 0, calling p_sample at each step. Returns final denoised sample |
GaussianDiffusion.p_sample_loop_progressive(y, model, ...) |
Same as above (we don't use progressive) | Same as p_sample_loop but yields intermediate samples at each timestep |
GaussianDiffusion.prior_sample(y, noise=None) |
trainer.py: validation() (lines 813-815)inference.py: inference_single_image() (lines 223-226)app.py: super_resolve() (lines 128-130) |
Initializes x_T for sampling: Generates starting point from prior distribution. Formula: x_T = y + κ * √η_T * noise (starts from LR + noise) |
Input Scaling
| Original Function | Our Implementation | Description |
|---|---|---|
GaussianDiffusion._scale_input(inputs, t) |
trainer.py: _scale_input() (lines 367-389)inference.py: _scale_input() (lines 151-173)app.py: _scale_input() (lines 50-72) |
Normalizes input variance: Scales input x_t to normalize variance across timesteps for training stability. Formula: x_scaled = x_t / std where std = √(η_t * κ² + 1) for latent space |
Training Loss
| Original Function | Our Implementation | Description |
|---|---|---|
GaussianDiffusion.training_losses(model, x_start, y, t, ...) |
trainer.py: training_step() (lines 392-493) |
Computes training loss: Encodes HR/LR to latent, adds noise via q_sample, predicts x0, computes MSE loss. Our implementation: predicts x0 directly (ModelMeanType.START_X) |
Autoencoder Functions
| Original Function | Our Implementation | Description |
|---|---|---|
GaussianDiffusion.encode_first_stage(y, first_stage_model, up_sample=False) |
autoencoder.py: VQGANWrapper.encode()data.py: SRDatasetOnTheFly.__getitem__() (line 160) |
Encodes image to latent: Encodes pixel-space image to latent space using VQGAN. If up_sample=True, upsamples LR before encoding |
GaussianDiffusion.decode_first_stage(z_sample, first_stage_model, ...) |
autoencoder.py: VQGANWrapper.decode()trainer.py: validation() (line 861)inference.py: inference_single_image() (line 269) |
Decodes latent to image: Decodes latent-space tensor back to pixel space using VQGAN decoder |
Trainer Functions
Original Trainer (original_trainer.py)
| Original Function | Our Implementation | Description |
|---|---|---|
Trainer.__init__(configs, ...) |
trainer.py: Trainer.__init__() (lines 36-75) |
Initializes trainer: Sets up device, checkpoint directory, noise schedule, loss function, WandB |
Trainer.setup_seed(seed=None) |
trainer.py: setup_seed() (lines 76-95) |
Sets random seeds: Ensures reproducibility by setting seeds for random, numpy, torch, and CUDA |
Trainer.build_model() |
trainer.py: build_model() (lines 212-267) |
Builds model and autoencoder: Initializes FullUNET model, optionally compiles it, loads VQGAN autoencoder, initializes LPIPS metric |
Trainer.setup_optimization() |
trainer.py: setup_optimization() (lines 141-211) |
Sets up optimizer and scheduler: Initializes AdamW optimizer, AMP scaler (if enabled), CosineAnnealingLR scheduler (if enabled) |
Trainer.build_dataloader() |
trainer.py: build_dataloader() (lines 283-344) |
Builds data loaders: Creates train and validation DataLoaders, wraps train loader to cycle infinitely |
Trainer.training_losses() |
trainer.py: training_step() (lines 392-493) |
Training step: Implements micro-batching, adds noise, forward pass, loss computation, backward step, gradient clipping, optimizer step |
Trainer.validation(phase='val') |
trainer.py: validation() (lines 748-963) |
Validation loop: Runs full diffusion sampling loop, decodes results, computes PSNR/SSIM/LPIPS metrics, logs to WandB |
Trainer.adjust_lr(current_iters=None) |
trainer.py: adjust_lr() (lines 495-519) |
Learning rate scheduling: Implements linear warmup, then cosine annealing (if enabled) |
Trainer.save_ckpt() |
trainer.py: save_ckpt() (lines 520-562) |
Saves checkpoint: Saves model, optimizer, AMP scaler, LR scheduler states, current iteration |
Trainer.resume_from_ckpt(ckpt_path) |
trainer.py: resume_from_ckpt() (lines 563-670) |
Resumes from checkpoint: Loads model, optimizer, scaler, scheduler states, restores iteration count and LR schedule |
Trainer.log_step_train(...) |
trainer.py: log_step_train() (lines 671-747) |
Logs training metrics: Logs loss, learning rate, images (HR, LR, noisy input, prediction) to WandB at specified frequencies |
Trainer.reload_ema_model() |
trainer.py: validation() (line 754) |
Loads EMA model: Uses EMA model for validation if use_ema_val=True |
Inference Functions
Original Sampler (original_sampler.py)
| Original Function | Our Implementation | Description |
|---|---|---|
ResShiftSampler.sample_func(y0, noise_repeat=False, mask=False) |
inference.py: inference_single_image() (lines 175-274)app.py: super_resolve() (lines 93-165) |
Single image inference: Encodes LR to latent, runs full diffusion sampling loop, decodes to pixel space |
ResShiftSampler.inference(in_path, out_path, ...) |
inference.py: main() (lines 385-509) |
Batch inference: Processes single image or directory of images, handles chopping for large images |
ResShiftSampler.build_model() |
inference.py: load_model() (lines 322-384) |
Loads model and autoencoder: Loads checkpoint, handles compiled model checkpoints (strips _orig_mod. prefix), loads EMA if specified |
Key Differences and Notes
1. Model Prediction Type
- Original: Supports multiple prediction types (START_X, EPSILON, RESIDUAL, EPSILON_SCALE)
- Our Implementation: Only uses START_X (predicts x0 directly), matching ResShift paper
2. Sampling Initialization
- Original: Uses
prior_sample(y, noise)→x_T = y + κ * √η_T * noise(starts from LR + noise) - Our Implementation: Same approach in inference and validation (fixed in validation to match original)
3. Backward Equation (Equation 7)
- Original: Uses
p_mean_variance()→ computesμ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_predandΣ_θ = κ² * (η_{t-1}/η_t) * α_t - Our Implementation: Same equation, implemented directly in sampling loops
4. Input Scaling
- Original:
_scale_input()normalizes variance:x_scaled = x_t / √(η_t * κ² + 1)for latent space - Our Implementation: Same formula, applied in training and inference
5. Training Loss
- Original: Supports weighted MSE based on posterior variance
- Our Implementation: Uses simple MSE loss (predicts x0, compares with HR latent)
6. Validation
- Original: Uses
p_sample_loop_progressive()with EMA model, computes PSNR/LPIPS - Our Implementation: Same approach, also computes SSIM, uses EMA if
use_ema_val=True
7. Checkpoint Handling
- Original: Standard checkpoint loading
- Our Implementation: Handles compiled model checkpoints (strips
_orig_mod.prefix) fortorch.compile()compatibility
Function Call Flow
Training Flow
train.py: train()
→ Trainer.__init__()
→ Trainer.build_model()
→ Trainer.setup_optimization()
→ Trainer.build_dataloader()
→ Loop:
→ Trainer.training_step() # q_sample + forward + loss + backward
→ Trainer.adjust_lr()
→ Trainer.validation() # p_sample_loop
→ Trainer.log_step_train()
→ Trainer.save_ckpt()
Inference Flow
inference.py: main()
→ load_model() # Load checkpoint
→ get_vqgan() # Load autoencoder
→ inference_single_image()
→ autoencoder.encode() # LR → latent
→ prior_sample() # Initialize x_T
→ Loop: p_sample() # Denoise T-1 → 0
→ autoencoder.decode() # Latent → pixel
Summary
Our implementation closely follows the original ResShift implementation, with the following key mappings:
- Forward process:
q_sample→training_step()noise addition - Backward process:
p_sample→ sampling loops invalidation()andinference_single_image() - Training:
training_losses→training_step() - Autoencoder:
encode_first_stage/decode_first_stage→VQGANWrapper.encode()/decode() - Input scaling:
_scale_input→_scale_input()(same name, same logic)
All core diffusion equations (forward process, backward equation 7, input scaling) match the original implementation.