Spaces:
Running
Running
File size: 10,949 Bytes
3c45764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Function Mapping: Original Implementation → Our Implementation
This document maps functions from the original ResShift implementation to our corresponding functions and explains what each function does.
## Core Diffusion Functions
### Forward Process (Noise Addition)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.q_sample(x_start, y, t, noise=None)` | `trainer.py: training_step()` (line 434) | **Forward diffusion process**: Adds noise to HR image according to ResShift schedule. Formula: `x_t = x_0 + η_t * (y - x_0) + κ * √η_t * ε` |
| `GaussianDiffusion.q_mean_variance(x_start, y, t)` | Not directly used | Computes mean and variance of forward process `q(x_t | x_0, y)` |
| `GaussianDiffusion.q_posterior_mean_variance(x_start, x_t, t)` | `trainer.py: validation()` (line 845) | Computes posterior mean and variance `q(x_{t-1} | x_t, x_0)` for backward sampling. Used in equation (7) |
### Backward Process (Sampling)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.p_mean_variance(model, x_t, y, t, ...)` | `trainer.py: validation()` (lines 844-848)<br>`inference.py: inference_single_image()` (lines 251-255)<br>`app.py: super_resolve()` (lines 154-158) | **Computes backward step parameters**: Calculates mean `μ_θ` and variance `Σ_θ` for equation (7). Mean: `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)`<br>Variance: `Σ_θ = κ² * (η_{t-1}/η_t) * α_t` |
| `GaussianDiffusion.p_sample(model, x, y, t, ...)` | `trainer.py: validation()` (lines 850-853)<br>`inference.py: inference_single_image()` (lines 257-260)<br>`app.py: super_resolve()` (lines 160-163) | **Single backward sampling step**: Samples `x_{t-1}` from `p(x_{t-1} | x_t, y)` using equation (7). Formula: `x_{t-1} = μ_θ + √Σ_θ * ε` (with nonzero_mask for t > 0) |
| `GaussianDiffusion.p_sample_loop(y, model, ...)` | `trainer.py: validation()` (lines 822-856)<br>`inference.py: inference_single_image()` (lines 229-263)<br>`app.py: super_resolve()` (lines 133-165) | **Full sampling loop**: Iterates from t = T-1 down to t = 0, calling `p_sample` at each step. Returns final denoised sample |
| `GaussianDiffusion.p_sample_loop_progressive(y, model, ...)` | Same as above (we don't use progressive) | Same as `p_sample_loop` but yields intermediate samples at each timestep |
| `GaussianDiffusion.prior_sample(y, noise=None)` | `trainer.py: validation()` (lines 813-815)<br>`inference.py: inference_single_image()` (lines 223-226)<br>`app.py: super_resolve()` (lines 128-130) | **Initializes x_T for sampling**: Generates starting point from prior distribution. Formula: `x_T = y + κ * √η_T * noise` (starts from LR + noise) |
### Input Scaling
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion._scale_input(inputs, t)` | `trainer.py: _scale_input()` (lines 367-389)<br>`inference.py: _scale_input()` (lines 151-173)<br>`app.py: _scale_input()` (lines 50-72) | **Normalizes input variance**: Scales input `x_t` to normalize variance across timesteps for training stability. Formula: `x_scaled = x_t / std` where `std = √(η_t * κ² + 1)` for latent space |
### Training Loss
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.training_losses(model, x_start, y, t, ...)` | `trainer.py: training_step()` (lines 392-493) | **Computes training loss**: Encodes HR/LR to latent, adds noise via `q_sample`, predicts x0, computes MSE loss. Our implementation: predicts x0 directly (ModelMeanType.START_X) |
### Autoencoder Functions
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `GaussianDiffusion.encode_first_stage(y, first_stage_model, up_sample=False)` | `autoencoder.py: VQGANWrapper.encode()`<br>`data.py: SRDatasetOnTheFly.__getitem__()` (line 160) | **Encodes image to latent**: Encodes pixel-space image to latent space using VQGAN. If `up_sample=True`, upsamples LR before encoding |
| `GaussianDiffusion.decode_first_stage(z_sample, first_stage_model, ...)` | `autoencoder.py: VQGANWrapper.decode()`<br>`trainer.py: validation()` (line 861)<br>`inference.py: inference_single_image()` (line 269) | **Decodes latent to image**: Decodes latent-space tensor back to pixel space using VQGAN decoder |
## Trainer Functions
### Original Trainer (`original_trainer.py`)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `Trainer.__init__(configs, ...)` | `trainer.py: Trainer.__init__()` (lines 36-75) | **Initializes trainer**: Sets up device, checkpoint directory, noise schedule, loss function, WandB |
| `Trainer.setup_seed(seed=None)` | `trainer.py: setup_seed()` (lines 76-95) | **Sets random seeds**: Ensures reproducibility by setting seeds for random, numpy, torch, and CUDA |
| `Trainer.build_model()` | `trainer.py: build_model()` (lines 212-267) | **Builds model and autoencoder**: Initializes FullUNET model, optionally compiles it, loads VQGAN autoencoder, initializes LPIPS metric |
| `Trainer.setup_optimization()` | `trainer.py: setup_optimization()` (lines 141-211) | **Sets up optimizer and scheduler**: Initializes AdamW optimizer, AMP scaler (if enabled), CosineAnnealingLR scheduler (if enabled) |
| `Trainer.build_dataloader()` | `trainer.py: build_dataloader()` (lines 283-344) | **Builds data loaders**: Creates train and validation DataLoaders, wraps train loader to cycle infinitely |
| `Trainer.training_losses()` | `trainer.py: training_step()` (lines 392-493) | **Training step**: Implements micro-batching, adds noise, forward pass, loss computation, backward step, gradient clipping, optimizer step |
| `Trainer.validation(phase='val')` | `trainer.py: validation()` (lines 748-963) | **Validation loop**: Runs full diffusion sampling loop, decodes results, computes PSNR/SSIM/LPIPS metrics, logs to WandB |
| `Trainer.adjust_lr(current_iters=None)` | `trainer.py: adjust_lr()` (lines 495-519) | **Learning rate scheduling**: Implements linear warmup, then cosine annealing (if enabled) |
| `Trainer.save_ckpt()` | `trainer.py: save_ckpt()` (lines 520-562) | **Saves checkpoint**: Saves model, optimizer, AMP scaler, LR scheduler states, current iteration |
| `Trainer.resume_from_ckpt(ckpt_path)` | `trainer.py: resume_from_ckpt()` (lines 563-670) | **Resumes from checkpoint**: Loads model, optimizer, scaler, scheduler states, restores iteration count and LR schedule |
| `Trainer.log_step_train(...)` | `trainer.py: log_step_train()` (lines 671-747) | **Logs training metrics**: Logs loss, learning rate, images (HR, LR, noisy input, prediction) to WandB at specified frequencies |
| `Trainer.reload_ema_model()` | `trainer.py: validation()` (line 754) | **Loads EMA model**: Uses EMA model for validation if `use_ema_val=True` |
## Inference Functions
### Original Sampler (`original_sampler.py`)
| Original Function | Our Implementation | Description |
|------------------|-------------------|-------------|
| `ResShiftSampler.sample_func(y0, noise_repeat=False, mask=False)` | `inference.py: inference_single_image()` (lines 175-274)<br>`app.py: super_resolve()` (lines 93-165) | **Single image inference**: Encodes LR to latent, runs full diffusion sampling loop, decodes to pixel space |
| `ResShiftSampler.inference(in_path, out_path, ...)` | `inference.py: main()` (lines 385-509) | **Batch inference**: Processes single image or directory of images, handles chopping for large images |
| `ResShiftSampler.build_model()` | `inference.py: load_model()` (lines 322-384) | **Loads model and autoencoder**: Loads checkpoint, handles compiled model checkpoints (strips `_orig_mod.` prefix), loads EMA if specified |
## Key Differences and Notes
### 1. **Model Prediction Type**
- **Original**: Supports multiple prediction types (START_X, EPSILON, RESIDUAL, EPSILON_SCALE)
- **Our Implementation**: Only uses START_X (predicts x0 directly), matching ResShift paper
### 2. **Sampling Initialization**
- **Original**: Uses `prior_sample(y, noise)` → `x_T = y + κ * √η_T * noise` (starts from LR + noise)
- **Our Implementation**: Same approach in inference and validation (fixed in validation to match original)
### 3. **Backward Equation (Equation 7)**
- **Original**: Uses `p_mean_variance()` → computes `μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred` and `Σ_θ = κ² * (η_{t-1}/η_t) * α_t`
- **Our Implementation**: Same equation, implemented directly in sampling loops
### 4. **Input Scaling**
- **Original**: `_scale_input()` normalizes variance: `x_scaled = x_t / √(η_t * κ² + 1)` for latent space
- **Our Implementation**: Same formula, applied in training and inference
### 5. **Training Loss**
- **Original**: Supports weighted MSE based on posterior variance
- **Our Implementation**: Uses simple MSE loss (predicts x0, compares with HR latent)
### 6. **Validation**
- **Original**: Uses `p_sample_loop_progressive()` with EMA model, computes PSNR/LPIPS
- **Our Implementation**: Same approach, also computes SSIM, uses EMA if `use_ema_val=True`
### 7. **Checkpoint Handling**
- **Original**: Standard checkpoint loading
- **Our Implementation**: Handles compiled model checkpoints (strips `_orig_mod.` prefix) for `torch.compile()` compatibility
## Function Call Flow
### Training Flow
```
train.py: train()
→ Trainer.__init__()
→ Trainer.build_model()
→ Trainer.setup_optimization()
→ Trainer.build_dataloader()
→ Loop:
→ Trainer.training_step() # q_sample + forward + loss + backward
→ Trainer.adjust_lr()
→ Trainer.validation() # p_sample_loop
→ Trainer.log_step_train()
→ Trainer.save_ckpt()
```
### Inference Flow
```
inference.py: main()
→ load_model() # Load checkpoint
→ get_vqgan() # Load autoencoder
→ inference_single_image()
→ autoencoder.encode() # LR → latent
→ prior_sample() # Initialize x_T
→ Loop: p_sample() # Denoise T-1 → 0
→ autoencoder.decode() # Latent → pixel
```
## Summary
Our implementation closely follows the original ResShift implementation, with the following key mappings:
- **Forward process**: `q_sample` → `training_step()` noise addition
- **Backward process**: `p_sample` → sampling loops in `validation()` and `inference_single_image()`
- **Training**: `training_losses` → `training_step()`
- **Autoencoder**: `encode_first_stage`/`decode_first_stage` → `VQGANWrapper.encode()`/`decode()`
- **Input scaling**: `_scale_input` → `_scale_input()` (same name, same logic)
All core diffusion equations (forward process, backward equation 7, input scaling) match the original implementation.
|