Spaces:
Running
Running
| #!/usr/bin/env python | |
| # -*- coding:utf-8 -*- | |
| """ | |
| Inference script for ResShift diffusion model. | |
| Performs super-resolution on LR images using full diffusion sampling. | |
| Consistent with original ResShift inference interface. | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import torchvision.transforms.functional as TF | |
| import numpy as np | |
| from tqdm import tqdm | |
| from model import FullUNET | |
| from autoencoder import get_vqgan | |
| from noiseControl import resshift_schedule | |
| from config import ( | |
| device, T, k, normalize_input, latent_flag, | |
| autoencoder_ckpt_path, _project_root, | |
| image_size, # Latent space size (64) | |
| gt_size, # Pixel space size (256) | |
| sf, # Scale factor (4) | |
| ) | |
| def get_parser(**parser_kwargs): | |
| """Parse command-line arguments.""" | |
| parser = argparse.ArgumentParser(**parser_kwargs) | |
| parser.add_argument( | |
| "-i", "--in_path", type=str, required=True, | |
| help="Input path (image file or directory)." | |
| ) | |
| parser.add_argument( | |
| "-o", "--out_path", type=str, default="./results", | |
| help="Output path (image file or directory)." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", type=str, required=True, | |
| help="Path to model checkpoint (e.g., checkpoints/ckpts/model_1500.pth)." | |
| ) | |
| parser.add_argument( | |
| "--ema_checkpoint", type=str, default=None, | |
| help="Path to EMA checkpoint (optional, e.g., checkpoints/ckpts/ema_model_1500.pth)." | |
| ) | |
| parser.add_argument( | |
| "--use_ema", action="store_true", | |
| help="Use EMA model for inference (requires --ema_checkpoint)." | |
| ) | |
| parser.add_argument( | |
| "--scale", type=int, default=4, | |
| help="Scale factor for SR (default: 4)." | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=12345, | |
| help="Random seed for reproducibility." | |
| ) | |
| parser.add_argument( | |
| "--bs", type=int, default=1, | |
| help="Batch size for inference." | |
| ) | |
| parser.add_argument( | |
| "--chop_size", type=int, default=512, | |
| choices=[512, 256, 64], | |
| help="Chopping size for large images (default: 512)." | |
| ) | |
| parser.add_argument( | |
| "--chop_stride", type=int, default=-1, | |
| help="Chopping stride (default: auto-calculated)." | |
| ) | |
| parser.add_argument( | |
| "--chop_bs", type=int, default=1, | |
| help="Batch size for chopping (default: 1)." | |
| ) | |
| parser.add_argument( | |
| "--use_amp", action="store_true", default=True, | |
| help="Use automatic mixed precision (default: True)." | |
| ) | |
| return parser.parse_args() | |
| def set_seed(seed): | |
| """Set random seed for reproducibility.""" | |
| import random | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def load_image(image_path): | |
| """ | |
| Load and preprocess image for inference. | |
| Args: | |
| image_path: Path to input image | |
| Returns: | |
| Preprocessed image tensor (1, 3, H, W) in [0, 1] range | |
| Original image size (H, W) | |
| """ | |
| # Load image | |
| img = Image.open(image_path).convert("RGB") | |
| orig_size = img.size # (W, H) | |
| # Calculate target size (LR should be downscaled by scale factor) | |
| # For 4x SR: if input is 256x256, it's already LR, output will be 1024x1024 | |
| # But we work in 256x256 pixel space, so we keep input at 256x256 | |
| target_size = gt_size # 256x256 | |
| # Resize to target size (bicubic interpolation) | |
| img = img.resize((target_size, target_size), Image.BICUBIC) | |
| # Convert to tensor and normalize to [0, 1] | |
| img_tensor = TF.to_tensor(img).unsqueeze(0) # (1, 3, H, W) | |
| return img_tensor, orig_size | |
| def save_image(tensor, save_path, orig_size=None): | |
| """ | |
| Save tensor image to file. | |
| Args: | |
| tensor: Image tensor (1, 3, H, W) in [0, 1] | |
| save_path: Path to save image | |
| orig_size: Original image size (W, H) for optional resize | |
| """ | |
| # Convert to PIL Image | |
| img = TF.to_pil_image(tensor.squeeze(0).cpu()) | |
| # Optionally resize to original size scaled by scale factor | |
| if orig_size is not None: | |
| target_size = (orig_size[0] * sf, orig_size[1] * sf) | |
| img = img.resize(target_size, Image.LANCZOS) | |
| # Save image | |
| save_path = Path(save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| img.save(save_path) | |
| print(f"✓ Saved SR image to: {save_path}") | |
| def _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag): | |
| """ | |
| Scale input based on timestep for training stability. | |
| Args: | |
| x_t: Noisy input tensor (B, C, H, W) | |
| t: Timestep tensor (B,) | |
| eta_schedule: Noise schedule (T, 1, 1, 1) | |
| k: Noise scaling factor | |
| normalize_input: Whether to normalize input | |
| latent_flag: Whether working in latent space | |
| Returns: | |
| Scaled input tensor | |
| """ | |
| if normalize_input and latent_flag: | |
| eta_t = eta_schedule[t] # (B, 1, 1, 1) | |
| std = torch.sqrt(eta_t * k**2 + 1) | |
| x_t_scaled = x_t / std | |
| else: | |
| x_t_scaled = x_t | |
| return x_t_scaled | |
| def inference_single_image( | |
| model, | |
| autoencoder, | |
| lr_image_tensor, | |
| eta_schedule, | |
| device, | |
| T=15, | |
| k=2.0, | |
| normalize_input=True, | |
| latent_flag=True, | |
| use_amp=False, | |
| ): | |
| """ | |
| Perform inference on a single LR image using full diffusion sampling. | |
| Args: | |
| model: Trained ResShift model | |
| autoencoder: VQGAN autoencoder for encoding/decoding | |
| lr_image_tensor: LR image tensor (1, 3, 256, 256) in [0, 1] | |
| eta_schedule: Noise schedule (T, 1, 1, 1) | |
| device: Device to run inference on | |
| T: Number of diffusion timesteps | |
| k: Noise scaling factor | |
| normalize_input: Whether to normalize input | |
| latent_flag: Whether working in latent space | |
| use_amp: Whether to use automatic mixed precision | |
| Returns: | |
| SR image tensor (1, 3, 256, 256) in [0, 1] | |
| """ | |
| model.eval() | |
| # Move to device | |
| lr_image_tensor = lr_image_tensor.to(device) | |
| # Autocast context | |
| if use_amp and torch.cuda.is_available(): | |
| autocast_context = torch.amp.autocast('cuda') | |
| else: | |
| from contextlib import nullcontext | |
| autocast_context = nullcontext() | |
| with torch.no_grad(): | |
| # Encode LR image to latent space | |
| lr_latent = autoencoder.encode(lr_image_tensor) # (1, 3, 64, 64) | |
| # Initialize x_t at maximum timestep (T-1) | |
| # Start from LR with maximum noise | |
| epsilon_init = torch.randn_like(lr_latent) | |
| eta_max = eta_schedule[T - 1] | |
| # Start from noisy LR | |
| x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init | |
| # Full diffusion sampling loop | |
| for t_step in range(T - 1, -1, -1): # T-1, T-2, ..., 1, 0 | |
| t = torch.full((lr_latent.shape[0],), t_step, device=device, dtype=torch.long) | |
| # Scale input if needed | |
| x_t_scaled = _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag) | |
| # Predict x0 from current noisy state | |
| with autocast_context: | |
| x0_pred = model(x_t_scaled, t, lq=lr_latent) | |
| # If not the last step, compute x_{t-1} from predicted x0 using equation (7) | |
| if t_step > 0: | |
| # Equation (7) from ResShift paper: | |
| # μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t) | |
| # Σ_θ = κ² * (η_{t-1}/η_t) * α_t | |
| # x_{t-1} = μ_θ + sqrt(Σ_θ) * ε | |
| eta_t = eta_schedule[t_step] | |
| eta_t_minus_1 = eta_schedule[t_step - 1] | |
| # Compute alpha_t = η_t - η_{t-1} | |
| alpha_t = eta_t - eta_t_minus_1 | |
| # Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred | |
| mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred | |
| # Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t | |
| variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t | |
| # Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε | |
| noise = torch.randn_like(x_t) | |
| nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1))) | |
| x_t = mean + nonzero_mask * torch.sqrt(variance) * noise | |
| else: | |
| # Final step: use predicted x0 | |
| x_t = x0_pred | |
| # Final prediction | |
| sr_latent = x_t | |
| # Decode back to pixel space | |
| sr_image = autoencoder.decode(sr_latent) # (1, 3, 256, 256) | |
| # Clamp to [0, 1] | |
| sr_image = sr_image.clamp(0, 1) | |
| return sr_image | |
| def inference_with_chopping( | |
| model, | |
| autoencoder, | |
| lr_image_tensor, | |
| eta_schedule, | |
| device, | |
| chop_size=512, | |
| chop_stride=448, | |
| chop_bs=1, | |
| T=15, | |
| k=2.0, | |
| normalize_input=True, | |
| latent_flag=True, | |
| use_amp=False, | |
| ): | |
| """ | |
| Perform inference with chopping for large images. | |
| Args: | |
| model: Trained ResShift model | |
| autoencoder: VQGAN autoencoder | |
| lr_image_tensor: LR image tensor (1, 3, H, W) | |
| eta_schedule: Noise schedule | |
| device: Device to run inference on | |
| chop_size: Size of each patch | |
| chop_stride: Stride between patches | |
| chop_bs: Batch size for chopping | |
| T: Number of diffusion timesteps | |
| k: Noise scaling factor | |
| normalize_input: Whether to normalize input | |
| latent_flag: Whether working in latent space | |
| use_amp: Whether to use AMP | |
| Returns: | |
| SR image tensor (1, 3, H*sf, W*sf) | |
| """ | |
| # For now, implement simple version without chopping | |
| # Full chopping implementation would require more complex logic | |
| # This is a placeholder that processes the full image | |
| return inference_single_image( | |
| model, autoencoder, lr_image_tensor, eta_schedule, | |
| device, T, k, normalize_input, latent_flag, use_amp | |
| ) | |
| def load_model(checkpoint_path, ema_checkpoint_path=None, use_ema=False, device=device): | |
| """ | |
| Load model from checkpoint. | |
| Args: | |
| checkpoint_path: Path to model checkpoint | |
| ema_checkpoint_path: Path to EMA checkpoint (optional) | |
| use_ema: Whether to use EMA model | |
| device: Device to load model on | |
| Returns: | |
| Loaded model | |
| """ | |
| print(f"Loading model from: {checkpoint_path}") | |
| model = FullUNET() | |
| model = model.to(device) | |
| # Load checkpoint | |
| ckpt = torch.load(checkpoint_path, map_location=device) | |
| if 'state_dict' in ckpt: | |
| state_dict = ckpt['state_dict'] | |
| else: | |
| state_dict = ckpt | |
| # Handle compiled model checkpoints (strip _orig_mod. prefix) | |
| if any(k.startswith('_orig_mod.') for k in state_dict.keys()): | |
| print(" Detected compiled model checkpoint, stripping _orig_mod. prefix...") | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith('_orig_mod.'): | |
| new_state_dict[k[10:]] = v # Remove '_orig_mod.' prefix | |
| else: | |
| new_state_dict[k] = v | |
| state_dict = new_state_dict | |
| model.load_state_dict(state_dict) | |
| print("✓ Model loaded") | |
| # Load EMA if requested | |
| if use_ema and ema_checkpoint_path: | |
| print(f"Loading EMA model from: {ema_checkpoint_path}") | |
| from ema import EMA | |
| ema = EMA(model, ema_rate=0.999, device=device) | |
| ema_ckpt = torch.load(ema_checkpoint_path, map_location=device) | |
| # Handle compiled model checkpoints (strip _orig_mod. prefix) | |
| if any(k.startswith('_orig_mod.') for k in ema_ckpt.keys()): | |
| print(" Detected compiled model in EMA checkpoint, stripping _orig_mod. prefix...") | |
| new_ema_ckpt = {} | |
| for k, v in ema_ckpt.items(): | |
| if k.startswith('_orig_mod.'): | |
| new_ema_ckpt[k[10:]] = v # Remove '_orig_mod.' prefix | |
| else: | |
| new_ema_ckpt[k] = v | |
| ema_ckpt = new_ema_ckpt | |
| ema.load_state_dict(ema_ckpt) | |
| ema.apply_to_model(model) | |
| print("✓ EMA model loaded and applied") | |
| return model | |
| def main(): | |
| args = get_parser() | |
| print("=" * 80) | |
| print("ResShift Inference") | |
| print("=" * 80) | |
| # Set random seed | |
| set_seed(args.seed) | |
| # Validate scale factor | |
| assert args.scale == 4, "We only support 4x super-resolution now!" | |
| # Calculate chopping stride if not provided | |
| if args.chop_stride < 0: | |
| if args.chop_size == 512: | |
| chop_stride = (512 - 64) * (4 // args.scale) | |
| elif args.chop_size == 256: | |
| chop_stride = (256 - 32) * (4 // args.scale) | |
| elif args.chop_size == 64: | |
| chop_stride = (64 - 16) * (4 // args.scale) | |
| else: | |
| raise ValueError("Chop size must be in [512, 256, 64]") | |
| else: | |
| chop_stride = args.chop_stride * (4 // args.scale) | |
| chop_size = args.chop_size * (4 // args.scale) | |
| print(f"Chopping size/stride: {chop_size}/{chop_stride}") | |
| # Load model | |
| model = load_model( | |
| args.checkpoint, | |
| args.ema_checkpoint, | |
| args.use_ema, | |
| device | |
| ) | |
| # Load VQGAN autoencoder | |
| print("\nLoading VQGAN autoencoder...") | |
| autoencoder = get_vqgan() | |
| print("✓ VQGAN autoencoder loaded") | |
| # Initialize noise schedule | |
| print("\nInitializing noise schedule...") | |
| eta = resshift_schedule().to(device) | |
| eta = eta[:, None, None, None] # (T, 1, 1, 1) | |
| print("✓ Noise schedule initialized") | |
| # Prepare input/output paths | |
| in_path = Path(args.in_path) | |
| out_path = Path(args.out_path) | |
| # Determine if input is file or directory | |
| if in_path.is_file(): | |
| input_files = [in_path] | |
| if out_path.suffix: # Output is a file | |
| output_files = [out_path] | |
| else: # Output is a directory | |
| output_files = [out_path / in_path.name] | |
| elif in_path.is_dir(): | |
| # Get all image files from directory | |
| image_extensions = {'.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'} | |
| input_files = [f for f in in_path.iterdir() if f.suffix.lower() in image_extensions] | |
| output_files = [out_path / f.name for f in input_files] | |
| out_path.mkdir(parents=True, exist_ok=True) | |
| else: | |
| raise ValueError(f"Input path does not exist: {in_path}") | |
| if not input_files: | |
| raise ValueError(f"No image files found in: {in_path}") | |
| print(f"\nFound {len(input_files)} image(s) to process") | |
| # Process each image | |
| print("\n" + "=" * 80) | |
| print("Running Inference") | |
| print("=" * 80) | |
| for idx, (input_file, output_file) in enumerate(zip(input_files, output_files), 1): | |
| print(f"\n[{idx}/{len(input_files)}] Processing: {input_file.name}") | |
| # Load input image | |
| lr_image, orig_size = load_image(input_file) | |
| # Run inference | |
| if args.chop_size < 512: # Use chopping for large images | |
| sr_image = inference_with_chopping( | |
| model=model, | |
| autoencoder=autoencoder, | |
| lr_image_tensor=lr_image, | |
| eta_schedule=eta, | |
| device=device, | |
| chop_size=chop_size, | |
| chop_stride=chop_stride, | |
| chop_bs=args.chop_bs, | |
| T=T, | |
| k=k, | |
| normalize_input=normalize_input, | |
| latent_flag=latent_flag, | |
| use_amp=args.use_amp, | |
| ) | |
| else: | |
| sr_image = inference_single_image( | |
| model=model, | |
| autoencoder=autoencoder, | |
| lr_image_tensor=lr_image, | |
| eta_schedule=eta, | |
| device=device, | |
| T=T, | |
| k=k, | |
| normalize_input=normalize_input, | |
| latent_flag=latent_flag, | |
| use_amp=args.use_amp, | |
| ) | |
| # Save output | |
| save_image(sr_image, output_file, orig_size=orig_size) | |
| print("\n" + "=" * 80) | |
| print("Inference Complete!") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| main() | |