DiffusionSR / src /trainer.py
shekkari21's picture
Commiting all the super resolution files
3c45764
import torch
import torch.nn as nn
import random
import numpy as np
from model import FullUNET
from noiseControl import resshift_schedule
from torch.utils.data import DataLoader
from data import mini_dataset, train_dataset, valid_dataset, get_vqgan_model
import torch.optim as optim
from config import (
batch_size, device, learning_rate, iterations,
weight_decay, T, k, _project_root, num_workers,
use_amp, lr, lr_min, lr_schedule, warmup_iterations,
compile_flag, compile_mode, batch, prefetch_factor, microbatch,
save_freq, log_freq, val_freq, val_y_channel, ema_rate, use_ema_val,
seed, global_seeding, normalize_input, latent_flag
)
import wandb
import os
import math
import time
from pathlib import Path
from itertools import cycle
from contextlib import nullcontext
from dotenv import load_dotenv
from metrics import compute_psnr, compute_ssim, compute_lpips
from ema import EMA
import lpips
class Trainer:
"""
Modular trainer class following the original ResShift trainer structure.
"""
def __init__(self, save_dir=None, resume_ckpt=None):
"""
Initialize trainer with config values.
Args:
save_dir: Directory to save checkpoints (defaults to _project_root / 'checkpoints')
resume_ckpt: Path to checkpoint file to resume from (optional)
"""
self.device = device
self.current_iters = 0
self.iters_start = 0
self.resume_ckpt = resume_ckpt
# Setup checkpoint directory
if save_dir is None:
save_dir = _project_root / 'checkpoints'
self.save_dir = Path(save_dir)
self.ckpt_dir = self.save_dir / 'ckpts'
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
# Initialize noise schedule (eta values for ResShift)
self.eta = resshift_schedule().to(self.device)
self.eta = self.eta[:, None, None, None] # shape (T, 1, 1, 1)
# Loss criterion
self.criterion = nn.MSELoss()
# Timing for checkpoint saving
self.tic = None
# EMA will be initialized after model is built
self.ema = None
self.ema_model = None
# Set random seeds for reproducibility
self.setup_seed()
# Initialize WandB
self.init_wandb()
def setup_seed(self, seed_val=None, global_seeding_val=None):
"""
Set random seeds for reproducibility.
Sets seeds for:
- Python random module
- NumPy
- PyTorch (CPU and CUDA)
Args:
seed_val: Seed value (defaults to config.seed)
global_seeding_val: Whether to use global seeding (defaults to config.global_seeding)
"""
if seed_val is None:
seed_val = seed
if global_seeding_val is None:
global_seeding_val = global_seeding
# Set Python random seed
random.seed(seed_val)
# Set NumPy random seed
np.random.seed(seed_val)
# Set PyTorch random seed
torch.manual_seed(seed_val)
# Set CUDA random seeds (if available)
if torch.cuda.is_available():
if global_seeding_val:
torch.cuda.manual_seed_all(seed_val)
else:
torch.cuda.manual_seed(seed_val)
# For multi-GPU, each GPU would get seed + rank (not implemented here)
# Make deterministic (may impact performance)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
print(f"✓ Random seeds set: seed={seed_val}, global_seeding={global_seeding_val}")
def init_wandb(self):
"""Initialize WandB logging."""
load_dotenv(os.path.join(_project_root, '.env'))
wandb.init(
project="diffusionsr",
name="reshift_training",
config={
"learning_rate": learning_rate,
"batch_size": batch_size,
"steps": iterations,
"model": "ResShift",
"T": T,
"k": k,
"optimizer": "AdamW" if weight_decay > 0 else "Adam",
"betas": (0.9, 0.999),
"grad_clip": 1.0,
"criterion": "MSE",
"device": str(device),
"training_space": "latent_64x64",
"use_amp": use_amp,
"ema_rate": 0.999 if hasattr(self, 'ema_rate') else None
}
)
def setup_optimization(self):
"""
Component 1: Setup optimizer and AMP scaler.
Sets up:
- Optimizer (AdamW with weight decay or Adam)
- AMP GradScaler if use_amp is True
"""
# Use AdamW if weight_decay > 0, otherwise Adam
if weight_decay > 0:
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
else:
self.optimizer = optim.Adam(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
# AMP settings: Create GradScaler if use_amp is True and CUDA is available
if use_amp and torch.cuda.is_available():
self.amp_scaler = torch.amp.GradScaler('cuda')
else:
self.amp_scaler = None
if use_amp and not torch.cuda.is_available():
print(" ⚠ Warning: AMP requested but CUDA not available. Disabling AMP.")
# Learning rate scheduler (cosine annealing after warmup)
self.lr_scheduler = None
if lr_schedule == 'cosin':
self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=self.optimizer,
T_max=iterations - warmup_iterations,
eta_min=lr_min
)
print(f" - LR scheduler: CosineAnnealingLR (T_max={iterations - warmup_iterations}, eta_min={lr_min})")
# Load pending optimizer state if resuming
if hasattr(self, '_pending_optimizer_state'):
self.optimizer.load_state_dict(self._pending_optimizer_state)
print(f" - Loaded optimizer state from checkpoint")
delattr(self, '_pending_optimizer_state')
# Load pending LR scheduler state if resuming
if hasattr(self, '_pending_lr_scheduler_state') and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(self._pending_lr_scheduler_state)
print(f" - Loaded LR scheduler state from checkpoint")
delattr(self, '_pending_lr_scheduler_state')
# Restore LR schedule by replaying adjust_lr for all previous iterations
# This ensures the LR is at the correct value for the resumed iteration
if hasattr(self, '_resume_iters') and self._resume_iters > 0:
print(f" - Restoring learning rate schedule to iteration {self._resume_iters}...")
for ii in range(1, self._resume_iters + 1):
self.adjust_lr(ii)
print(f" - ✓ Learning rate schedule restored")
delattr(self, '_resume_iters')
print(f"✓ Setup optimization:")
print(f" - Optimizer: {type(self.optimizer).__name__}")
print(f" - Learning rate: {learning_rate}")
print(f" - Weight decay: {weight_decay}")
print(f" - Warmup iterations: {warmup_iterations}")
print(f" - LR schedule: {lr_schedule if lr_schedule else 'None (fixed LR)'}")
print(f" - AMP enabled: {use_amp} ({'GradScaler active' if self.amp_scaler else 'disabled'})")
def build_model(self):
"""
Component 2: Build model and autoencoder (VQGAN).
Sets up:
- FullUNET model
- Model compilation (optional)
- VQGAN autoencoder for encoding/decoding
- Model info printing
"""
# Build main model
print("Building FullUNET model...")
self.model = FullUNET()
self.model = self.model.to(self.device)
# Optional: Compile model for optimization
# Model compilation can provide 20-30% speedup on modern GPUs
# but requires PyTorch 2.0+ and may have compatibility issues
self.model_compiled = False
if compile_flag:
try:
print(f"Compiling model with mode: {compile_mode}...")
self.model = torch.compile(self.model, mode=compile_mode)
self.model_compiled = True
print("✓ Model compilation done")
except Exception as e:
print(f"⚠ Warning: Model compilation failed: {e}")
print(" Continuing without compilation...")
self.model_compiled = False
# Load VQGAN autoencoder
print("Loading VQGAN autoencoder...")
self.autoencoder = get_vqgan_model()
print("✓ VQGAN autoencoder loaded")
# Initialize LPIPS model for validation
print("Loading LPIPS metric...")
self.lpips_model = lpips.LPIPS(net='vgg').to(self.device)
for params in self.lpips_model.parameters():
params.requires_grad_(False)
self.lpips_model.eval()
print("✓ LPIPS metric loaded")
# Initialize EMA if enabled
if ema_rate > 0:
print(f"Initializing EMA with rate: {ema_rate}...")
self.ema = EMA(self.model, ema_rate=ema_rate, device=self.device)
# Add Swin Transformer relative position index to ignore keys
self.ema.add_ignore_key('relative_position_index')
print("✓ EMA initialized")
else:
print("⚠ EMA disabled (ema_rate = 0)")
# Print model information
self.print_model_info()
def print_model_info(self):
"""Print model parameter count and architecture info."""
# Count parameters
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"\n✓ Model built successfully:")
print(f" - Model: FullUNET")
print(f" - Total parameters: {total_params / 1e6:.2f}M")
print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f" - Device: {self.device}")
print(f" - Compiled: {'Yes' if getattr(self, 'model_compiled', False) else 'No'}")
if self.autoencoder is not None:
print(f" - Autoencoder: VQGAN (loaded)")
def build_dataloader(self):
"""
Component 3: Build train and validation dataloaders.
Sets up:
- Train dataloader with infinite cycle wrapper
- Validation dataloader (if validation dataset exists)
- Proper batch sizes, num_workers, pin_memory, etc.
"""
def _wrap_loader(loader):
"""Wrap dataloader to cycle infinitely."""
while True:
yield from loader
# Create datasets dictionary
datasets = {'train': train_dataset}
if valid_dataset is not None:
datasets['val'] = valid_dataset
# Print dataset sizes
for phase, dataset in datasets.items():
print(f" - {phase.capitalize()} dataset: {len(dataset)} images")
# Create train dataloader
train_batch_size = batch[0] # Use first value from batch list
train_loader = DataLoader(
datasets['train'],
batch_size=train_batch_size,
shuffle=True,
drop_last=True, # Drop last incomplete batch
num_workers=min(num_workers, 4), # Limit num_workers
pin_memory=True if torch.cuda.is_available() else False,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
)
# Wrap train loader to cycle infinitely
self.dataloaders = {'train': _wrap_loader(train_loader)}
# Create validation dataloader if validation dataset exists
if 'val' in datasets:
val_batch_size = batch[1] if len(batch) > 1 else batch[0] # Use second value or fallback
val_loader = DataLoader(
datasets['val'],
batch_size=val_batch_size,
shuffle=False,
drop_last=False, # Don't drop last batch in validation
num_workers=0, # No multiprocessing for validation (safer)
pin_memory=True if torch.cuda.is_available() else False,
)
self.dataloaders['val'] = val_loader
# Store datasets
self.datasets = datasets
print(f"\n✓ Dataloaders built:")
print(f" - Train batch size: {train_batch_size}")
print(f" - Train num_workers: {min(num_workers, 4)}")
print(f" - Train drop_last: True")
if 'val' in self.dataloaders:
print(f" - Val batch size: {val_batch_size}")
print(f" - Val num_workers: 0")
def backward_step(self, loss, num_grad_accumulate=1):
"""
Component 4: Handle backward pass with AMP support and gradient accumulation.
Args:
loss: The computed loss tensor
num_grad_accumulate: Number of gradient accumulation steps (for micro-batching)
Returns:
loss: The loss tensor (for logging)
"""
# Normalize loss by gradient accumulation steps
loss = loss / num_grad_accumulate
# Backward pass: use AMP scaler if available, otherwise direct backward
if self.amp_scaler is None:
loss.backward()
else:
self.amp_scaler.scale(loss).backward()
return loss
def _scale_input(self, x_t, t):
"""
Scale input based on timestep for training stability.
Matches original GaussianDiffusion._scale_input for latent space.
For latent space: std = sqrt(etas[t] * kappa^2 + 1)
This normalizes the input variance across different timesteps.
Args:
x_t: Noisy input tensor (B, C, H, W)
t: Timestep tensor (B,)
Returns:
x_t_scaled: Scaled input tensor (B, C, H, W)
"""
if normalize_input and latent_flag:
# For latent space: std = sqrt(etas[t] * kappa^2 + 1)
# Extract eta_t for each sample in batch
eta_t = self.eta[t] # (B, 1, 1, 1)
std = torch.sqrt(eta_t * k**2 + 1)
x_t_scaled = x_t / std
else:
x_t_scaled = x_t
return x_t_scaled
def training_step(self, hr_latent, lr_latent):
"""
Component 5: Main training step with micro-batching and gradient accumulation.
Args:
hr_latent: High-resolution latent tensor (B, C, 64, 64)
lr_latent: Low-resolution latent tensor (B, C, 64, 64)
Returns:
loss: Average loss value for logging
timing_dict: Dictionary with timing information
"""
step_start = time.time()
self.model.train()
current_batchsize = hr_latent.shape[0]
micro_batchsize = microbatch
num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize)
total_loss = 0.0
forward_time = 0.0
backward_time = 0.0
# Process in micro-batches for gradient accumulation
for jj in range(0, current_batchsize, micro_batchsize):
# Extract micro-batch
end_idx = min(jj + micro_batchsize, current_batchsize)
hr_micro = hr_latent[jj:end_idx].to(self.device)
lr_micro = lr_latent[jj:end_idx].to(self.device)
last_batch = (end_idx >= current_batchsize)
# Compute residual in latent space
residual = (lr_micro - hr_micro)
# Generate random timesteps for each sample in micro-batch
t = torch.randint(0, T, (hr_micro.shape[0],)).to(self.device)
# Add noise in latent space (ResShift noise schedule)
epsilon = torch.randn_like(hr_micro) # Noise in latent space
eta_t = self.eta[t] # (B, 1, 1, 1)
x_t = hr_micro + eta_t * residual + k * torch.sqrt(eta_t) * epsilon
# Forward pass with autocast if AMP is enabled
forward_start = time.time()
if use_amp and torch.cuda.is_available():
context = torch.amp.autocast('cuda')
else:
context = nullcontext()
with context:
# Scale input for training stability (normalize variance across timesteps)
x_t_scaled = self._scale_input(x_t, t)
# Forward pass: Model predicts x0 (clean HR latent), not noise
# ResShift uses predict_type = "xstart"
x0_pred = self.model(x_t_scaled, t, lq=lr_micro)
# Loss: Compare predicted x0 with ground truth HR latent
loss = self.criterion(x0_pred, hr_micro)
forward_time += time.time() - forward_start
# Store loss value for logging (before dividing for gradient accumulation)
total_loss += loss.item()
# Backward step (handles gradient accumulation and AMP)
backward_start = time.time()
self.backward_step(loss, num_grad_accumulate)
backward_time += time.time() - backward_start
# Gradient clipping before optimizer step
if self.amp_scaler is None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
else:
# Unscale gradients before clipping when using AMP
self.amp_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
# Zero gradients
self.model.zero_grad()
# Update EMA after optimizer step
if self.ema is not None:
self.ema.update(self.model)
# Compute total step time
step_time = time.time() - step_start
# Return average loss (average across micro-batches)
num_micro_batches = math.ceil(current_batchsize / micro_batchsize)
avg_loss = total_loss / num_micro_batches if num_micro_batches > 0 else total_loss
# Return timing information
timing_dict = {
'step_time': step_time,
'forward_time': forward_time,
'backward_time': backward_time,
'num_micro_batches': num_micro_batches
}
return avg_loss, timing_dict
def adjust_lr(self, current_iters=None):
"""
Component 6: Adjust learning rate with warmup and optional cosine annealing.
Learning rate schedule:
- Warmup phase (iters <= warmup_iterations): Linear increase from 0 to base_lr
- After warmup: Use cosine annealing scheduler if lr_schedule == 'cosin', else keep base_lr
Args:
current_iters: Current iteration number (defaults to self.current_iters)
"""
base_lr = learning_rate
warmup_steps = warmup_iterations
current_iters = self.current_iters if current_iters is None else current_iters
if current_iters <= warmup_steps:
# Warmup phase: linear increase from 0 to base_lr
warmup_lr = (current_iters / warmup_steps) * base_lr
for params_group in self.optimizer.param_groups:
params_group['lr'] = warmup_lr
else:
# After warmup: use scheduler if available
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def save_ckpt(self):
"""
Component 7: Save checkpoint with model state, optimizer state, and training info.
Saves:
- Model state dict
- Optimizer state dict
- Current iteration number
- AMP scaler state (if AMP is enabled)
- LR scheduler state (if scheduler exists)
"""
ckpt_path = self.ckpt_dir / f'model_{self.current_iters}.pth'
# Prepare checkpoint dictionary
ckpt = {
'iters_start': self.current_iters,
'state_dict': self.model.state_dict(),
}
# Add optimizer state if available
if hasattr(self, 'optimizer'):
ckpt['optimizer'] = self.optimizer.state_dict()
# Add AMP scaler state if available
if self.amp_scaler is not None:
ckpt['amp_scaler'] = self.amp_scaler.state_dict()
# Add LR scheduler state if available
if self.lr_scheduler is not None:
ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
# Save checkpoint
torch.save(ckpt, ckpt_path)
print(f"✓ Checkpoint saved: {ckpt_path}")
# Save EMA checkpoint separately if EMA is enabled
if self.ema is not None:
ema_ckpt_path = self.ckpt_dir / f'ema_model_{self.current_iters}.pth'
torch.save(self.ema.state_dict(), ema_ckpt_path)
print(f"✓ EMA checkpoint saved: {ema_ckpt_path}")
return ckpt_path
def resume_from_ckpt(self, ckpt_path):
"""
Resume training from a checkpoint.
Loads:
- Model state dict
- Optimizer state dict
- AMP scaler state (if AMP is enabled)
- LR scheduler state (if scheduler exists)
- Current iteration number
- Restores LR schedule by replaying adjust_lr for previous iterations
Args:
ckpt_path: Path to checkpoint file (.pth)
"""
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")
if not ckpt_path.endswith('.pth'):
raise ValueError(f"Checkpoint file must have .pth extension: {ckpt_path}")
print(f"\n{'=' * 100}")
print(f"Resuming from checkpoint: {ckpt_path}")
print(f"{'=' * 100}")
# Load checkpoint
ckpt = torch.load(ckpt_path, map_location=self.device)
# Load model state dict
if 'state_dict' in ckpt:
self.model.load_state_dict(ckpt['state_dict'])
print(f"✓ Loaded model state dict")
else:
# If checkpoint is just the state dict
self.model.load_state_dict(ckpt)
print(f"✓ Loaded model state dict (direct)")
# Load optimizer state dict (must be done after optimizer is created)
if 'optimizer' in ckpt:
if hasattr(self, 'optimizer'):
self.optimizer.load_state_dict(ckpt['optimizer'])
print(f"✓ Loaded optimizer state dict")
else:
print(f"⚠ Warning: Optimizer state found in checkpoint but optimizer not yet created.")
print(f" Optimizer will be loaded after setup_optimization() is called.")
self._pending_optimizer_state = ckpt['optimizer']
# Load AMP scaler state
if 'amp_scaler' in ckpt:
if hasattr(self, 'amp_scaler') and self.amp_scaler is not None:
self.amp_scaler.load_state_dict(ckpt['amp_scaler'])
print(f"✓ Loaded AMP scaler state")
else:
print(f"⚠ Warning: AMP scaler state found but AMP not enabled or scaler not created.")
# Load LR scheduler state
if 'lr_scheduler' in ckpt:
if hasattr(self, 'lr_scheduler') and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
print(f"✓ Loaded LR scheduler state")
else:
print(f"⚠ Warning: LR scheduler state found but scheduler not yet created.")
self._pending_lr_scheduler_state = ckpt['lr_scheduler']
# Load EMA state if available (must be done after EMA is initialized)
# EMA checkpoint naming: ema_model_{iters}.pth (matches save pattern)
ckpt_path_obj = Path(ckpt_path)
# Extract iteration number from checkpoint name (e.g., "model_10000.pth" -> "10000")
if 'iters_start' in ckpt:
iters = ckpt['iters_start']
ema_ckpt_path = ckpt_path_obj.parent / f"ema_model_{iters}.pth"
else:
# Fallback: try to extract from filename
try:
iters = int(ckpt_path_obj.stem.split('_')[-1])
ema_ckpt_path = ckpt_path_obj.parent / f"ema_model_{iters}.pth"
except:
ema_ckpt_path = None
if ema_ckpt_path is not None and ema_ckpt_path.exists() and self.ema is not None:
ema_ckpt = torch.load(ema_ckpt_path, map_location=self.device)
self.ema.load_state_dict(ema_ckpt)
print(f"✓ Loaded EMA state from: {ema_ckpt_path}")
elif ema_ckpt_path is not None and ema_ckpt_path.exists() and self.ema is None:
print(f"⚠ Warning: EMA checkpoint found but EMA not enabled. Skipping EMA load.")
elif self.ema is not None:
print(f"⚠ Warning: EMA enabled but no EMA checkpoint found. Starting with fresh EMA.")
# Restore iteration number
if 'iters_start' in ckpt:
self.iters_start = ckpt['iters_start']
self.current_iters = ckpt['iters_start']
print(f"✓ Resuming from iteration: {self.iters_start}")
else:
print(f"⚠ Warning: No iteration number found in checkpoint. Starting from 0.")
self.iters_start = 0
self.current_iters = 0
# Note: LR schedule restoration will be done after setup_optimization()
# Store the iteration number for later restoration
self._resume_iters = self.iters_start
print(f"{'=' * 100}\n")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def log_step_train(self, loss, hr_latent, lr_latent, x_t, pred, phase='train'):
"""
Component 8: Log training metrics and images to WandB.
Logs:
- Loss and learning rate (at log_freq[0] intervals)
- Training images: HR, LR, and predictions (at log_freq[1] intervals)
- Elapsed time for checkpoint intervals
Args:
loss: Training loss value (float)
hr_latent: High-resolution latent tensor (B, C, 64, 64)
lr_latent: Low-resolution latent tensor (B, C, 64, 64)
x_t: Noisy input tensor (B, C, 64, 64)
pred: Model prediction (x0_pred - clean HR latent) (B, C, 64, 64)
phase: Training phase ('train' or 'val')
"""
# Log loss and learning rate at log_freq[0] intervals
if self.current_iters % log_freq[0] == 0:
current_lr = self.optimizer.param_groups[0]['lr']
# Get timing info if available (passed from training_step)
timing_info = {}
if hasattr(self, '_last_timing'):
timing_info = {
'train/step_time': self._last_timing.get('step_time', 0),
'train/forward_time': self._last_timing.get('forward_time', 0),
'train/backward_time': self._last_timing.get('backward_time', 0),
'train/iterations_per_sec': 1.0 / self._last_timing.get('step_time', 1.0) if self._last_timing.get('step_time', 0) > 0 else 0
}
wandb.log({
'loss': loss,
'learning_rate': current_lr,
'step': self.current_iters,
**timing_info
})
# Print to console
timing_str = ""
if hasattr(self, '_last_timing') and self._last_timing.get('step_time', 0) > 0:
timing_str = f", Step: {self._last_timing['step_time']:.3f}s, Forward: {self._last_timing['forward_time']:.3f}s, Backward: {self._last_timing['backward_time']:.3f}s"
print(f"Train: {self.current_iters:06d}/{iterations:06d}, "
f"Loss: {loss:.6f}, LR: {current_lr:.2e}{timing_str}")
# Log images at log_freq[1] intervals (only if x_t and pred are provided)
if self.current_iters % log_freq[1] == 0 and x_t is not None and pred is not None:
with torch.no_grad():
# Decode latents to pixel space for visualization
# Take first sample from batch
hr_pixel = self.autoencoder.decode(hr_latent[0:1]) # (1, 3, 256, 256)
lr_pixel = self.autoencoder.decode(lr_latent[0:1]) # (1, 3, 256, 256)
# Decode noisy input for visualization
x_t_pixel = self.autoencoder.decode(x_t[0:1]) # (1, 3, 256, 256)
# Decode predicted x0 (clean HR latent) for visualization
pred_pixel = self.autoencoder.decode(pred[0:1]) # (1, 3, 256, 256)
# Log images to WandB
wandb.log({
f'{phase}/hr_sample': wandb.Image(hr_pixel[0].cpu().clamp(0, 1)),
f'{phase}/lr_sample': wandb.Image(lr_pixel[0].cpu().clamp(0, 1)),
f'{phase}/noisy_input': wandb.Image(x_t_pixel[0].cpu().clamp(0, 1)),
f'{phase}/pred_sample': wandb.Image(pred_pixel[0].cpu().clamp(0, 1)),
'step': self.current_iters
})
# Track elapsed time for checkpoint intervals
if self.current_iters % save_freq == 1:
self.tic = time.time()
if self.current_iters % save_freq == 0 and self.tic is not None:
self.toc = time.time()
elapsed = self.toc - self.tic
print(f"Elapsed time for {save_freq} iterations: {elapsed:.2f}s")
print("=" * 100)
def validation(self):
"""
Run validation on validation dataset with full diffusion sampling loop.
Performs iterative denoising from t = T-1 down to t = 0, matching the
original ResShift implementation. This is slower but more accurate than
single-step prediction.
Computes:
- PSNR, SSIM, and LPIPS metrics
- Logs validation images to WandB
"""
if 'val' not in self.dataloaders:
print("No validation dataset available. Skipping validation.")
return
print("\n" + "=" * 100)
print("Running Validation")
print("=" * 100)
val_start = time.time()
# Use EMA model for validation if enabled
if use_ema_val and self.ema is not None:
# Create EMA model copy if it doesn't exist
if self.ema_model is None:
from copy import deepcopy
self.ema_model = deepcopy(self.model)
# Load EMA state into EMA model
self.ema.apply_to_model(self.ema_model)
self.ema_model.eval()
val_model = self.ema_model
print("Using EMA model for validation")
else:
self.model.eval()
val_model = self.model
if use_ema_val and self.ema is None:
print("⚠ Warning: use_ema_val=True but EMA not enabled. Using regular model.")
val_iter = iter(self.dataloaders['val'])
total_psnr = 0.0
total_ssim = 0.0
total_lpips = 0.0
total_val_loss = 0.0
num_samples = 0
total_sampling_time = 0.0
total_forward_time = 0.0
total_decode_time = 0.0
total_metric_time = 0.0
with torch.no_grad():
for batch_idx, (hr_latent, lr_latent) in enumerate(val_iter):
batch_start = time.time()
# Move to device
hr_latent = hr_latent.to(self.device)
lr_latent = lr_latent.to(self.device)
# Full diffusion sampling loop (iterative denoising)
# Start from maximum timestep and iterate backwards: T-1 → T-2 → ... → 1 → 0
sampling_start = time.time()
# Initialize x_t at maximum timestep (T-1)
# Start from LR with maximum noise (prior_sample: x_T = y + kappa * sqrt(eta_T) * noise)
epsilon_init = torch.randn_like(lr_latent)
eta_max = self.eta[T - 1]
x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init
# Track forward pass time during sampling
sampling_forward_time = 0.0
# Iterative sampling: denoise from t = T-1 down to t = 0
for t_step in range(T - 1, -1, -1): # T-1, T-2, ..., 1, 0
t = torch.full((hr_latent.shape[0],), t_step, device=self.device, dtype=torch.long)
# Scale input for training stability (normalize variance across timesteps)
x_t_scaled = self._scale_input(x_t, t)
# Predict x0 from current noisy state x_t
forward_start = time.time()
x0_pred = val_model(x_t_scaled, t, lq=lr_latent)
sampling_forward_time += time.time() - forward_start
# If not the last step, compute x_{t-1} from predicted x0 using equation (7)
if t_step > 0:
# Equation (7) from ResShift paper:
# μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
# Σ_θ = κ² * (η_{t-1}/η_t) * α_t
# x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
eta_t = self.eta[t_step]
eta_t_minus_1 = self.eta[t_step - 1]
# Compute alpha_t = η_t - η_{t-1}
alpha_t = eta_t - eta_t_minus_1
# Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred
mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
# Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t
# Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
noise = torch.randn_like(x_t)
nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1)))
x_t = mean + nonzero_mask * torch.sqrt(variance) * noise
else:
# Final step: use predicted x0 as final output
x_t = x0_pred
# Final prediction after full sampling loop
x0_final = x_t
# Compute validation loss (MSE in latent space, same as training loss)
val_loss = self.criterion(x0_final, hr_latent).item()
total_val_loss += val_loss * hr_latent.shape[0]
sampling_time = time.time() - sampling_start
total_sampling_time += sampling_time
total_forward_time += sampling_forward_time
# Decode latents to pixel space for metrics and visualization
decode_start = time.time()
hr_pixel = self.autoencoder.decode(hr_latent)
lr_pixel = self.autoencoder.decode(lr_latent)
sr_pixel = self.autoencoder.decode(x0_final) # Final SR output after full sampling
decode_time = time.time() - decode_start
total_decode_time += decode_time
# Convert to [0, 1] range if needed
hr_pixel = hr_pixel.clamp(0, 1)
sr_pixel = sr_pixel.clamp(0, 1)
# Compute metrics using simple functions
metric_start = time.time()
batch_psnr = compute_psnr(hr_pixel, sr_pixel)
total_psnr += batch_psnr * hr_latent.shape[0]
batch_ssim = compute_ssim(hr_pixel, sr_pixel)
total_ssim += batch_ssim * hr_latent.shape[0]
batch_lpips = compute_lpips(hr_pixel, sr_pixel, self.lpips_model)
total_lpips += batch_lpips * hr_latent.shape[0]
metric_time = time.time() - metric_start
total_metric_time += metric_time
num_samples += hr_latent.shape[0]
batch_time = time.time() - batch_start
# Print timing for first batch
if batch_idx == 0:
print(f"\nValidation Batch 0 Timing:")
print(f" - Sampling loop: {sampling_time:.3f}s ({sampling_forward_time:.3f}s forward)")
print(f" - Decoding: {decode_time:.3f}s")
print(f" - Metrics: {metric_time:.3f}s")
print(f" - Total batch: {batch_time:.3f}s")
# Log validation images periodically
if batch_idx == 0:
wandb.log({
'val/hr_sample': wandb.Image(hr_pixel[0].cpu()),
'val/lr_sample': wandb.Image(lr_pixel[0].cpu()),
'val/sr_sample': wandb.Image(sr_pixel[0].cpu()),
'step': self.current_iters
})
# Compute average metrics and timing
val_total_time = time.time() - val_start
num_batches = batch_idx + 1
if num_samples > 0:
mean_psnr = total_psnr / num_samples
mean_ssim = total_ssim / num_samples
mean_lpips = total_lpips / num_samples
mean_val_loss = total_val_loss / num_samples
avg_sampling_time = total_sampling_time / num_batches
avg_forward_time = total_forward_time / num_batches
avg_decode_time = total_decode_time / num_batches
avg_metric_time = total_metric_time / num_batches
avg_batch_time = val_total_time / num_batches
print(f"\nValidation Metrics:")
print(f" - Loss: {mean_val_loss:.6f}")
print(f" - PSNR: {mean_psnr:.2f} dB")
print(f" - SSIM: {mean_ssim:.4f}")
print(f" - LPIPS: {mean_lpips:.4f}")
print(f"\nValidation Timing (Total: {val_total_time:.2f}s, {num_batches} batches):")
print(f" - Avg sampling loop: {avg_sampling_time:.3f}s/batch ({avg_forward_time:.3f}s forward)")
print(f" - Avg decoding: {avg_decode_time:.3f}s/batch")
print(f" - Avg metrics: {avg_metric_time:.3f}s/batch")
print(f" - Avg batch time: {avg_batch_time:.3f}s/batch")
wandb.log({
'val/loss': mean_val_loss,
'val/psnr': mean_psnr,
'val/ssim': mean_ssim,
'val/lpips': mean_lpips,
'val/total_time': val_total_time,
'val/avg_sampling_time': avg_sampling_time,
'val/avg_forward_time': avg_forward_time,
'val/avg_decode_time': avg_decode_time,
'val/avg_metric_time': avg_metric_time,
'val/avg_batch_time': avg_batch_time,
'val/num_batches': num_batches,
'val/num_samples': num_samples,
'step': self.current_iters
})
print("=" * 100)
# Set model back to training mode
self.model.train()
if self.ema_model is not None:
self.ema_model.train() # Keep in sync, but won't be used for training
# Note: Main training script is in train.py
# This file contains the Trainer class implementation