DiffusionSR / src /inference.py
shekkari21's picture
Commiting all the super resolution files
3c45764
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
Inference script for ResShift diffusion model.
Performs super-resolution on LR images using full diffusion sampling.
Consistent with original ResShift inference interface.
"""
import os
import sys
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms.functional as TF
import numpy as np
from tqdm import tqdm
from model import FullUNET
from autoencoder import get_vqgan
from noiseControl import resshift_schedule
from config import (
device, T, k, normalize_input, latent_flag,
autoencoder_ckpt_path, _project_root,
image_size, # Latent space size (64)
gt_size, # Pixel space size (256)
sf, # Scale factor (4)
)
def get_parser(**parser_kwargs):
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-i", "--in_path", type=str, required=True,
help="Input path (image file or directory)."
)
parser.add_argument(
"-o", "--out_path", type=str, default="./results",
help="Output path (image file or directory)."
)
parser.add_argument(
"--checkpoint", type=str, required=True,
help="Path to model checkpoint (e.g., checkpoints/ckpts/model_1500.pth)."
)
parser.add_argument(
"--ema_checkpoint", type=str, default=None,
help="Path to EMA checkpoint (optional, e.g., checkpoints/ckpts/ema_model_1500.pth)."
)
parser.add_argument(
"--use_ema", action="store_true",
help="Use EMA model for inference (requires --ema_checkpoint)."
)
parser.add_argument(
"--scale", type=int, default=4,
help="Scale factor for SR (default: 4)."
)
parser.add_argument(
"--seed", type=int, default=12345,
help="Random seed for reproducibility."
)
parser.add_argument(
"--bs", type=int, default=1,
help="Batch size for inference."
)
parser.add_argument(
"--chop_size", type=int, default=512,
choices=[512, 256, 64],
help="Chopping size for large images (default: 512)."
)
parser.add_argument(
"--chop_stride", type=int, default=-1,
help="Chopping stride (default: auto-calculated)."
)
parser.add_argument(
"--chop_bs", type=int, default=1,
help="Batch size for chopping (default: 1)."
)
parser.add_argument(
"--use_amp", action="store_true", default=True,
help="Use automatic mixed precision (default: True)."
)
return parser.parse_args()
def set_seed(seed):
"""Set random seed for reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_image(image_path):
"""
Load and preprocess image for inference.
Args:
image_path: Path to input image
Returns:
Preprocessed image tensor (1, 3, H, W) in [0, 1] range
Original image size (H, W)
"""
# Load image
img = Image.open(image_path).convert("RGB")
orig_size = img.size # (W, H)
# Calculate target size (LR should be downscaled by scale factor)
# For 4x SR: if input is 256x256, it's already LR, output will be 1024x1024
# But we work in 256x256 pixel space, so we keep input at 256x256
target_size = gt_size # 256x256
# Resize to target size (bicubic interpolation)
img = img.resize((target_size, target_size), Image.BICUBIC)
# Convert to tensor and normalize to [0, 1]
img_tensor = TF.to_tensor(img).unsqueeze(0) # (1, 3, H, W)
return img_tensor, orig_size
def save_image(tensor, save_path, orig_size=None):
"""
Save tensor image to file.
Args:
tensor: Image tensor (1, 3, H, W) in [0, 1]
save_path: Path to save image
orig_size: Original image size (W, H) for optional resize
"""
# Convert to PIL Image
img = TF.to_pil_image(tensor.squeeze(0).cpu())
# Optionally resize to original size scaled by scale factor
if orig_size is not None:
target_size = (orig_size[0] * sf, orig_size[1] * sf)
img = img.resize(target_size, Image.LANCZOS)
# Save image
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
img.save(save_path)
print(f"✓ Saved SR image to: {save_path}")
def _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag):
"""
Scale input based on timestep for training stability.
Args:
x_t: Noisy input tensor (B, C, H, W)
t: Timestep tensor (B,)
eta_schedule: Noise schedule (T, 1, 1, 1)
k: Noise scaling factor
normalize_input: Whether to normalize input
latent_flag: Whether working in latent space
Returns:
Scaled input tensor
"""
if normalize_input and latent_flag:
eta_t = eta_schedule[t] # (B, 1, 1, 1)
std = torch.sqrt(eta_t * k**2 + 1)
x_t_scaled = x_t / std
else:
x_t_scaled = x_t
return x_t_scaled
def inference_single_image(
model,
autoencoder,
lr_image_tensor,
eta_schedule,
device,
T=15,
k=2.0,
normalize_input=True,
latent_flag=True,
use_amp=False,
):
"""
Perform inference on a single LR image using full diffusion sampling.
Args:
model: Trained ResShift model
autoencoder: VQGAN autoencoder for encoding/decoding
lr_image_tensor: LR image tensor (1, 3, 256, 256) in [0, 1]
eta_schedule: Noise schedule (T, 1, 1, 1)
device: Device to run inference on
T: Number of diffusion timesteps
k: Noise scaling factor
normalize_input: Whether to normalize input
latent_flag: Whether working in latent space
use_amp: Whether to use automatic mixed precision
Returns:
SR image tensor (1, 3, 256, 256) in [0, 1]
"""
model.eval()
# Move to device
lr_image_tensor = lr_image_tensor.to(device)
# Autocast context
if use_amp and torch.cuda.is_available():
autocast_context = torch.amp.autocast('cuda')
else:
from contextlib import nullcontext
autocast_context = nullcontext()
with torch.no_grad():
# Encode LR image to latent space
lr_latent = autoencoder.encode(lr_image_tensor) # (1, 3, 64, 64)
# Initialize x_t at maximum timestep (T-1)
# Start from LR with maximum noise
epsilon_init = torch.randn_like(lr_latent)
eta_max = eta_schedule[T - 1]
# Start from noisy LR
x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init
# Full diffusion sampling loop
for t_step in range(T - 1, -1, -1): # T-1, T-2, ..., 1, 0
t = torch.full((lr_latent.shape[0],), t_step, device=device, dtype=torch.long)
# Scale input if needed
x_t_scaled = _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag)
# Predict x0 from current noisy state
with autocast_context:
x0_pred = model(x_t_scaled, t, lq=lr_latent)
# If not the last step, compute x_{t-1} from predicted x0 using equation (7)
if t_step > 0:
# Equation (7) from ResShift paper:
# μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
# Σ_θ = κ² * (η_{t-1}/η_t) * α_t
# x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
eta_t = eta_schedule[t_step]
eta_t_minus_1 = eta_schedule[t_step - 1]
# Compute alpha_t = η_t - η_{t-1}
alpha_t = eta_t - eta_t_minus_1
# Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred
mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
# Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t
# Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
noise = torch.randn_like(x_t)
nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1)))
x_t = mean + nonzero_mask * torch.sqrt(variance) * noise
else:
# Final step: use predicted x0
x_t = x0_pred
# Final prediction
sr_latent = x_t
# Decode back to pixel space
sr_image = autoencoder.decode(sr_latent) # (1, 3, 256, 256)
# Clamp to [0, 1]
sr_image = sr_image.clamp(0, 1)
return sr_image
def inference_with_chopping(
model,
autoencoder,
lr_image_tensor,
eta_schedule,
device,
chop_size=512,
chop_stride=448,
chop_bs=1,
T=15,
k=2.0,
normalize_input=True,
latent_flag=True,
use_amp=False,
):
"""
Perform inference with chopping for large images.
Args:
model: Trained ResShift model
autoencoder: VQGAN autoencoder
lr_image_tensor: LR image tensor (1, 3, H, W)
eta_schedule: Noise schedule
device: Device to run inference on
chop_size: Size of each patch
chop_stride: Stride between patches
chop_bs: Batch size for chopping
T: Number of diffusion timesteps
k: Noise scaling factor
normalize_input: Whether to normalize input
latent_flag: Whether working in latent space
use_amp: Whether to use AMP
Returns:
SR image tensor (1, 3, H*sf, W*sf)
"""
# For now, implement simple version without chopping
# Full chopping implementation would require more complex logic
# This is a placeholder that processes the full image
return inference_single_image(
model, autoencoder, lr_image_tensor, eta_schedule,
device, T, k, normalize_input, latent_flag, use_amp
)
def load_model(checkpoint_path, ema_checkpoint_path=None, use_ema=False, device=device):
"""
Load model from checkpoint.
Args:
checkpoint_path: Path to model checkpoint
ema_checkpoint_path: Path to EMA checkpoint (optional)
use_ema: Whether to use EMA model
device: Device to load model on
Returns:
Loaded model
"""
print(f"Loading model from: {checkpoint_path}")
model = FullUNET()
model = model.to(device)
# Load checkpoint
ckpt = torch.load(checkpoint_path, map_location=device)
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
else:
state_dict = ckpt
# Handle compiled model checkpoints (strip _orig_mod. prefix)
if any(k.startswith('_orig_mod.') for k in state_dict.keys()):
print(" Detected compiled model checkpoint, stripping _orig_mod. prefix...")
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('_orig_mod.'):
new_state_dict[k[10:]] = v # Remove '_orig_mod.' prefix
else:
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict)
print("✓ Model loaded")
# Load EMA if requested
if use_ema and ema_checkpoint_path:
print(f"Loading EMA model from: {ema_checkpoint_path}")
from ema import EMA
ema = EMA(model, ema_rate=0.999, device=device)
ema_ckpt = torch.load(ema_checkpoint_path, map_location=device)
# Handle compiled model checkpoints (strip _orig_mod. prefix)
if any(k.startswith('_orig_mod.') for k in ema_ckpt.keys()):
print(" Detected compiled model in EMA checkpoint, stripping _orig_mod. prefix...")
new_ema_ckpt = {}
for k, v in ema_ckpt.items():
if k.startswith('_orig_mod.'):
new_ema_ckpt[k[10:]] = v # Remove '_orig_mod.' prefix
else:
new_ema_ckpt[k] = v
ema_ckpt = new_ema_ckpt
ema.load_state_dict(ema_ckpt)
ema.apply_to_model(model)
print("✓ EMA model loaded and applied")
return model
def main():
args = get_parser()
print("=" * 80)
print("ResShift Inference")
print("=" * 80)
# Set random seed
set_seed(args.seed)
# Validate scale factor
assert args.scale == 4, "We only support 4x super-resolution now!"
# Calculate chopping stride if not provided
if args.chop_stride < 0:
if args.chop_size == 512:
chop_stride = (512 - 64) * (4 // args.scale)
elif args.chop_size == 256:
chop_stride = (256 - 32) * (4 // args.scale)
elif args.chop_size == 64:
chop_stride = (64 - 16) * (4 // args.scale)
else:
raise ValueError("Chop size must be in [512, 256, 64]")
else:
chop_stride = args.chop_stride * (4 // args.scale)
chop_size = args.chop_size * (4 // args.scale)
print(f"Chopping size/stride: {chop_size}/{chop_stride}")
# Load model
model = load_model(
args.checkpoint,
args.ema_checkpoint,
args.use_ema,
device
)
# Load VQGAN autoencoder
print("\nLoading VQGAN autoencoder...")
autoencoder = get_vqgan()
print("✓ VQGAN autoencoder loaded")
# Initialize noise schedule
print("\nInitializing noise schedule...")
eta = resshift_schedule().to(device)
eta = eta[:, None, None, None] # (T, 1, 1, 1)
print("✓ Noise schedule initialized")
# Prepare input/output paths
in_path = Path(args.in_path)
out_path = Path(args.out_path)
# Determine if input is file or directory
if in_path.is_file():
input_files = [in_path]
if out_path.suffix: # Output is a file
output_files = [out_path]
else: # Output is a directory
output_files = [out_path / in_path.name]
elif in_path.is_dir():
# Get all image files from directory
image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'}
input_files = [f for f in in_path.iterdir() if f.suffix.lower() in image_extensions]
output_files = [out_path / f.name for f in input_files]
out_path.mkdir(parents=True, exist_ok=True)
else:
raise ValueError(f"Input path does not exist: {in_path}")
if not input_files:
raise ValueError(f"No image files found in: {in_path}")
print(f"\nFound {len(input_files)} image(s) to process")
# Process each image
print("\n" + "=" * 80)
print("Running Inference")
print("=" * 80)
for idx, (input_file, output_file) in enumerate(zip(input_files, output_files), 1):
print(f"\n[{idx}/{len(input_files)}] Processing: {input_file.name}")
# Load input image
lr_image, orig_size = load_image(input_file)
# Run inference
if args.chop_size < 512: # Use chopping for large images
sr_image = inference_with_chopping(
model=model,
autoencoder=autoencoder,
lr_image_tensor=lr_image,
eta_schedule=eta,
device=device,
chop_size=chop_size,
chop_stride=chop_stride,
chop_bs=args.chop_bs,
T=T,
k=k,
normalize_input=normalize_input,
latent_flag=latent_flag,
use_amp=args.use_amp,
)
else:
sr_image = inference_single_image(
model=model,
autoencoder=autoencoder,
lr_image_tensor=lr_image,
eta_schedule=eta,
device=device,
T=T,
k=k,
normalize_input=normalize_input,
latent_flag=latent_flag,
use_amp=args.use_amp,
)
# Save output
save_image(sr_image, output_file, orig_size=orig_size)
print("\n" + "=" * 80)
print("Inference Complete!")
print("=" * 80)
if __name__ == "__main__":
main()