Spaces:
Sleeping
Sleeping
File size: 9,633 Bytes
3c45764 efb85e3 3c45764 efb85e3 3c45764 efb85e3 3c45764 efb85e3 3c45764 efb85e3 3c45764 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
"""
Gradio app for ResShift Super-Resolution
Hosted on Hugging Face Spaces
"""
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pathlib import Path
import sys
from huggingface_hub import hf_hub_download
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))
from model import FullUNET
from autoencoder import get_vqgan
from noiseControl import resshift_schedule
from config import device, T, k, normalize_input, latent_flag, gt_size
# Hugging Face repo ID for weights
HF_WEIGHTS_REPO_ID = "shekkari21/DiffusionSR-weights"
# Global variables for loaded models
model = None
autoencoder = None
eta_schedule = None
def load_models():
"""Load models on startup."""
global model, autoencoder, eta_schedule
print("Loading models...")
# Load model checkpoint
checkpoint_path = "checkpoints/ckpts/model_3200.pth"
checkpoint_file = Path(checkpoint_path)
# Download from Hugging Face if not found locally
if not checkpoint_file.exists():
# Try to find any checkpoint locally first
ckpt_dir = Path("checkpoints/ckpts")
if ckpt_dir.exists():
checkpoints = list(ckpt_dir.glob("model_*.pth"))
if checkpoints:
checkpoint_path = str(checkpoints[-1]) # Use latest
print(f"Using checkpoint: {checkpoint_path}")
else:
# Download from Hugging Face
print(f"Checkpoint not found locally. Downloading from Hugging Face...")
try:
# Files are in root of weights repo, download to local directory structure
ckpt_dir.mkdir(parents=True, exist_ok=True)
downloaded_path = hf_hub_download(
repo_id=HF_WEIGHTS_REPO_ID,
filename="model_3200.pth",
local_dir=str(ckpt_dir),
local_dir_use_symlinks=False
)
checkpoint_path = str(ckpt_dir / "model_3200.pth")
print(f"✓ Downloaded checkpoint: {checkpoint_path}")
except Exception as e:
raise FileNotFoundError(
f"Could not download checkpoint from Hugging Face: {e}\n"
f"Please ensure the file exists in the repository."
)
else:
# Create directory and download
ckpt_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint not found locally. Downloading from Hugging Face...")
try:
downloaded_path = hf_hub_download(
repo_id=HF_WEIGHTS_REPO_ID,
filename="model_3200.pth",
local_dir=str(ckpt_dir),
local_dir_use_symlinks=False
)
checkpoint_path = str(ckpt_dir / "model_3200.pth")
print(f"✓ Downloaded checkpoint: {checkpoint_path}")
except Exception as e:
raise FileNotFoundError(
f"Could not download checkpoint from Hugging Face: {e}\n"
f"Please ensure the file exists in the repository."
)
model = FullUNET()
model = model.to(device)
ckpt = torch.load(checkpoint_path, map_location=device)
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
else:
state_dict = ckpt
# Handle compiled model checkpoints
if any(key.startswith('_orig_mod.') for key in state_dict.keys()):
new_state_dict = {}
for key, val in state_dict.items():
if key.startswith('_orig_mod.'):
new_state_dict[key[10:]] = val
else:
new_state_dict[key] = val
state_dict = new_state_dict
model.load_state_dict(state_dict)
model.eval()
print("✓ Model loaded")
# Load VQGAN autoencoder
autoencoder = get_vqgan()
print("✓ VQGAN autoencoder loaded")
# Initialize noise schedule
eta_schedule = resshift_schedule().to(device)
eta_schedule = eta_schedule[:, None, None, None]
print("✓ Noise schedule initialized")
return "Models loaded successfully!"
def _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag):
"""Scale input based on timestep."""
if normalize_input and latent_flag:
eta_t = eta_schedule[t]
std = torch.sqrt(eta_t * k**2 + 1)
x_t_scaled = x_t / std
else:
x_t_scaled = x_t
return x_t_scaled
def super_resolve(input_image):
"""
Perform super-resolution on input image.
Args:
input_image: PIL Image or numpy array
Returns:
PIL Image of super-resolved output
"""
if input_image is None:
return None
if model is None or autoencoder is None:
return None
try:
# Convert to PIL Image if needed
if isinstance(input_image, Image.Image):
img = input_image
else:
img = Image.fromarray(input_image)
# Resize to target size (256x256)
img = img.resize((gt_size, gt_size), Image.BICUBIC)
# Convert to tensor
img_tensor = TF.to_tensor(img).unsqueeze(0).to(device) # (1, 3, 256, 256)
# Run inference
with torch.no_grad():
# Encode to latent space
lr_latent = autoencoder.encode(img_tensor) # (1, 3, 64, 64)
# Initialize x_t at maximum timestep
epsilon_init = torch.randn_like(lr_latent)
eta_max = eta_schedule[T - 1]
x_t = lr_latent + k * torch.sqrt(eta_max) * epsilon_init
# Full diffusion sampling loop
for t_step in range(T - 1, -1, -1):
t = torch.full((lr_latent.shape[0],), t_step, device=device, dtype=torch.long)
# Scale input
x_t_scaled = _scale_input(x_t, t, eta_schedule, k, normalize_input, latent_flag)
# Predict x0
x0_pred = model(x_t_scaled, t, lq=lr_latent)
# Compute x_{t-1} using equation (7)
if t_step > 0:
# Equation (7) from ResShift paper:
# μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * f_θ(x_t, y_0, t)
# Σ_θ = κ² * (η_{t-1}/η_t) * α_t
# x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
eta_t = eta_schedule[t_step]
eta_t_minus_1 = eta_schedule[t_step - 1]
# Compute alpha_t = η_t - η_{t-1}
alpha_t = eta_t - eta_t_minus_1
# Compute mean: μ_θ = (η_{t-1}/η_t) * x_t + (α_t/η_t) * x0_pred
mean = (eta_t_minus_1 / eta_t) * x_t + (alpha_t / eta_t) * x0_pred
# Compute variance: Σ_θ = κ² * (η_{t-1}/η_t) * α_t
variance = k**2 * (eta_t_minus_1 / eta_t) * alpha_t
# Sample: x_{t-1} = μ_θ + sqrt(Σ_θ) * ε
noise = torch.randn_like(x_t)
nonzero_mask = torch.tensor(1.0 if t_step > 0 else 0.0, device=x_t.device).view(-1, *([1] * (len(x_t.shape) - 1)))
x_t = mean + nonzero_mask * torch.sqrt(variance) * noise
else:
x_t = x0_pred
# Decode back to pixel space
sr_latent = x_t
sr_image = autoencoder.decode(sr_latent) # (1, 3, 256, 256)
sr_image = sr_image.clamp(0, 1)
# Convert to PIL Image
sr_pil = TF.to_pil_image(sr_image.squeeze(0).cpu())
return sr_pil
except Exception as e:
print(f"Error during inference: {str(e)}")
import traceback
traceback.print_exc()
return None
# Create Gradio interface
with gr.Blocks(title="ResShift Super-Resolution") as demo:
gr.Markdown(
"""
# ResShift Super-Resolution
Upload a low-resolution image to get a super-resolved version using ResShift diffusion model.
**Note**: The model performs 4x super-resolution in latent space (256x256 → 256x256 pixel space, but with enhanced quality).
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image (Low Resolution)",
type="pil",
height=300
)
submit_btn = gr.Button("Super-Resolve", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Super-Resolved Output",
type="pil",
height=300
)
status = gr.Textbox(label="Status", value="Loading models...", interactive=False)
# Load models on startup
demo.load(
fn=load_models,
outputs=status,
show_progress=True
)
# Process on button click
submit_btn.click(
fn=super_resolve,
inputs=input_image,
outputs=output_image,
show_progress=True
)
# Also process on image upload
input_image.change(
fn=super_resolve,
inputs=input_image,
outputs=output_image,
show_progress=True
)
if __name__ == "__main__":
demo.launch(share=True)
|