Spaces:
Sleeping
Sleeping
File size: 41,134 Bytes
3c45764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 |
import torch
import torch.nn as nn
import random
import numpy as np
from model import FullUNET
from noiseControl import resshift_schedule
from torch.utils.data import DataLoader
from data import mini_dataset, train_dataset, valid_dataset, get_vqgan_model
import torch.optim as optim
from config import (
batch_size, device, learning_rate, iterations,
weight_decay, T, k, _project_root, num_workers,
use_amp, lr, lr_min, lr_schedule, warmup_iterations,
compile_flag, compile_mode, batch, prefetch_factor, microbatch,
save_freq, log_freq, val_freq, val_y_channel, ema_rate, use_ema_val,
seed, global_seeding, normalize_input, latent_flag
)
import wandb
import os
import math
import time
from pathlib import Path
from itertools import cycle
from contextlib import nullcontext
from dotenv import load_dotenv
from metrics import compute_psnr, compute_ssim, compute_lpips
from ema import EMA
import lpips
class Trainer:
"""
Modular trainer class following the original ResShift trainer structure.
"""
def __init__(self, save_dir=None, resume_ckpt=None):
"""
Initialize trainer with config values.
Args:
save_dir: Directory to save checkpoints (defaults to _project_root / 'checkpoints')
resume_ckpt: Path to checkpoint file to resume from (optional)
"""
self.device = device
self.current_iters = 0
self.iters_start = 0
self.resume_ckpt = resume_ckpt
# Setup checkpoint directory
if save_dir is None:
save_dir = _project_root / 'checkpoints'
self.save_dir = Path(save_dir)
self.ckpt_dir = self.save_dir / 'ckpts'
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
# Initialize noise schedule (eta values for ResShift)
self.eta = resshift_schedule().to(self.device)
self.eta = self.eta[:, None, None, None] # shape (T, 1, 1, 1)
# Loss criterion
self.criterion = nn.MSELoss()
# Timing for checkpoint saving
self.tic = None
# EMA will be initialized after model is built
self.ema = None
self.ema_model = None
# Set random seeds for reproducibility
self.setup_seed()
# Initialize WandB
self.init_wandb()
def setup_seed(self, seed_val=None, global_seeding_val=None):
"""
Set random seeds for reproducibility.
Sets seeds for:
- Python random module
- NumPy
- PyTorch (CPU and CUDA)
Args:
seed_val: Seed value (defaults to config.seed)
global_seeding_val: Whether to use global seeding (defaults to config.global_seeding)
"""
if seed_val is None:
seed_val = seed
if global_seeding_val is None:
global_seeding_val = global_seeding
# Set Python random seed
random.seed(seed_val)
# Set NumPy random seed
np.random.seed(seed_val)
# Set PyTorch random seed
torch.manual_seed(seed_val)
# Set CUDA random seeds (if available)
if torch.cuda.is_available():
if global_seeding_val:
torch.cuda.manual_seed_all(seed_val)
else:
torch.cuda.manual_seed(seed_val)
# For multi-GPU, each GPU would get seed + rank (not implemented here)
# Make deterministic (may impact performance)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
print(f"✓ Random seeds set: seed={seed_val}, global_seeding={global_seeding_val}")
def init_wandb(self):
"""Initialize WandB logging."""
load_dotenv(os.path.join(_project_root, '.env'))
wandb.init(
project="diffusionsr",
name="reshift_training",
config={
"learning_rate": learning_rate,
"batch_size": batch_size,
"steps": iterations,
"model": "ResShift",
"T": T,
"k": k,
"optimizer": "AdamW" if weight_decay > 0 else "Adam",
"betas": (0.9, 0.999),
"grad_clip": 1.0,
"criterion": "MSE",
"device": str(device),
"training_space": "latent_64x64",
"use_amp": use_amp,
"ema_rate": 0.999 if hasattr(self, 'ema_rate') else None
}
)
def setup_optimization(self):
"""
Component 1: Setup optimizer and AMP scaler.
Sets up:
- Optimizer (AdamW with weight decay or Adam)
- AMP GradScaler if use_amp is True
"""
# Use AdamW if weight_decay > 0, otherwise Adam
if weight_decay > 0:
self.optimizer = optim.AdamW(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
else:
self.optimizer = optim.Adam(
self.model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.999)
)
# AMP settings: Create GradScaler if use_amp is True and CUDA is available
if use_amp and torch.cuda.is_available():
self.amp_scaler = torch.amp.GradScaler('cuda')
else:
self.amp_scaler = None
if use_amp and not torch.cuda.is_available():
print(" ⚠ Warning: AMP requested but CUDA not available. Disabling AMP.")
# Learning rate scheduler (cosine annealing after warmup)
self.lr_scheduler = None
if lr_schedule == 'cosin':
self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=self.optimizer,
T_max=iterations - warmup_iterations,
eta_min=lr_min
)
print(f" - LR scheduler: CosineAnnealingLR (T_max={iterations - warmup_iterations}, eta_min={lr_min})")
# Load pending optimizer state if resuming
if hasattr(self, '_pending_optimizer_state'):
self.optimizer.load_state_dict(self._pending_optimizer_state)
print(f" - Loaded optimizer state from checkpoint")
delattr(self, '_pending_optimizer_state')
# Load pending LR scheduler state if resuming
if hasattr(self, '_pending_lr_scheduler_state') and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(self._pending_lr_scheduler_state)
print(f" - Loaded LR scheduler state from checkpoint")
delattr(self, '_pending_lr_scheduler_state')
# Restore LR schedule by replaying adjust_lr for all previous iterations
# This ensures the LR is at the correct value for the resumed iteration
if hasattr(self, '_resume_iters') and self._resume_iters > 0:
print(f" - Restoring learning rate schedule to iteration {self._resume_iters}...")
for ii in range(1, self._resume_iters + 1):
self.adjust_lr(ii)
print(f" - ✓ Learning rate schedule restored")
delattr(self, '_resume_iters')
print(f"✓ Setup optimization:")
print(f" - Optimizer: {type(self.optimizer).__name__}")
print(f" - Learning rate: {learning_rate}")
print(f" - Weight decay: {weight_decay}")
print(f" - Warmup iterations: {warmup_iterations}")
print(f" - LR schedule: {lr_schedule if lr_schedule else 'None (fixed LR)'}")
print(f" - AMP enabled: {use_amp} ({'GradScaler active' if self.amp_scaler else 'disabled'})")
def build_model(self):
"""
Component 2: Build model and autoencoder (VQGAN).
Sets up:
- FullUNET model
- Model compilation (optional)
- VQGAN autoencoder for encoding/decoding
- Model info printing
"""
# Build main model
print("Building FullUNET model...")
self.model = FullUNET()
self.model = self.model.to(self.device)
# Optional: Compile model for optimization
# Model compilation can provide 20-30% speedup on modern GPUs
# but requires PyTorch 2.0+ and may have compatibility issues
self.model_compiled = False
if compile_flag:
try:
print(f"Compiling model with mode: {compile_mode}...")
self.model = torch.compile(self.model, mode=compile_mode)
self.model_compiled = True
print("✓ Model compilation done")
except Exception as e:
print(f"⚠ Warning: Model compilation failed: {e}")
print(" Continuing without compilation...")
self.model_compiled = False
# Load VQGAN autoencoder
print("Loading VQGAN autoencoder...")
self.autoencoder = get_vqgan_model()
print("✓ VQGAN autoencoder loaded")
# Initialize LPIPS model for validation
print("Loading LPIPS metric...")
self.lpips_model = lpips.LPIPS(net='vgg').to(self.device)
for params in self.lpips_model.parameters():
params.requires_grad_(False)
self.lpips_model.eval()
print("✓ LPIPS metric loaded")
# Initialize EMA if enabled
if ema_rate > 0:
print(f"Initializing EMA with rate: {ema_rate}...")
self.ema = EMA(self.model, ema_rate=ema_rate, device=self.device)
# Add Swin Transformer relative position index to ignore keys
self.ema.add_ignore_key('relative_position_index')
print("✓ EMA initialized")
else:
print("⚠ EMA disabled (ema_rate = 0)")
# Print model information
self.print_model_info()
def print_model_info(self):
"""Print model parameter count and architecture info."""
# Count parameters
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"\n✓ Model built successfully:")
print(f" - Model: FullUNET")
print(f" - Total parameters: {total_params / 1e6:.2f}M")
print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f" - Device: {self.device}")
print(f" - Compiled: {'Yes' if getattr(self, 'model_compiled', False) else 'No'}")
if self.autoencoder is not None:
print(f" - Autoencoder: VQGAN (loaded)")
def build_dataloader(self):
"""
Component 3: Build train and validation dataloaders.
Sets up:
- Train dataloader with infinite cycle wrapper
- Validation dataloader (if validation dataset exists)
- Proper batch sizes, num_workers, pin_memory, etc.
"""
def _wrap_loader(loader):
"""Wrap dataloader to cycle infinitely."""
while True:
yield from loader
# Create datasets dictionary
datasets = {'train': train_dataset}
if valid_dataset is not None:
datasets['val'] = valid_dataset
# Print dataset sizes
for phase, dataset in datasets.items():
print(f" - {phase.capitalize()} dataset: {len(dataset)} images")
# Create train dataloader
train_batch_size = batch[0] # Use first value from batch list
train_loader = DataLoader(
datasets['train'],
batch_size=train_batch_size,
shuffle=True,
drop_last=True, # Drop last incomplete batch
num_workers=min(num_workers, 4), # Limit num_workers
pin_memory=True if torch.cuda.is_available() else False,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
)
# Wrap train loader to cycle infinitely
self.dataloaders = {'train': _wrap_loader(train_loader)}
# Create validation dataloader if validation dataset exists
if 'val' in datasets:
val_batch_size = batch[1] if len(batch) > 1 else batch[0] # Use second value or fallback
val_loader = DataLoader(
datasets['val'],
batch_size=val_batch_size,
shuffle=False,
drop_last=False, # Don't drop last batch in validation
num_workers=0, # No multiprocessing for validation (safer)
pin_memory=True if torch.cuda.is_available() else False,
)
self.dataloaders['val'] = val_loader
# Store datasets
self.datasets = datasets
print(f"\n✓ Dataloaders built:")
print(f" - Train batch size: {train_batch_size}")
print(f" - Train num_workers: {min(num_workers, 4)}")
print(f" - Train drop_last: True")
if 'val' in self.dataloaders:
print(f" - Val batch size: {val_batch_size}")
print(f" - Val num_workers: 0")
def backward_step(self, loss, num_grad_accumulate=1):
"""
Component 4: Handle backward pass with AMP support and gradient accumulation.
Args:
loss: The computed loss tensor
num_grad_accumulate: Number of gradient accumulation steps (for micro-batching)
Returns:
loss: The loss tensor (for logging)
"""
# Normalize loss by gradient accumulation steps
loss = loss / num_grad_accumulate
# Backward pass: use AMP scaler if available, otherwise direct backward
if self.amp_scaler is None:
loss.backward()
else:
self.amp_scaler.scale(loss).backward()
return loss
def _scale_input(self, x_t, t):
"""
Scale input based on timestep for training stability.
Matches original GaussianDiffusion._scale_input for latent space.
For latent space: std = sqrt(etas[t] * kappa^2 + 1)
This normalizes the input variance across different timesteps.
Args:
x_t: Noisy input tensor (B, C, H, W)
t: Timestep tensor (B,)
Returns:
x_t_scaled: Scaled input tensor (B, C, H, W)
"""
if normalize_input and latent_flag:
# For latent space: std = sqrt(etas[t] * kappa^2 + 1)
# Extract eta_t for each sample in batch
eta_t = self.eta[t] # (B, 1, 1, 1)
std = torch.sqrt(eta_t * k**2 + 1)
x_t_scaled = x_t / std
else:
x_t_scaled = x_t
return x_t_scaled
def training_step(self, hr_latent, lr_latent):
"""
Component 5: Main training step with micro-batching and gradient accumulation.
Args:
hr_latent: High-resolution latent tensor (B, C, 64, 64)
lr_latent: Low-resolution latent tensor (B, C, 64, 64)
Returns:
loss: Average loss value for logging
timing_dict: Dictionary with timing information
"""
step_start = time.time()
self.model.train()
current_batchsize = hr_latent.shape[0]
micro_batchsize = microbatch
num_grad_accumulate = math.ceil(current_batchsize / micro_batchsize)
total_loss = 0.0
forward_time = 0.0
backward_time = 0.0
# Process in micro-batches for gradient accumulation
for jj in range(0, current_batchsize, micro_batchsize):
# Extract micro-batch
end_idx = min(jj + micro_batchsize, current_batchsize)
hr_micro = hr_latent[jj:end_idx].to(self.device)
lr_micro = lr_latent[jj:end_idx].to(self.device)
last_batch = (end_idx >= current_batchsize)
# Compute residual in latent space
residual = (lr_micro - hr_micro)
# Generate random timesteps for each sample in micro-batch
t = torch.randint(0, T, (hr_micro.shape[0],)).to(self.device)
# Add noise in latent space (ResShift noise schedule)
epsilon = torch.randn_like(hr_micro) # Noise in latent space
eta_t = self.eta[t] # (B, 1, 1, 1)
x_t = hr_micro + eta_t * residual + k * torch.sqrt(eta_t) * epsilon
# Forward pass with autocast if AMP is enabled
forward_start = time.time()
if use_amp and torch.cuda.is_available():
context = torch.amp.autocast('cuda')
else:
context = nullcontext()
with context:
# Scale input for training stability (normalize variance across timesteps)
x_t_scaled = self._scale_input(x_t, t)
# Forward pass: Model predicts x0 (clean HR latent), not noise
# ResShift uses predict_type = "xstart"
x0_pred = self.model(x_t_scaled, t, lq=lr_micro)
# Loss: Compare predicted x0 with ground truth HR latent
loss = self.criterion(x0_pred, hr_micro)
forward_time += time.time() - forward_start
# Store loss value for logging (before dividing for gradient accumulation)
total_loss += loss.item()
# Backward step (handles gradient accumulation and AMP)
backward_start = time.time()
self.backward_step(loss, num_grad_accumulate)
backward_time += time.time() - backward_start
# Gradient clipping before optimizer step
if self.amp_scaler is None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
else:
# Unscale gradients before clipping when using AMP
self.amp_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.amp_scaler.step(self.optimizer)
self.amp_scaler.update()
# Zero gradients
self.model.zero_grad()
# Update EMA after optimizer step
if self.ema is not None:
self.ema.update(self.model)
# Compute total step time
step_time = time.time() - step_start
# Return average loss (average across micro-batches)
num_micro_batches = math.ceil(current_batchsize / micro_batchsize)
avg_loss = total_loss / num_micro_batches if num_micro_batches > 0 else total_loss
# Return timing information
timing_dict = {
'step_time': step_time,
'forward_time': forward_time,
'backward_time': backward_time,
'num_micro_batches': num_micro_batches
}
return avg_loss, timing_dict
def adjust_lr(self, current_iters=None):
"""
Component 6: Adjust learning rate with warmup and optional cosine annealing.
Learning rate schedule:
- Warmup phase (iters <= warmup_iterations): Linear increase from 0 to base_lr
- After warmup: Use cosine annealing scheduler if lr_schedule == 'cosin', else keep base_lr
Args:
current_iters: Current iteration number (defaults to self.current_iters)
"""
base_lr = learning_rate
warmup_steps = warmup_iterations
current_iters = self.current_iters if current_iters is None else current_iters
if current_iters <= warmup_steps:
# Warmup phase: linear increase from 0 to base_lr
warmup_lr = (current_iters / warmup_steps) * base_lr
for params_group in self.optimizer.param_groups:
params_group['lr'] = warmup_lr
else:
# After warmup: use scheduler if available
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def save_ckpt(self):
"""
Component 7: Save checkpoint with model state, optimizer state, and training info.
Saves:
- Model state dict
- Optimizer state dict
- Current iteration number
- AMP scaler state (if AMP is enabled)
- LR scheduler state (if scheduler exists)
"""
ckpt_path = self.ckpt_dir / f'model_{self.current_iters}.pth'
# Prepare checkpoint dictionary
ckpt = {
'iters_start': self.current_iters,
'state_dict': self.model.state_dict(),
}
# Add optimizer state if available
if hasattr(self, 'optimizer'):
ckpt['optimizer'] = self.optimizer.state_dict()
# Add AMP scaler state if available
if self.amp_scaler is not None:
ckpt['amp_scaler'] = self.amp_scaler.state_dict()
# Add LR scheduler state if available
if self.lr_scheduler is not None:
ckpt['lr_scheduler'] = self.lr_scheduler.state_dict()
# Save checkpoint
torch.save(ckpt, ckpt_path)
print(f"✓ Checkpoint saved: {ckpt_path}")
# Save EMA checkpoint separately if EMA is enabled
if self.ema is not None:
ema_ckpt_path = self.ckpt_dir / f'ema_model_{self.current_iters}.pth'
torch.save(self.ema.state_dict(), ema_ckpt_path)
print(f"✓ EMA checkpoint saved: {ema_ckpt_path}")
return ckpt_path
def resume_from_ckpt(self, ckpt_path):
"""
Resume training from a checkpoint.
Loads:
- Model state dict
- Optimizer state dict
- AMP scaler state (if AMP is enabled)
- LR scheduler state (if scheduler exists)
- Current iteration number
- Restores LR schedule by replaying adjust_lr for previous iterations
Args:
ckpt_path: Path to checkpoint file (.pth)
"""
if not os.path.isfile(ckpt_path):
raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")
if not ckpt_path.endswith('.pth'):
raise ValueError(f"Checkpoint file must have .pth extension: {ckpt_path}")
print(f"\n{'=' * 100}")
print(f"Resuming from checkpoint: {ckpt_path}")
print(f"{'=' * 100}")
# Load checkpoint
ckpt = torch.load(ckpt_path, map_location=self.device)
# Load model state dict
if 'state_dict' in ckpt:
self.model.load_state_dict(ckpt['state_dict'])
print(f"✓ Loaded model state dict")
else:
# If checkpoint is just the state dict
self.model.load_state_dict(ckpt)
print(f"✓ Loaded model state dict (direct)")
# Load optimizer state dict (must be done after optimizer is created)
if 'optimizer' in ckpt:
if hasattr(self, 'optimizer'):
self.optimizer.load_state_dict(ckpt['optimizer'])
print(f"✓ Loaded optimizer state dict")
else:
print(f"⚠ Warning: Optimizer state found in checkpoint but optimizer not yet created.")
print(f" Optimizer will be loaded after setup_optimization() is called.")
self._pending_optimizer_state = ckpt['optimizer']
# Load AMP scaler state
if 'amp_scaler' in ckpt:
if hasattr(self, 'amp_scaler') and self.amp_scaler is not None:
self.amp_scaler.load_state_dict(ckpt['amp_scaler'])
print(f"✓ Loaded AMP scaler state")
else:
print(f"⚠ Warning: AMP scaler state found but AMP not enabled or scaler not created.")
# Load LR scheduler state
if 'lr_scheduler' in ckpt:
if hasattr(self, 'lr_scheduler') and self.lr_scheduler is not None:
self.lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
print(f"✓ Loaded LR scheduler state")
else:
print(f"⚠ Warning: LR scheduler state found but scheduler not yet created.")
self._pending_lr_scheduler_state = ckpt['lr_scheduler']
# Load EMA state if available (must be done after EMA is initialized)
# EMA checkpoint naming: ema_model_{iters}.pth (matches save pattern)
ckpt_path_obj = Path(ckpt_path)
# Extract iteration number from checkpoint name (e.g., "model_10000.pth" -> "10000")
if 'iters_start' in ckpt:
iters = ckpt['iters_start']
ema_ckpt_path = ckpt_path_obj.parent / f"ema_model_{iters}.pth"
else:
# Fallback: try to extract from filename
try:
iters = int(ckpt_path_obj.stem.split('_')[-1])
ema_ckpt_path = ckpt_path_obj.parent / f"ema_model_{iters}.pth"
except:
ema_ckpt_path = None
if ema_ckpt_path is not None and ema_ckpt_path.exists() and self.ema is not None:
ema_ckpt = torch.load(ema_ckpt_path, map_location=self.device)
self.ema.load_state_dict(ema_ckpt)
print(f"✓ Loaded EMA state from: {ema_ckpt_path}")
elif ema_ckpt_path is not None and ema_ckpt_path.exists() and self.ema is None:
print(f"⚠ Warning: EMA checkpoint found but EMA not enabled. Skipping EMA load.")
elif self.ema is not None:
print(f"⚠ Warning: EMA enabled but no EMA checkpoint found. Starting with fresh EMA.")
# Restore iteration number
if 'iters_start' in ckpt:
self.iters_start = ckpt['iters_start']
self.current_iters = ckpt['iters_start']
print(f"✓ Resuming from iteration: {self.iters_start}")
else:
print(f"⚠ Warning: No iteration number found in checkpoint. Starting from 0.")
self.iters_start = 0
self.current_iters = 0
# Note: LR schedule restoration will be done after setup_optimization()
# Store the iteration number for later restoration
self._resume_iters = self.iters_start
print(f"{'=' * 100}\n")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def log_step_train(self, loss, hr_latent, lr_latent, x_t, pred, phase='train'):
"""
Component 8: Log training metrics and images to WandB.
Logs:
- Loss and learning rate (at log_freq[0] intervals)
- Training images: HR, LR, and predictions (at log_freq[1] intervals)
- Elapsed time for checkpoint intervals
Args:
loss: Training loss value (float)
hr_latent: High-resolution latent tensor (B, C, 64, 64)
lr_latent: Low-resolution latent tensor (B, C, 64, 64)
x_t: Noisy input tensor (B, C, 64, 64)
pred: Model prediction (x0_pred - clean HR latent) (B, C, 64, 64)
phase: Training phase ('train' or 'val')
"""
# Log loss and learning rate at log_freq[0] intervals
if self.current_iters % log_freq[0] == 0:
current_lr = self.optimizer.param_groups[0]['lr']
# Get timing info if available (passed from training_step)
timing_info = {}
if hasattr(self, '_last_timing'):
timing_info = {
'train/step_time': self._last_timing.get('step_time', 0),
'train/forward_time': self._last_timing.get('forward_time', 0),
'train/backward_time': self._last_timing.get('backward_time', 0),
'train/iterations_per_sec': 1.0 / self._last_timing.get('step_time', 1.0) if self._last_timing.get('step_time', 0) > 0 else 0
}
wandb.log({
'loss': loss,
'learning_rate': current_lr,
'step': self.current_iters,
**timing_info
})
# Print to console
timing_str = ""
if hasattr(self, '_last_timing') and self._last_timing.get('step_time', 0) > 0:
timing_str = f", Step: {self._last_timing['step_time']:.3f}s, Forward: {self._last_timing['forward_time']:.3f}s, Backward: {self._last_timing['backward_time']:.3f}s"
print(f"Train: {self.current_iters:06d}/{iterations:06d}, "
f"Loss: {loss:.6f}, LR: {current_lr:.2e}{timing_str}")
# Log images at log_freq[1] intervals (only if x_t and pred are provided)
if self.current_iters % log_freq[1] == 0 and x_t is not None and pred is not None:
with torch.no_grad():
# Decode latents to pixel space for visualization
# Take first sample from batch
hr_pixel = self.autoencoder.decode(hr_latent[0:1]) # (1, 3, 256, 256)
lr_pixel = self.autoencoder.decode(lr_latent[0:1]) # (1, 3, 256, 256)
# Decode noisy input for visualization
x_t_pixel = self.autoencoder.decode(x_t[0:1]) # (1, 3, 256, 256)
# Decode predicted x0 (clean HR latent) for visualization
pred_pixel = self.autoencoder.decode(pred[0:1]) # (1, 3, 256, 256)
# Log images to WandB
wandb.log({
f'{phase}/hr_sample': wandb.Image(hr_pixel[0].cpu().clamp(0, 1)),
f'{phase}/lr_sample': wandb.Image(lr_pixel[0].cpu().clamp(0, 1)),
f'{phase}/noisy_input': wandb.Image(x_t_pixel[0].cpu().clamp(0, 1)),
f'{phase}/pred_sample': wandb.Image(pred_pixel[0].cpu().clamp(0, 1)),
'step': self.current_iters
})
# Track elapsed time for checkpoint intervals
if self.current_iters % save_freq == 1:
self.tic = time.time()
if self.current_iters % save_freq == 0 and self.tic is not None:
self.toc = time.time()
elapsed = self.toc - self.tic
print(f"Elapsed time for {save_freq} iterations: {elapsed:.2f}s")
print("=" * 100)
def validation(self):
"""
Run validation on validation dataset with full diffusion sampling loop.
Performs iterative denoising from t = T-1 down to t = 0, matching the
original ResShift implementation. This is slower but more accurate than
single-step prediction.
Computes:
- PSNR, SSIM, and LPIPS metrics
- Logs validation images to WandB
"""
if 'val' not in self.dataloaders:
print("No validation dataset available. Skipping validation.")
return
print("\n" + "=" * 100)
print("Running Validation")
print("=" * 100)
val_start = time.time()
# Use EMA model for validation if enabled
if use_ema_val and self.ema is not None:
# Create EMA model copy if it doesn't exist
if self.ema_model is None:
from copy import deepcopy
self.ema_model = deepcopy(self.model)
# Load EMA state into EMA model
self.ema.apply_to_model(self.ema_model)
self.ema_model.eval()
val_model = self.ema_model
print("Using EMA model for validation")
else:
self.model.eval()
val_model = self.model
if use_ema_val and self.ema is None:
print("⚠ Warning: use_ema_val=True but EMA not enabled. Using regular model.")
val_iter = iter(self.dataloaders['val'])
total_psnr = 0.0
total_ssim = 0.0
total_lpips = 0.0
total_val_loss = 0.0
num_samples = 0
total_sampling_time = 0.0
total_forward_time = 0.0
total_decode_time = 0.0
total_metric_time = 0.0
with torch.no_grad():
for batch_idx, (hr_latent, lr_latent) in enumerate(val_iter):
batch_start = time.time()
# Move to device
hr_latent = hr_latent.to(self.device)
lr_latent = lr_latent.to(self.device)
# Full diffusion sampling loop (iterative denoising)
# Start from maximum timestep and iterate backwards: T-1 → T-2 → ... → 1 → 0
sampling_start = time.time()
# Initialize x_t at maximum timestep (T-1)
# Start from LR with maximum noise (prior_sample: x_T = y + kappa * sqrt(eta_T) * noise)
epsilon_init = torch.randn_like(lr_latent)
eta_max = self.eta[T - 1]
x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init
# Track forward pass time during sampling
sampling_forward_time = 0.0
# Iterative sampling: denoise from t = T-1 down to t = 0
for t_step in range(T - 1, -1, -1): # T-1, T-2, ..., 1, 0
t = torch.full((hr_latent.shape[0],), t_step, device=self.device, dtype=torch.long)
# Scale input for training stability (normalize variance across timesteps)
x_t_scaled = self._scale_input(x_t, t)
# Predict x0 from current noisy state x_t
forward_start = time.time()
x0_pred = val_model(x_t_scaled, t, lq=lr_latent)
sampling_forward_time += time.time() - forward_start
# If not the last step, compute x_{t-1} from predicted x0 using equation (7)
if t_step > 0:
# Equation (7) from ResShift paper:
# μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
# Σ_θ = κ² * (η_{t-1}/η_t) * α_t
# x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
eta_t = self.eta[t_step]
eta_t_minus_1 = self.eta[t_step - 1]
# Compute alpha_t = η_t - η_{t-1}
alpha_t = eta_t - eta_t_minus_1
# Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred
mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
# Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t
# Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
noise = torch.randn_like(x_t)
nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1)))
x_t = mean + nonzero_mask * torch.sqrt(variance) * noise
else:
# Final step: use predicted x0 as final output
x_t = x0_pred
# Final prediction after full sampling loop
x0_final = x_t
# Compute validation loss (MSE in latent space, same as training loss)
val_loss = self.criterion(x0_final, hr_latent).item()
total_val_loss += val_loss * hr_latent.shape[0]
sampling_time = time.time() - sampling_start
total_sampling_time += sampling_time
total_forward_time += sampling_forward_time
# Decode latents to pixel space for metrics and visualization
decode_start = time.time()
hr_pixel = self.autoencoder.decode(hr_latent)
lr_pixel = self.autoencoder.decode(lr_latent)
sr_pixel = self.autoencoder.decode(x0_final) # Final SR output after full sampling
decode_time = time.time() - decode_start
total_decode_time += decode_time
# Convert to [0, 1] range if needed
hr_pixel = hr_pixel.clamp(0, 1)
sr_pixel = sr_pixel.clamp(0, 1)
# Compute metrics using simple functions
metric_start = time.time()
batch_psnr = compute_psnr(hr_pixel, sr_pixel)
total_psnr += batch_psnr * hr_latent.shape[0]
batch_ssim = compute_ssim(hr_pixel, sr_pixel)
total_ssim += batch_ssim * hr_latent.shape[0]
batch_lpips = compute_lpips(hr_pixel, sr_pixel, self.lpips_model)
total_lpips += batch_lpips * hr_latent.shape[0]
metric_time = time.time() - metric_start
total_metric_time += metric_time
num_samples += hr_latent.shape[0]
batch_time = time.time() - batch_start
# Print timing for first batch
if batch_idx == 0:
print(f"\nValidation Batch 0 Timing:")
print(f" - Sampling loop: {sampling_time:.3f}s ({sampling_forward_time:.3f}s forward)")
print(f" - Decoding: {decode_time:.3f}s")
print(f" - Metrics: {metric_time:.3f}s")
print(f" - Total batch: {batch_time:.3f}s")
# Log validation images periodically
if batch_idx == 0:
wandb.log({
'val/hr_sample': wandb.Image(hr_pixel[0].cpu()),
'val/lr_sample': wandb.Image(lr_pixel[0].cpu()),
'val/sr_sample': wandb.Image(sr_pixel[0].cpu()),
'step': self.current_iters
})
# Compute average metrics and timing
val_total_time = time.time() - val_start
num_batches = batch_idx + 1
if num_samples > 0:
mean_psnr = total_psnr / num_samples
mean_ssim = total_ssim / num_samples
mean_lpips = total_lpips / num_samples
mean_val_loss = total_val_loss / num_samples
avg_sampling_time = total_sampling_time / num_batches
avg_forward_time = total_forward_time / num_batches
avg_decode_time = total_decode_time / num_batches
avg_metric_time = total_metric_time / num_batches
avg_batch_time = val_total_time / num_batches
print(f"\nValidation Metrics:")
print(f" - Loss: {mean_val_loss:.6f}")
print(f" - PSNR: {mean_psnr:.2f} dB")
print(f" - SSIM: {mean_ssim:.4f}")
print(f" - LPIPS: {mean_lpips:.4f}")
print(f"\nValidation Timing (Total: {val_total_time:.2f}s, {num_batches} batches):")
print(f" - Avg sampling loop: {avg_sampling_time:.3f}s/batch ({avg_forward_time:.3f}s forward)")
print(f" - Avg decoding: {avg_decode_time:.3f}s/batch")
print(f" - Avg metrics: {avg_metric_time:.3f}s/batch")
print(f" - Avg batch time: {avg_batch_time:.3f}s/batch")
wandb.log({
'val/loss': mean_val_loss,
'val/psnr': mean_psnr,
'val/ssim': mean_ssim,
'val/lpips': mean_lpips,
'val/total_time': val_total_time,
'val/avg_sampling_time': avg_sampling_time,
'val/avg_forward_time': avg_forward_time,
'val/avg_decode_time': avg_decode_time,
'val/avg_metric_time': avg_metric_time,
'val/avg_batch_time': avg_batch_time,
'val/num_batches': num_batches,
'val/num_samples': num_samples,
'step': self.current_iters
})
print("=" * 100)
# Set model back to training mode
self.model.train()
if self.ema_model is not None:
self.ema_model.train() # Keep in sync, but won't be used for training
# Note: Main training script is in train.py
# This file contains the Trainer class implementation
|