Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import random | |
| import numpy as np | |
| from model import FullUNET | |
| from noiseControl import resshift_schedule | |
| from torch.utils.data import DataLoader | |
| from data import mini_dataset, train_dataset, valid_dataset, get_vqgan_model | |
| import torch.optim as optim | |
| from config import ( | |
| batch_size, device, learning_rate, iterations, | |
| weight_decay, T, k, _project_root, num_workers, | |
| use_amp, lr, lr_min, lr_schedule, warmup_iterations, | |
| compile_flag, compile_mode, batch, prefetch_factor, microbatch, | |
| save_freq, log_freq, val_freq, val_y_channel, ema_rate, use_ema_val, | |
| seed, global_seeding, normalize_input, latent_flag | |
| ) | |
| import wandb | |
| import os | |
| import math | |
| import time | |
| from pathlib import Path | |
| from itertools import cycle | |
| from contextlib import nullcontext | |
| from dotenv import load_dotenv | |
| from metrics import compute_psnr, compute_ssim, compute_lpips | |
| from ema import EMA | |
| import lpips | |
| class Trainer: | |
| """ | |
| Modular trainer class following the original ResShift trainer structure. | |
| """ | |
| def __init__(self, save_dir=None, resume_ckpt=None): | |
| """ | |
| Initialize trainer with config values. | |
| Args: | |
| save_dir: Directory to save checkpoints (defaults to _project_root / 'checkpoints') | |
| resume_ckpt: Path to checkpoint file to resume from (optional) | |
| """ | |
| self.device = device | |
| self.current_iters = 0 | |
| self.iters_start = 0 | |
| self.resume_ckpt = resume_ckpt | |
| # Setup checkpoint directory | |
| if save_dir is None: | |
| save_dir = _project_root / 'checkpoints' | |
| self.save_dir = Path(save_dir) | |
| self.ckpt_dir = self.save_dir / 'ckpts' | |
| self.ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| # Initialize noise schedule (eta values for ResShift) | |
| self.eta = resshift_schedule().to(self.device) | |
| self.eta = self.eta[:, None, None, None] # shape (T, 1, 1, 1) | |
| # Loss criterion | |
| self.criterion = nn.MSELoss() | |
| # Timing for checkpoint saving | |
| self.tic = None | |
| # EMA will be initialized after model is built | |
| self.ema = None | |
| self.ema_model = None | |
| # Set random seeds for reproducibility | |
| self.setup_seed() | |
| # Initialize WandB | |
| self.init_wandb() | |
| def setup_seed(self, seed_val=None, global_seeding_val=None): | |
| """ | |
| Set random seeds for reproducibility. | |
| Sets seeds for: | |
| - Python random module | |
| - NumPy | |
| - PyTorch (CPU and CUDA) | |
| Args: | |
| seed_val: Seed value (defaults to config.seed) | |
| global_seeding_val: Whether to use global seeding (defaults to config.global_seeding) | |
| """ | |
| if seed_val is None: | |
| seed_val = seed | |
| if global_seeding_val is None: | |
| global_seeding_val = global_seeding | |
| # Set Python random seed | |
| random.seed(seed_val) | |
| # Set NumPy random seed | |
| np.random.seed(seed_val) | |
| # Set PyTorch random seed | |
| torch.manual_seed(seed_val) | |
| # Set CUDA random seeds (if available) | |
| if torch.cuda.is_available(): | |
| if global_seeding_val: | |
| torch.cuda.manual_seed_all(seed_val) | |
| else: | |
| torch.cuda.manual_seed(seed_val) | |
| # For multi-GPU, each GPU would get seed + rank (not implemented here) | |
| # Make deterministic (may impact performance) | |
| # torch.backends.cudnn.deterministic = True | |
| # torch.backends.cudnn.benchmark = False | |
| print(f"✓ Random seeds set: seed={seed_val}, global_seeding={global_seeding_val}") | |
| def init_wandb(self): | |
| """Initialize WandB logging.""" | |
| load_dotenv(os.path.join(_project_root, '.env')) | |
| wandb.init( | |
| project="diffusionsr", | |
| name="reshift_training", | |
| config={ | |
| "learning_rate": learning_rate, | |
| "batch_size": batch_size, | |
| "steps": iterations, | |
| "model": "ResShift", | |
| "T": T, | |
| "k": k, | |
| "optimizer": "AdamW" if weight_decay > 0 else "Adam", | |
| "betas": (0.9, 0.999), | |
| "grad_clip": 1.0, | |
| "criterion": "MSE", | |
| "device": str(device), | |
| "training_space": "latent_64x64", | |
| "use_amp": use_amp, | |
| "ema_rate": 0.999 if hasattr(self, 'ema_rate') else None | |
| } | |
| ) | |
| def setup_optimization(self): | |
| """ | |
| Component 1: Setup optimizer and AMP scaler. | |
| Sets up: | |
| - Optimizer (AdamW with weight decay or Adam) | |
| - AMP GradScaler if use_amp is True | |
| """ | |
| # Use AdamW if weight_decay > 0, otherwise Adam | |
| if weight_decay > 0: | |
| self.optimizer = optim.AdamW( | |
| self.model.parameters(), | |
| lr=learning_rate, | |
| weight_decay=weight_decay, | |
| betas=(0.9, 0.999) | |
| ) | |
| else: | |
| self.optimizer = optim.Adam( | |
| self.model.parameters(), | |
| lr=learning_rate, | |
| weight_decay=weight_decay, | |
| betas=(0.9, 0.999) | |
| ) | |
| # AMP settings: Create GradScaler if use_amp is True and CUDA is available | |
| if use_amp and torch.cuda.is_available(): | |
| self.amp_scaler = torch.amp.GradScaler('cuda') | |
| else: | |
| self.amp_scaler = None | |
| if use_amp and not torch.cuda.is_available(): | |
| print(" ⚠ Warning: AMP requested but CUDA not available. Disabling AMP.") | |
| # Learning rate scheduler (cosine annealing after warmup) | |
| self.lr_scheduler = None | |
| if lr_schedule == 'cosin': | |
| self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer=self.optimizer, | |
| T_max=iterations - warmup_iterations, | |
| eta_min=lr_min | |
| ) | |
| print(f" - LR scheduler: CosineAnnealingLR (T_max={iterations - warmup_iterations}, eta_min={lr_min})") | |
| # Load pending optimizer state if resuming | |
| if hasattr(self, '_pending_optimizer_state'): | |
| self.optimizer.load_state_dict(self._pending_optimizer_state) | |
| print(f" - Loaded optimizer state from checkpoint") | |
| delattr(self, '_pending_optimizer_state') | |
| # Load pending LR scheduler state if resuming | |
| if hasattr(self, '_pending_lr_scheduler_state') and self.lr_scheduler is not None: | |
| self.lr_scheduler.load_state_dict(self._pending_lr_scheduler_state) | |
| print(f" - Loaded LR scheduler state from checkpoint") | |
| delattr(self, '_pending_lr_scheduler_state') | |
| # Restore LR schedule by replaying adjust_lr for all previous iterations | |
| # This ensures the LR is at the correct value for the resumed iteration | |
| if hasattr(self, '_resume_iters') and self._resume_iters > 0: | |
| print(f" - Restoring learning rate schedule to iteration {self._resume_iters}...") | |
| for ii in range(1, self._resume_iters + 1): | |
| self.adjust_lr(ii) | |
| print(f" - ✓ Learning rate schedule restored") | |
| delattr(self, '_resume_iters') | |
| print(f"✓ Setup optimization:") | |
| print(f" - Optimizer: {type(self.optimizer).__name__}") | |
| print(f" - Learning rate: {learning_rate}") | |
| print(f" - Weight decay: {weight_decay}") | |
| print(f" - Warmup iterations: {warmup_iterations}") | |
| print(f" - LR schedule: {lr_schedule if lr_schedule else 'None (fixed LR)'}") | |
| print(f" - AMP enabled: {use_amp} ({'GradScaler active' if self.amp_scaler else 'disabled'})") | |
| def build_model(self): | |
| """ | |
| Component 2: Build model and autoencoder (VQGAN). | |
| Sets up: | |
| - FullUNET model | |
| - Model compilation (optional) | |
| - VQGAN autoencoder for encoding/decoding | |
| - Model info printing | |
| """ | |
| # Build main model | |
| print("Building FullUNET model...") | |
| self.model = FullUNET() | |
| self.model = self.model.to(self.device) | |
| # Optional: Compile model for optimization | |
| # Model compilation can provide 20-30% speedup on modern GPUs | |
| # but requires PyTorch 2.0+ and may have compatibility issues | |
| self.model_compiled = False | |
| if compile_flag: | |
| try: | |
| print(f"Compiling model with mode: {compile_mode}...") | |
| self.model = torch.compile(self.model, mode=compile_mode) | |
| self.model_compiled = True | |
| print("✓ Model compilation done") | |
| except Exception as e: | |
| print(f"⚠ Warning: Model compilation failed: {e}") | |
| print(" Continuing without compilation...") | |
| self.model_compiled = False | |
| # Load VQGAN autoencoder | |
| print("Loading VQGAN autoencoder...") | |
| self.autoencoder = get_vqgan_model() | |
| print("✓ VQGAN autoencoder loaded") | |
| # Initialize LPIPS model for validation | |
| print("Loading LPIPS metric...") | |
| self.lpips_model = lpips.LPIPS(net='vgg').to(self.device) | |
| for params in self.lpips_model.parameters(): | |
| params.requires_grad_(False) | |
| self.lpips_model.eval() | |
| print("✓ LPIPS metric loaded") | |
| # Initialize EMA if enabled | |
| if ema_rate > 0: | |
| print(f"Initializing EMA with rate: {ema_rate}...") | |
| self.ema = EMA(self.model, ema_rate=ema_rate, device=self.device) | |
| # Add Swin Transformer relative position index to ignore keys | |
| self.ema.add_ignore_key('relative_position_index') | |
| print("✓ EMA initialized") | |
| else: | |
| print("⚠ EMA disabled (ema_rate = 0)") | |
| # Print model information | |
| self.print_model_info() | |
| def print_model_info(self): | |
| """Print model parameter count and architecture info.""" | |
| # Count parameters | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| print(f"\n✓ Model built successfully:") | |
| print(f" - Model: FullUNET") | |
| print(f" - Total parameters: {total_params / 1e6:.2f}M") | |
| print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M") | |
| print(f" - Device: {self.device}") | |
| print(f" - Compiled: {'Yes' if getattr(self, 'model_compiled', False) else 'No'}") | |
| if self.autoencoder is not None: | |
| print(f" - Autoencoder: VQGAN (loaded)") | |
| def build_dataloader(self): | |
| """ | |
| Component 3: Build train and validation dataloaders. | |
| Sets up: | |
| - Train dataloader with infinite cycle wrapper | |
| - Validation dataloader (if validation dataset exists) | |
| - Proper batch sizes, num_workers, pin_memory, etc. | |
| """ | |
| def _wrap_loader(loader): | |
| """Wrap dataloader to cycle infinitely.""" | |
| while True: | |
| yield from loader | |
| # Create datasets dictionary | |
| datasets = {'train': train_dataset} | |
| if valid_dataset is not None: | |
| datasets['val'] = valid_dataset | |
| # Print dataset sizes | |
| for phase, dataset in datasets.items(): | |
| print(f" - {phase.capitalize()} dataset: {len(dataset)} images") | |
| # Create train dataloader | |
| train_batch_size = batch[0] # Use first value from batch list | |
| train_loader = DataLoader( | |
| datasets['train'], | |
| batch_size=train_batch_size, | |
| shuffle=True, | |
| drop_last=True, # Drop last incomplete batch | |
| num_workers=min(num_workers, 4), # Limit num_workers | |
| pin_memory=True if torch.cuda.is_available() else False, | |
| prefetch_factor=prefetch_factor if num_workers > 0 else None, | |
| ) | |
| # Wrap train loader to cycle infinitely | |
| self.dataloaders = {'train': _wrap_loader(train_loader)} | |
| # Create validation dataloader if validation dataset exists | |
| if 'val' in datasets: | |
| val_batch_size = batch[1] if len(batch) > 1 else batch[0] # Use second value or fallback | |
| val_loader = DataLoader( | |
| datasets['val'], | |
| batch_size=val_batch_size, | |
| shuffle=False, | |
| drop_last=False, # Don't drop last batch in validation | |
| num_workers=0, # No multiprocessing for validation (safer) | |
| pin_memory=True if torch.cuda.is_available() else False, | |
| ) | |
| self.dataloaders['val'] = val_loader | |
| # Store datasets | |
| self.datasets = datasets | |
| print(f"\n✓ Dataloaders built:") | |
| print(f" - Train batch size: {train_batch_size}") | |
| print(f" - Train num_workers: {min(num_workers, 4)}") | |
| print(f" - Train drop_last: True") | |
| if 'val' in self.dataloaders: | |
| print(f" - Val batch size: {val_batch_size}") | |
| print(f" - Val num_workers: 0") | |
| def backward_step(self, loss, num_grad_accumulate=1): | |
| """ | |
| Component 4: Handle backward pass with AMP support and gradient accumulation. | |
| Args: | |
| loss: The computed loss tensor | |
| num_grad_accumulate: Number of gradient accumulation steps (for micro-batching) | |
| Returns: | |
| loss: The loss tensor (for logging) | |
| """ | |
| # Normalize loss by gradient accumulation steps | |
| loss = loss / num_grad_accumulate | |
| # Backward pass: use AMP scaler if available, otherwise direct backward | |
| if self.amp_scaler is None: | |
| loss.backward() | |
| else: | |
| self.amp_scaler.scale(loss).backward() | |
| return loss | |
| def _scale_input(self, x_t, t): | |
| """ | |
| Scale input based on timestep for training stability. | |
| Matches original GaussianDiffusion._scale_input for latent space. | |
| For latent space: std = sqrt(etas[t] * kappa^2 + 1) | |
| This normalizes the input variance across different timesteps. | |
| Args: | |
| x_t: Noisy input tensor (B, C, H, W) | |
| t: Timestep tensor (B,) | |
| Returns: | |
| x_t_scaled: Scaled input tensor (B, C, H, W) | |
| """ | |
| if normalize_input and latent_flag: | |
| # For latent space: std = sqrt(etas[t] * kappa^2 + 1) | |
| # Extract eta_t for each sample in batch | |
| eta_t = self.eta[t] # (B, 1, 1, 1) | |
| std = torch.sqrt(eta_t * k**2 + 1) | |
| x_t_scaled = x_t / std | |
| else: | |
| x_t_scaled = x_t | |
| return x_t_scaled | |
| def training_step(self, hr_latent, lr_latent): | |
| """ | |
| Component 5: Main training step with micro-batching and gradient accumulation. | |
| Args: | |
| hr_latent: High-resolution latent tensor (B, C, 64, 64) | |
| lr_latent: Low-resolution latent tensor (B, C, 64, 64) | |
| Returns: | |
| loss: Average loss value for logging | |
| timing_dict: Dictionary with timing information | |
| """ | |
| step_start = time.time() | |
| self.model.train() | |
| current_batchsize = hr_latent.shape[0] | |
| micro_batchsize = microbatch | |
| num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize) | |
| total_loss = 0.0 | |
| forward_time = 0.0 | |
| backward_time = 0.0 | |
| # Process in micro-batches for gradient accumulation | |
| for jj in range(0, current_batchsize, micro_batchsize): | |
| # Extract micro-batch | |
| end_idx = min(jj + micro_batchsize, current_batchsize) | |
| hr_micro = hr_latent[jj:end_idx].to(self.device) | |
| lr_micro = lr_latent[jj:end_idx].to(self.device) | |
| last_batch = (end_idx >= current_batchsize) | |
| # Compute residual in latent space | |
| residual = (lr_micro - hr_micro) | |
| # Generate random timesteps for each sample in micro-batch | |
| t = torch.randint(0, T, (hr_micro.shape[0],)).to(self.device) | |
| # Add noise in latent space (ResShift noise schedule) | |
| epsilon = torch.randn_like(hr_micro) # Noise in latent space | |
| eta_t = self.eta[t] # (B, 1, 1, 1) | |
| x_t = hr_micro + eta_t * residual + k * torch.sqrt(eta_t) * epsilon | |
| # Forward pass with autocast if AMP is enabled | |
| forward_start = time.time() | |
| if use_amp and torch.cuda.is_available(): | |
| context = torch.amp.autocast('cuda') | |
| else: | |
| context = nullcontext() | |
| with context: | |
| # Scale input for training stability (normalize variance across timesteps) | |
| x_t_scaled = self._scale_input(x_t, t) | |
| # Forward pass: Model predicts x0 (clean HR latent), not noise | |
| # ResShift uses predict_type = "xstart" | |
| x0_pred = self.model(x_t_scaled, t, lq=lr_micro) | |
| # Loss: Compare predicted x0 with ground truth HR latent | |
| loss = self.criterion(x0_pred, hr_micro) | |
| forward_time += time.time() - forward_start | |
| # Store loss value for logging (before dividing for gradient accumulation) | |
| total_loss += loss.item() | |
| # Backward step (handles gradient accumulation and AMP) | |
| backward_start = time.time() | |
| self.backward_step(loss, num_grad_accumulate) | |
| backward_time += time.time() - backward_start | |
| # Gradient clipping before optimizer step | |
| if self.amp_scaler is None: | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.optimizer.step() | |
| else: | |
| # Unscale gradients before clipping when using AMP | |
| self.amp_scaler.unscale_(self.optimizer) | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.amp_scaler.step(self.optimizer) | |
| self.amp_scaler.update() | |
| # Zero gradients | |
| self.model.zero_grad() | |
| # Update EMA after optimizer step | |
| if self.ema is not None: | |
| self.ema.update(self.model) | |
| # Compute total step time | |
| step_time = time.time() - step_start | |
| # Return average loss (average across micro-batches) | |
| num_micro_batches = math.ceil(current_batchsize / micro_batchsize) | |
| avg_loss = total_loss / num_micro_batches if num_micro_batches > 0 else total_loss | |
| # Return timing information | |
| timing_dict = { | |
| 'step_time': step_time, | |
| 'forward_time': forward_time, | |
| 'backward_time': backward_time, | |
| 'num_micro_batches': num_micro_batches | |
| } | |
| return avg_loss, timing_dict | |
| def adjust_lr(self, current_iters=None): | |
| """ | |
| Component 6: Adjust learning rate with warmup and optional cosine annealing. | |
| Learning rate schedule: | |
| - Warmup phase (iters <= warmup_iterations): Linear increase from 0 to base_lr | |
| - After warmup: Use cosine annealing scheduler if lr_schedule == 'cosin', else keep base_lr | |
| Args: | |
| current_iters: Current iteration number (defaults to self.current_iters) | |
| """ | |
| base_lr = learning_rate | |
| warmup_steps = warmup_iterations | |
| current_iters = self.current_iters if current_iters is None else current_iters | |
| if current_iters <= warmup_steps: | |
| # Warmup phase: linear increase from 0 to base_lr | |
| warmup_lr = (current_iters / warmup_steps) * base_lr | |
| for params_group in self.optimizer.param_groups: | |
| params_group['lr'] = warmup_lr | |
| else: | |
| # After warmup: use scheduler if available | |
| if self.lr_scheduler is not None: | |
| self.lr_scheduler.step() | |
| def save_ckpt(self): | |
| """ | |
| Component 7: Save checkpoint with model state, optimizer state, and training info. | |
| Saves: | |
| - Model state dict | |
| - Optimizer state dict | |
| - Current iteration number | |
| - AMP scaler state (if AMP is enabled) | |
| - LR scheduler state (if scheduler exists) | |
| """ | |
| ckpt_path = self.ckpt_dir / f'model_{self.current_iters}.pth' | |
| # Prepare checkpoint dictionary | |
| ckpt = { | |
| 'iters_start': self.current_iters, | |
| 'state_dict': self.model.state_dict(), | |
| } | |
| # Add optimizer state if available | |
| if hasattr(self, 'optimizer'): | |
| ckpt['optimizer'] = self.optimizer.state_dict() | |
| # Add AMP scaler state if available | |
| if self.amp_scaler is not None: | |
| ckpt['amp_scaler'] = self.amp_scaler.state_dict() | |
| # Add LR scheduler state if available | |
| if self.lr_scheduler is not None: | |
| ckpt['lr_scheduler'] = self.lr_scheduler.state_dict() | |
| # Save checkpoint | |
| torch.save(ckpt, ckpt_path) | |
| print(f"✓ Checkpoint saved: {ckpt_path}") | |
| # Save EMA checkpoint separately if EMA is enabled | |
| if self.ema is not None: | |
| ema_ckpt_path = self.ckpt_dir / f'ema_model_{self.current_iters}.pth' | |
| torch.save(self.ema.state_dict(), ema_ckpt_path) | |
| print(f"✓ EMA checkpoint saved: {ema_ckpt_path}") | |
| return ckpt_path | |
| def resume_from_ckpt(self, ckpt_path): | |
| """ | |
| Resume training from a checkpoint. | |
| Loads: | |
| - Model state dict | |
| - Optimizer state dict | |
| - AMP scaler state (if AMP is enabled) | |
| - LR scheduler state (if scheduler exists) | |
| - Current iteration number | |
| - Restores LR schedule by replaying adjust_lr for previous iterations | |
| Args: | |
| ckpt_path: Path to checkpoint file (.pth) | |
| """ | |
| if not os.path.isfile(ckpt_path): | |
| raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}") | |
| if not ckpt_path.endswith('.pth'): | |
| raise ValueError(f"Checkpoint file must have .pth extension: {ckpt_path}") | |
| print(f"\n{'=' * 100}") | |
| print(f"Resuming from checkpoint: {ckpt_path}") | |
| print(f"{'=' * 100}") | |
| # Load checkpoint | |
| ckpt = torch.load(ckpt_path, map_location=self.device) | |
| # Load model state dict | |
| if 'state_dict' in ckpt: | |
| self.model.load_state_dict(ckpt['state_dict']) | |
| print(f"✓ Loaded model state dict") | |
| else: | |
| # If checkpoint is just the state dict | |
| self.model.load_state_dict(ckpt) | |
| print(f"✓ Loaded model state dict (direct)") | |
| # Load optimizer state dict (must be done after optimizer is created) | |
| if 'optimizer' in ckpt: | |
| if hasattr(self, 'optimizer'): | |
| self.optimizer.load_state_dict(ckpt['optimizer']) | |
| print(f"✓ Loaded optimizer state dict") | |
| else: | |
| print(f"⚠ Warning: Optimizer state found in checkpoint but optimizer not yet created.") | |
| print(f" Optimizer will be loaded after setup_optimization() is called.") | |
| self._pending_optimizer_state = ckpt['optimizer'] | |
| # Load AMP scaler state | |
| if 'amp_scaler' in ckpt: | |
| if hasattr(self, 'amp_scaler') and self.amp_scaler is not None: | |
| self.amp_scaler.load_state_dict(ckpt['amp_scaler']) | |
| print(f"✓ Loaded AMP scaler state") | |
| else: | |
| print(f"⚠ Warning: AMP scaler state found but AMP not enabled or scaler not created.") | |
| # Load LR scheduler state | |
| if 'lr_scheduler' in ckpt: | |
| if hasattr(self, 'lr_scheduler') and self.lr_scheduler is not None: | |
| self.lr_scheduler.load_state_dict(ckpt['lr_scheduler']) | |
| print(f"✓ Loaded LR scheduler state") | |
| else: | |
| print(f"⚠ Warning: LR scheduler state found but scheduler not yet created.") | |
| self._pending_lr_scheduler_state = ckpt['lr_scheduler'] | |
| # Load EMA state if available (must be done after EMA is initialized) | |
| # EMA checkpoint naming: ema_model_{iters}.pth (matches save pattern) | |
| ckpt_path_obj = Path(ckpt_path) | |
| # Extract iteration number from checkpoint name (e.g., "model_10000.pth" -> "10000") | |
| if 'iters_start' in ckpt: | |
| iters = ckpt['iters_start'] | |
| ema_ckpt_path = ckpt_path_obj.parent / f"ema_model_{iters}.pth" | |
| else: | |
| # Fallback: try to extract from filename | |
| try: | |
| iters = int(ckpt_path_obj.stem.split('_')[-1]) | |
| ema_ckpt_path = ckpt_path_obj.parent / f"ema_model_{iters}.pth" | |
| except: | |
| ema_ckpt_path = None | |
| if ema_ckpt_path is not None and ema_ckpt_path.exists() and self.ema is not None: | |
| ema_ckpt = torch.load(ema_ckpt_path, map_location=self.device) | |
| self.ema.load_state_dict(ema_ckpt) | |
| print(f"✓ Loaded EMA state from: {ema_ckpt_path}") | |
| elif ema_ckpt_path is not None and ema_ckpt_path.exists() and self.ema is None: | |
| print(f"⚠ Warning: EMA checkpoint found but EMA not enabled. Skipping EMA load.") | |
| elif self.ema is not None: | |
| print(f"⚠ Warning: EMA enabled but no EMA checkpoint found. Starting with fresh EMA.") | |
| # Restore iteration number | |
| if 'iters_start' in ckpt: | |
| self.iters_start = ckpt['iters_start'] | |
| self.current_iters = ckpt['iters_start'] | |
| print(f"✓ Resuming from iteration: {self.iters_start}") | |
| else: | |
| print(f"⚠ Warning: No iteration number found in checkpoint. Starting from 0.") | |
| self.iters_start = 0 | |
| self.current_iters = 0 | |
| # Note: LR schedule restoration will be done after setup_optimization() | |
| # Store the iteration number for later restoration | |
| self._resume_iters = self.iters_start | |
| print(f"{'=' * 100}\n") | |
| # Clear CUDA cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def log_step_train(self, loss, hr_latent, lr_latent, x_t, pred, phase='train'): | |
| """ | |
| Component 8: Log training metrics and images to WandB. | |
| Logs: | |
| - Loss and learning rate (at log_freq[0] intervals) | |
| - Training images: HR, LR, and predictions (at log_freq[1] intervals) | |
| - Elapsed time for checkpoint intervals | |
| Args: | |
| loss: Training loss value (float) | |
| hr_latent: High-resolution latent tensor (B, C, 64, 64) | |
| lr_latent: Low-resolution latent tensor (B, C, 64, 64) | |
| x_t: Noisy input tensor (B, C, 64, 64) | |
| pred: Model prediction (x0_pred - clean HR latent) (B, C, 64, 64) | |
| phase: Training phase ('train' or 'val') | |
| """ | |
| # Log loss and learning rate at log_freq[0] intervals | |
| if self.current_iters % log_freq[0] == 0: | |
| current_lr = self.optimizer.param_groups[0]['lr'] | |
| # Get timing info if available (passed from training_step) | |
| timing_info = {} | |
| if hasattr(self, '_last_timing'): | |
| timing_info = { | |
| 'train/step_time': self._last_timing.get('step_time', 0), | |
| 'train/forward_time': self._last_timing.get('forward_time', 0), | |
| 'train/backward_time': self._last_timing.get('backward_time', 0), | |
| 'train/iterations_per_sec': 1.0 / self._last_timing.get('step_time', 1.0) if self._last_timing.get('step_time', 0) > 0 else 0 | |
| } | |
| wandb.log({ | |
| 'loss': loss, | |
| 'learning_rate': current_lr, | |
| 'step': self.current_iters, | |
| **timing_info | |
| }) | |
| # Print to console | |
| timing_str = "" | |
| if hasattr(self, '_last_timing') and self._last_timing.get('step_time', 0) > 0: | |
| timing_str = f", Step: {self._last_timing['step_time']:.3f}s, Forward: {self._last_timing['forward_time']:.3f}s, Backward: {self._last_timing['backward_time']:.3f}s" | |
| print(f"Train: {self.current_iters:06d}/{iterations:06d}, " | |
| f"Loss: {loss:.6f}, LR: {current_lr:.2e}{timing_str}") | |
| # Log images at log_freq[1] intervals (only if x_t and pred are provided) | |
| if self.current_iters % log_freq[1] == 0 and x_t is not None and pred is not None: | |
| with torch.no_grad(): | |
| # Decode latents to pixel space for visualization | |
| # Take first sample from batch | |
| hr_pixel = self.autoencoder.decode(hr_latent[0:1]) # (1, 3, 256, 256) | |
| lr_pixel = self.autoencoder.decode(lr_latent[0:1]) # (1, 3, 256, 256) | |
| # Decode noisy input for visualization | |
| x_t_pixel = self.autoencoder.decode(x_t[0:1]) # (1, 3, 256, 256) | |
| # Decode predicted x0 (clean HR latent) for visualization | |
| pred_pixel = self.autoencoder.decode(pred[0:1]) # (1, 3, 256, 256) | |
| # Log images to WandB | |
| wandb.log({ | |
| f'{phase}/hr_sample': wandb.Image(hr_pixel[0].cpu().clamp(0, 1)), | |
| f'{phase}/lr_sample': wandb.Image(lr_pixel[0].cpu().clamp(0, 1)), | |
| f'{phase}/noisy_input': wandb.Image(x_t_pixel[0].cpu().clamp(0, 1)), | |
| f'{phase}/pred_sample': wandb.Image(pred_pixel[0].cpu().clamp(0, 1)), | |
| 'step': self.current_iters | |
| }) | |
| # Track elapsed time for checkpoint intervals | |
| if self.current_iters % save_freq == 1: | |
| self.tic = time.time() | |
| if self.current_iters % save_freq == 0 and self.tic is not None: | |
| self.toc = time.time() | |
| elapsed = self.toc - self.tic | |
| print(f"Elapsed time for {save_freq} iterations: {elapsed:.2f}s") | |
| print("=" * 100) | |
| def validation(self): | |
| """ | |
| Run validation on validation dataset with full diffusion sampling loop. | |
| Performs iterative denoising from t = T-1 down to t = 0, matching the | |
| original ResShift implementation. This is slower but more accurate than | |
| single-step prediction. | |
| Computes: | |
| - PSNR, SSIM, and LPIPS metrics | |
| - Logs validation images to WandB | |
| """ | |
| if 'val' not in self.dataloaders: | |
| print("No validation dataset available. Skipping validation.") | |
| return | |
| print("\n" + "=" * 100) | |
| print("Running Validation") | |
| print("=" * 100) | |
| val_start = time.time() | |
| # Use EMA model for validation if enabled | |
| if use_ema_val and self.ema is not None: | |
| # Create EMA model copy if it doesn't exist | |
| if self.ema_model is None: | |
| from copy import deepcopy | |
| self.ema_model = deepcopy(self.model) | |
| # Load EMA state into EMA model | |
| self.ema.apply_to_model(self.ema_model) | |
| self.ema_model.eval() | |
| val_model = self.ema_model | |
| print("Using EMA model for validation") | |
| else: | |
| self.model.eval() | |
| val_model = self.model | |
| if use_ema_val and self.ema is None: | |
| print("⚠ Warning: use_ema_val=True but EMA not enabled. Using regular model.") | |
| val_iter = iter(self.dataloaders['val']) | |
| total_psnr = 0.0 | |
| total_ssim = 0.0 | |
| total_lpips = 0.0 | |
| total_val_loss = 0.0 | |
| num_samples = 0 | |
| total_sampling_time = 0.0 | |
| total_forward_time = 0.0 | |
| total_decode_time = 0.0 | |
| total_metric_time = 0.0 | |
| with torch.no_grad(): | |
| for batch_idx, (hr_latent, lr_latent) in enumerate(val_iter): | |
| batch_start = time.time() | |
| # Move to device | |
| hr_latent = hr_latent.to(self.device) | |
| lr_latent = lr_latent.to(self.device) | |
| # Full diffusion sampling loop (iterative denoising) | |
| # Start from maximum timestep and iterate backwards: T-1 → T-2 → ... → 1 → 0 | |
| sampling_start = time.time() | |
| # Initialize x_t at maximum timestep (T-1) | |
| # Start from LR with maximum noise (prior_sample: x_T = y + kappa * sqrt(eta_T) * noise) | |
| epsilon_init = torch.randn_like(lr_latent) | |
| eta_max = self.eta[T - 1] | |
| x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init | |
| # Track forward pass time during sampling | |
| sampling_forward_time = 0.0 | |
| # Iterative sampling: denoise from t = T-1 down to t = 0 | |
| for t_step in range(T - 1, -1, -1): # T-1, T-2, ..., 1, 0 | |
| t = torch.full((hr_latent.shape[0],), t_step, device=self.device, dtype=torch.long) | |
| # Scale input for training stability (normalize variance across timesteps) | |
| x_t_scaled = self._scale_input(x_t, t) | |
| # Predict x0 from current noisy state x_t | |
| forward_start = time.time() | |
| x0_pred = val_model(x_t_scaled, t, lq=lr_latent) | |
| sampling_forward_time += time.time() - forward_start | |
| # If not the last step, compute x_{t-1} from predicted x0 using equation (7) | |
| if t_step > 0: | |
| # Equation (7) from ResShift paper: | |
| # μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t) | |
| # Σ_θ = κ² * (η_{t-1}/η_t) * α_t | |
| # x_{t-1} = μ_θ + sqrt(Σ_θ) * ε | |
| eta_t = self.eta[t_step] | |
| eta_t_minus_1 = self.eta[t_step - 1] | |
| # Compute alpha_t = η_t - η_{t-1} | |
| alpha_t = eta_t - eta_t_minus_1 | |
| # Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred | |
| mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred | |
| # Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t | |
| variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t | |
| # Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε | |
| noise = torch.randn_like(x_t) | |
| nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1))) | |
| x_t = mean + nonzero_mask * torch.sqrt(variance) * noise | |
| else: | |
| # Final step: use predicted x0 as final output | |
| x_t = x0_pred | |
| # Final prediction after full sampling loop | |
| x0_final = x_t | |
| # Compute validation loss (MSE in latent space, same as training loss) | |
| val_loss = self.criterion(x0_final, hr_latent).item() | |
| total_val_loss += val_loss * hr_latent.shape[0] | |
| sampling_time = time.time() - sampling_start | |
| total_sampling_time += sampling_time | |
| total_forward_time += sampling_forward_time | |
| # Decode latents to pixel space for metrics and visualization | |
| decode_start = time.time() | |
| hr_pixel = self.autoencoder.decode(hr_latent) | |
| lr_pixel = self.autoencoder.decode(lr_latent) | |
| sr_pixel = self.autoencoder.decode(x0_final) # Final SR output after full sampling | |
| decode_time = time.time() - decode_start | |
| total_decode_time += decode_time | |
| # Convert to [0, 1] range if needed | |
| hr_pixel = hr_pixel.clamp(0, 1) | |
| sr_pixel = sr_pixel.clamp(0, 1) | |
| # Compute metrics using simple functions | |
| metric_start = time.time() | |
| batch_psnr = compute_psnr(hr_pixel, sr_pixel) | |
| total_psnr += batch_psnr * hr_latent.shape[0] | |
| batch_ssim = compute_ssim(hr_pixel, sr_pixel) | |
| total_ssim += batch_ssim * hr_latent.shape[0] | |
| batch_lpips = compute_lpips(hr_pixel, sr_pixel, self.lpips_model) | |
| total_lpips += batch_lpips * hr_latent.shape[0] | |
| metric_time = time.time() - metric_start | |
| total_metric_time += metric_time | |
| num_samples += hr_latent.shape[0] | |
| batch_time = time.time() - batch_start | |
| # Print timing for first batch | |
| if batch_idx == 0: | |
| print(f"\nValidation Batch 0 Timing:") | |
| print(f" - Sampling loop: {sampling_time:.3f}s ({sampling_forward_time:.3f}s forward)") | |
| print(f" - Decoding: {decode_time:.3f}s") | |
| print(f" - Metrics: {metric_time:.3f}s") | |
| print(f" - Total batch: {batch_time:.3f}s") | |
| # Log validation images periodically | |
| if batch_idx == 0: | |
| wandb.log({ | |
| 'val/hr_sample': wandb.Image(hr_pixel[0].cpu()), | |
| 'val/lr_sample': wandb.Image(lr_pixel[0].cpu()), | |
| 'val/sr_sample': wandb.Image(sr_pixel[0].cpu()), | |
| 'step': self.current_iters | |
| }) | |
| # Compute average metrics and timing | |
| val_total_time = time.time() - val_start | |
| num_batches = batch_idx + 1 | |
| if num_samples > 0: | |
| mean_psnr = total_psnr / num_samples | |
| mean_ssim = total_ssim / num_samples | |
| mean_lpips = total_lpips / num_samples | |
| mean_val_loss = total_val_loss / num_samples | |
| avg_sampling_time = total_sampling_time / num_batches | |
| avg_forward_time = total_forward_time / num_batches | |
| avg_decode_time = total_decode_time / num_batches | |
| avg_metric_time = total_metric_time / num_batches | |
| avg_batch_time = val_total_time / num_batches | |
| print(f"\nValidation Metrics:") | |
| print(f" - Loss: {mean_val_loss:.6f}") | |
| print(f" - PSNR: {mean_psnr:.2f} dB") | |
| print(f" - SSIM: {mean_ssim:.4f}") | |
| print(f" - LPIPS: {mean_lpips:.4f}") | |
| print(f"\nValidation Timing (Total: {val_total_time:.2f}s, {num_batches} batches):") | |
| print(f" - Avg sampling loop: {avg_sampling_time:.3f}s/batch ({avg_forward_time:.3f}s forward)") | |
| print(f" - Avg decoding: {avg_decode_time:.3f}s/batch") | |
| print(f" - Avg metrics: {avg_metric_time:.3f}s/batch") | |
| print(f" - Avg batch time: {avg_batch_time:.3f}s/batch") | |
| wandb.log({ | |
| 'val/loss': mean_val_loss, | |
| 'val/psnr': mean_psnr, | |
| 'val/ssim': mean_ssim, | |
| 'val/lpips': mean_lpips, | |
| 'val/total_time': val_total_time, | |
| 'val/avg_sampling_time': avg_sampling_time, | |
| 'val/avg_forward_time': avg_forward_time, | |
| 'val/avg_decode_time': avg_decode_time, | |
| 'val/avg_metric_time': avg_metric_time, | |
| 'val/avg_batch_time': avg_batch_time, | |
| 'val/num_batches': num_batches, | |
| 'val/num_samples': num_samples, | |
| 'step': self.current_iters | |
| }) | |
| print("=" * 100) | |
| # Set model back to training mode | |
| self.model.train() | |
| if self.ema_model is not None: | |
| self.ema_model.train() # Keep in sync, but won't be used for training | |
| # Note: Main training script is in train.py | |
| # This file contains the Trainer class implementation | |