File size: 6,041 Bytes
3c45764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
Main training script for ResShift diffusion model.

This script initializes the Trainer class and runs the main training loop.
"""

import multiprocessing
# Fix CUDA multiprocessing: Set start method to 'spawn' for compatibility with CUDA
# This is required when using DataLoader with num_workers > 0 on systems where
# CUDA is initialized before worker processes are created (Colab, some Linux setups)
# Must be set before any CUDA initialization or DataLoader creation
try:
    multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
    # Start method already set (e.g., in another module), ignore
    pass

from trainer import Trainer
from config import (
    iterations, batch_size, microbatch, learning_rate,
    warmup_iterations, save_freq, log_freq, T, k, val_freq
)
import torch
import wandb


def train(resume_ckpt=None):
    """
    Main training loop that integrates all components.
    
    Training flow:
    1. Build model and dataloader
    2. Setup optimization
    3. Training loop:
       - Get batch from dataloader
       - Training step (forward, backward, optimizer step)
       - Adjust learning rate
       - Log metrics and images
       - Save checkpoints
    
    Args:
        resume_ckpt: Path to checkpoint file to resume from (optional)
    """
    # Initialize trainer
    trainer = Trainer(resume_ckpt=resume_ckpt)
    
    print("=" * 100)
    if resume_ckpt:
        print("Resuming Training")
    else:
        print("Starting Training")
    print("=" * 100)
    
    # Build model (Component 2)
    trainer.build_model()
    
    # Resume from checkpoint if provided (must be after model is built)
    if resume_ckpt:
        trainer.resume_from_ckpt(resume_ckpt)
    
    # Setup optimization (Component 1)
    trainer.setup_optimization()
    
    # Build dataloader (Component 3)
    trainer.build_dataloader()
    
    # Initialize training
    trainer.model.train()
    train_iter = iter(trainer.dataloaders['train'])
    
    print(f"\nTraining Configuration:")
    print(f"  - Total iterations: {iterations}")
    print(f"  - Batch size: {batch_size}")
    print(f"  - Micro-batch size: {microbatch}")
    print(f"  - Learning rate: {learning_rate}")
    print(f"  - Warmup iterations: {warmup_iterations}")
    print(f"  - Save frequency: {save_freq}")
    print(f"  - Log frequency: {log_freq}")
    print(f"  - Device: {trainer.device}")
    print("=" * 100)
    print("\nStarting training loop...\n")
    
    # Training loop
    for step in range(trainer.iters_start, iterations):
        trainer.current_iters = step + 1
        
        # Get batch from dataloader
        try:
            hr_latent, lr_latent = next(train_iter)
        except StopIteration:
            # Restart iterator if exhausted (shouldn't happen with infinite cycle, but safety)
            train_iter = iter(trainer.dataloaders['train'])
            hr_latent, lr_latent = next(train_iter)
        
        # Move to device
        hr_latent = hr_latent.to(trainer.device)
        lr_latent = lr_latent.to(trainer.device)
        
        # Training step (Component 5)
        # This handles: forward pass, backward pass, optimizer step, gradient accumulation
        loss, timing_dict = trainer.training_step(hr_latent, lr_latent)
        
        # Adjust learning rate (Component 6)
        trainer.adjust_lr()
        
        # Run validation (Component 9)
        if 'val' in trainer.dataloaders and trainer.current_iters % val_freq == 0:
            trainer.validation()
        
        # Store timing info for logging
        trainer._last_timing = timing_dict
        
        # Only recompute for logging if we're actually logging images
        # This avoids unnecessary computation when only logging loss
        if trainer.current_iters % log_freq[1] == 0:
            # Prepare data for logging (need x_t and pred for visualization)
            with torch.no_grad():
                residual = (lr_latent - hr_latent)
                t_log = torch.randint(0, T, (hr_latent.shape[0],)).to(trainer.device)
                epsilon_log = torch.randn_like(hr_latent)
                eta_t_log = trainer.eta[t_log]
                x_t_log = hr_latent + eta_t_log * residual + k * torch.sqrt(eta_t_log) * epsilon_log
                
                trainer.model.eval()
                # Model predicts x0 (clean HR latent), not noise
                x0_pred_log = trainer.model(x_t_log[0:1], t_log[0:1], lq=lr_latent[0:1])
                trainer.model.train()
            
            # Log training metrics and images (Component 8)
            trainer.log_step_train(
                loss=loss,
                hr_latent=hr_latent[0:1],
                lr_latent=lr_latent[0:1],
                x_t=x_t_log[0:1],
                pred=x0_pred_log,  # x0 prediction (clean HR latent)
                phase='train'
            )
        else:
            # Only log loss/metrics, no images
            trainer.log_step_train(
                loss=loss,
                hr_latent=hr_latent[0:1],
                lr_latent=lr_latent[0:1],
                x_t=None,  # Not needed when not logging images
                pred=None,  # Not needed when not logging images
                phase='train'
            )
        
        # Save checkpoint (Component 7)
        if trainer.current_iters % save_freq == 0:
            trainer.save_ckpt()
    
    # Final checkpoint
    print("\n" + "=" * 100)
    print("Training completed!")
    print("=" * 100)
    trainer.save_ckpt()
    print(f"Final checkpoint saved at iteration {trainer.current_iters}")
    
    # Finish WandB
    wandb.finish()


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Train ResShift diffusion model')
    parser.add_argument('--resume', type=str, default=None,
                        help='Path to checkpoint file to resume from (e.g., checkpoints/ckpts/model_10000.pth)')
    
    args = parser.parse_args()
    
    train(resume_ckpt=args.resume)