File size: 3,223 Bytes
3c45764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
from model import FullUNET
from noiseControl import resshift_schedule
from torch.utils.data import DataLoader
from data import mini_dataset, train_dataset, get_vqgan_model
import torch.optim as optim
from config import (batch_size, device, learning_rate, iterations, 
                    weight_decay, T, k, _project_root)
import wandb
import os
from dotenv import load_dotenv

# Load environment variables from .env file (looks for .env in project root)
load_dotenv(os.path.join(_project_root, '.env'))

wandb.init(
    project="diffusionsr",
    name="reshift_training",
    config={
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "steps": iterations,
        "model": "ResShift",
        "T": T,
        "k": k,
        "optimizer": "Adam",
        "betas": (0.9, 0.999),
        "grad_clip": 1.0,
        "criterion": "MSE",
        "device": str(device),
        "training_space": "latent_64x64"
    }
)

# Load VQGAN for decoding latents for visualization
vqgan = get_vqgan_model()

train_dl = DataLoader(mini_dataset, batch_size=batch_size, shuffle=True)

# Get a batch - now returns 64x64 latents
hr_latent, lr_latent = next(iter(train_dl))

hr_latent = hr_latent.to(device)  # (B, C, 64, 64) - HR latent
lr_latent = lr_latent.to(device)  # (B, C, 64, 64) - LR latent

eta = resshift_schedule().to(device)
eta = eta[:, None, None, None]   # shape (T,1,1,1)
residual = (lr_latent - hr_latent)  # Residual in latent space
model = FullUNET()
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay)
steps = iterations

# Watch model for gradients/parameters
wandb.watch(model, log="all", log_freq=10)
for step in range(steps):
    model.train()
    # take random timestep (0 to T-1)
    t = torch.randint(0, T, (batch_size,)).to(device)

    # add the noise in latent space
    epsilon = torch.randn_like(hr_latent)  # Noise in latent space
    eta_t = eta[t]
    x_t = hr_latent + eta_t * residual + k * torch.sqrt(eta_t) * epsilon
    # send the same patch in model forwardpass across different timestamps per each step
    # lr_latent is the low-resolution latent used for conditioning
    pred = model(x_t, t, lq=lr_latent)
    optimizer.zero_grad()
    loss = criterion(pred, epsilon)
    wandb.log({
    "loss": loss.item(),
    "step": step,
    "learning_rate": optimizer.param_groups[0]['lr']
    })
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    if step % 50 == 0:
        # Decode latents to pixel space for visualization
        with torch.no_grad():
            hr_pixel = vqgan.decode(hr_latent[0:1])  # (1, 3, 256, 256)
            lr_pixel = vqgan.decode(lr_latent[0:1])  # (1, 3, 256, 256)
            pred_pixel = vqgan.decode(x_t[0:1])  # (1, 3, 256, 256)
        
        wandb.log({
            "hr_sample": wandb.Image(hr_pixel[0].cpu().clamp(0, 1)),
            "lr_sample": wandb.Image(lr_pixel[0].cpu().clamp(0, 1)),
            "pred_sample": wandb.Image(pred_pixel[0].cpu().clamp(0, 1))
        })
    print(f'loss at step {step + 1} is {loss}')

wandb.finish()