Spaces:
Running
Running
File size: 6,041 Bytes
3c45764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""
Main training script for ResShift diffusion model.
This script initializes the Trainer class and runs the main training loop.
"""
import multiprocessing
# Fix CUDA multiprocessing: Set start method to 'spawn' for compatibility with CUDA
# This is required when using DataLoader with num_workers > 0 on systems where
# CUDA is initialized before worker processes are created (Colab, some Linux setups)
# Must be set before any CUDA initialization or DataLoader creation
try:
multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
# Start method already set (e.g., in another module), ignore
pass
from trainer import Trainer
from config import (
iterations, batch_size, microbatch, learning_rate,
warmup_iterations, save_freq, log_freq, T, k, val_freq
)
import torch
import wandb
def train(resume_ckpt=None):
"""
Main training loop that integrates all components.
Training flow:
1. Build model and dataloader
2. Setup optimization
3. Training loop:
- Get batch from dataloader
- Training step (forward, backward, optimizer step)
- Adjust learning rate
- Log metrics and images
- Save checkpoints
Args:
resume_ckpt: Path to checkpoint file to resume from (optional)
"""
# Initialize trainer
trainer = Trainer(resume_ckpt=resume_ckpt)
print("=" * 100)
if resume_ckpt:
print("Resuming Training")
else:
print("Starting Training")
print("=" * 100)
# Build model (Component 2)
trainer.build_model()
# Resume from checkpoint if provided (must be after model is built)
if resume_ckpt:
trainer.resume_from_ckpt(resume_ckpt)
# Setup optimization (Component 1)
trainer.setup_optimization()
# Build dataloader (Component 3)
trainer.build_dataloader()
# Initialize training
trainer.model.train()
train_iter = iter(trainer.dataloaders['train'])
print(f"\nTraining Configuration:")
print(f" - Total iterations: {iterations}")
print(f" - Batch size: {batch_size}")
print(f" - Micro-batch size: {microbatch}")
print(f" - Learning rate: {learning_rate}")
print(f" - Warmup iterations: {warmup_iterations}")
print(f" - Save frequency: {save_freq}")
print(f" - Log frequency: {log_freq}")
print(f" - Device: {trainer.device}")
print("=" * 100)
print("\nStarting training loop...\n")
# Training loop
for step in range(trainer.iters_start, iterations):
trainer.current_iters = step + 1
# Get batch from dataloader
try:
hr_latent, lr_latent = next(train_iter)
except StopIteration:
# Restart iterator if exhausted (shouldn't happen with infinite cycle, but safety)
train_iter = iter(trainer.dataloaders['train'])
hr_latent, lr_latent = next(train_iter)
# Move to device
hr_latent = hr_latent.to(trainer.device)
lr_latent = lr_latent.to(trainer.device)
# Training step (Component 5)
# This handles: forward pass, backward pass, optimizer step, gradient accumulation
loss, timing_dict = trainer.training_step(hr_latent, lr_latent)
# Adjust learning rate (Component 6)
trainer.adjust_lr()
# Run validation (Component 9)
if 'val' in trainer.dataloaders and trainer.current_iters % val_freq == 0:
trainer.validation()
# Store timing info for logging
trainer._last_timing = timing_dict
# Only recompute for logging if we're actually logging images
# This avoids unnecessary computation when only logging loss
if trainer.current_iters % log_freq[1] == 0:
# Prepare data for logging (need x_t and pred for visualization)
with torch.no_grad():
residual = (lr_latent - hr_latent)
t_log = torch.randint(0, T, (hr_latent.shape[0],)).to(trainer.device)
epsilon_log = torch.randn_like(hr_latent)
eta_t_log = trainer.eta[t_log]
x_t_log = hr_latent + eta_t_log * residual + k * torch.sqrt(eta_t_log) * epsilon_log
trainer.model.eval()
# Model predicts x0 (clean HR latent), not noise
x0_pred_log = trainer.model(x_t_log[0:1], t_log[0:1], lq=lr_latent[0:1])
trainer.model.train()
# Log training metrics and images (Component 8)
trainer.log_step_train(
loss=loss,
hr_latent=hr_latent[0:1],
lr_latent=lr_latent[0:1],
x_t=x_t_log[0:1],
pred=x0_pred_log, # x0 prediction (clean HR latent)
phase='train'
)
else:
# Only log loss/metrics, no images
trainer.log_step_train(
loss=loss,
hr_latent=hr_latent[0:1],
lr_latent=lr_latent[0:1],
x_t=None, # Not needed when not logging images
pred=None, # Not needed when not logging images
phase='train'
)
# Save checkpoint (Component 7)
if trainer.current_iters % save_freq == 0:
trainer.save_ckpt()
# Final checkpoint
print("\n" + "=" * 100)
print("Training completed!")
print("=" * 100)
trainer.save_ckpt()
print(f"Final checkpoint saved at iteration {trainer.current_iters}")
# Finish WandB
wandb.finish()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Train ResShift diffusion model')
parser.add_argument('--resume', type=str, default=None,
help='Path to checkpoint file to resume from (e.g., checkpoints/ckpts/model_10000.pth)')
args = parser.parse_args()
train(resume_ckpt=args.resume)
|