Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import warnings | |
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from diffusers import FluxImg2ImgPipeline | |
| from gradio_imageslider import ImageSlider | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import requests | |
| # ESRGAN imports | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils import img2tensor, tensor2img | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 800px; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| """ | |
| # Get HuggingFace token | |
| huggingface_token = os.getenv("HF_TOKEN") | |
| # Download FLUX model if not already cached | |
| print("π₯ Downloading FLUX model...") | |
| model_path = snapshot_download( | |
| repo_id="black-forest-labs/FLUX.1-dev", | |
| repo_type="model", | |
| ignore_patterns=["*.md", "*.gitattributes"], | |
| local_dir="FLUX.1-dev", | |
| token=huggingface_token, | |
| ) | |
| # Load FLUX pipeline on CPU initially | |
| print("π₯ Loading FLUX Img2Img pipeline...") | |
| pipe = FluxImg2ImgPipeline.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| use_safetensors=True | |
| ) | |
| # Enable memory optimizations | |
| pipe.enable_vae_tiling() | |
| pipe.enable_vae_slicing() | |
| pipe.vae.enable_tiling() | |
| pipe.vae.enable_slicing() | |
| # Download and load ESRGAN 4x-UltraSharp model | |
| print("π₯ Loading ESRGAN 4x-UltraSharp...") | |
| esrgan_path = "4x-UltraSharp.pth" | |
| if not os.path.exists(esrgan_path): | |
| print("Downloading ESRGAN model...") | |
| url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" | |
| response = requests.get(url) | |
| with open(esrgan_path, "wb") as f: | |
| f.write(response.content) | |
| # Initialize ESRGAN model | |
| esrgan_model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=4 | |
| ) | |
| state_dict = torch.load(esrgan_path, map_location='cpu') | |
| if 'params_ema' in state_dict: | |
| state_dict = state_dict['params_ema'] | |
| elif 'params' in state_dict: | |
| state_dict = state_dict['params'] | |
| esrgan_model.load_state_dict(state_dict) | |
| esrgan_model.eval() | |
| print("β All models loaded successfully!") | |
| MAX_SEED = 1000000 | |
| MAX_INPUT_SIZE = 512 # Max input size before upscaling | |
| def make_multiple_16(n): | |
| """Round to nearest multiple of 16 for FLUX requirements""" | |
| return ((n + 15) // 16) * 16 | |
| def truncate_prompt(prompt, max_tokens=75): | |
| """Truncate prompt to avoid CLIP token limit (77 tokens)""" | |
| if not prompt: | |
| return "" | |
| # Simple truncation by character count (rough approximation) | |
| if len(prompt) > 250: # ~75 tokens | |
| return prompt[:250] + "..." | |
| return prompt | |
| def prepare_image(image, max_size=MAX_INPUT_SIZE): | |
| """Prepare image for processing""" | |
| w, h = image.size | |
| # Limit input size | |
| if w > max_size or h > max_size: | |
| image.thumbnail((max_size, max_size), Image.LANCZOS) | |
| return image | |
| def esrgan_upscale(image): | |
| """Upscale image 4x using ESRGAN""" | |
| # Convert PIL to tensor | |
| img_np = np.array(image).astype(np.float32) / 255. | |
| img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True) | |
| # Upscale | |
| with torch.no_grad(): | |
| output = esrgan_model(img_tensor.unsqueeze(0).cpu()) | |
| # Convert back to PIL | |
| output_np = tensor2img(output.squeeze(0), rgb2bgr=False, min_max=(0, 1)) | |
| return Image.fromarray(output_np) | |
| # 60 seconds should be enough | |
| def enhance_image( | |
| input_image, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| denoising_strength, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Main enhancement function""" | |
| if input_image is None: | |
| raise gr.Error("Please upload an image") | |
| # Clear memory | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| try: | |
| # Randomize seed if needed | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Prepare and validate prompt | |
| prompt = truncate_prompt(prompt.strip() if prompt else "high quality, detailed") | |
| # Prepare input image | |
| input_image = prepare_image(input_image) | |
| original_size = input_image.size | |
| # Step 1: ESRGAN upscale (4x) on CPU | |
| gr.Info("π Upscaling with ESRGAN 4x...") | |
| with torch.no_grad(): | |
| # Move ESRGAN to GPU for faster processing | |
| esrgan_model.to("cuda") | |
| # Convert image for ESRGAN | |
| img_np = np.array(input_image).astype(np.float32) / 255. | |
| img_tensor = img2tensor(img_np, bgr2rgb=False, float32=True) | |
| img_tensor = img_tensor.unsqueeze(0).to("cuda") | |
| # Upscale | |
| output_tensor = esrgan_model(img_tensor) | |
| # Convert back to PIL | |
| output_np = tensor2img(output_tensor.squeeze(0).cpu(), rgb2bgr=False, min_max=(0, 1)) | |
| upscaled_image = Image.fromarray(output_np) | |
| # Move ESRGAN back to CPU to free memory | |
| esrgan_model.to("cpu") | |
| torch.cuda.empty_cache() | |
| # Ensure dimensions are multiples of 16 for FLUX | |
| w, h = upscaled_image.size | |
| new_w = make_multiple_16(w) | |
| new_h = make_multiple_16(h) | |
| if new_w != w or new_h != h: | |
| # Pad image to meet requirements | |
| padded = Image.new('RGB', (new_w, new_h)) | |
| padded.paste(upscaled_image, (0, 0)) | |
| upscaled_image = padded | |
| # Step 2: FLUX enhancement | |
| gr.Info("π¨ Enhancing with FLUX...") | |
| # Move pipeline to GPU | |
| pipe.to("cuda") | |
| # Generate with FLUX | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| with torch.inference_mode(): | |
| result = pipe( | |
| prompt=prompt, | |
| image=upscaled_image, | |
| strength=denoising_strength, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=1.0, # Fixed at 1.0 for FLUX dev | |
| height=new_h, | |
| width=new_w, | |
| generator=generator, | |
| ).images[0] | |
| # Crop back if we padded | |
| if new_w != w or new_h != h: | |
| result = result.crop((0, 0, w, h)) | |
| # Move pipeline back to CPU | |
| pipe.to("cpu") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Prepare images for slider (before/after) | |
| input_resized = input_image.resize(result.size, Image.LANCZOS) | |
| gr.Info("β Enhancement complete!") | |
| return [input_resized, result], seed | |
| except Exception as e: | |
| # Cleanup on error | |
| pipe.to("cpu") | |
| esrgan_model.to("cpu") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| raise gr.Error(f"Enhancement failed: {str(e)}") | |
| # Create Gradio interface | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π ESRGAN 4x + FLUX Enhancement</h1> | |
| <p>Upload an image to upscale 4x with ESRGAN and enhance with FLUX</p> | |
| <p>Optimized for <strong>ZeroGPU</strong> | Max input: 512x512 β Output: 2048x2048</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=256 | |
| ) | |
| prompt = gr.Textbox( | |
| label="Enhancement Prompt", | |
| placeholder="Describe the desired enhancement (e.g., 'high quality, sharp details, vibrant colors')", | |
| value="high quality, ultra detailed, sharp", | |
| lines=2 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| num_inference_steps = gr.Slider( | |
| label="Enhancement Steps", | |
| minimum=10, | |
| maximum=25, | |
| step=1, | |
| value=18, | |
| info="More steps = better quality but slower" | |
| ) | |
| denoising_strength = gr.Slider( | |
| label="Enhancement Strength", | |
| minimum=0.1, | |
| maximum=0.6, | |
| step=0.05, | |
| value=0.35, | |
| info="Higher = more changes to the image" | |
| ) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42 | |
| ) | |
| enhance_btn = gr.Button( | |
| "π¨ Enhance Image (4x Upscale)", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=2): | |
| # Output section | |
| result_slider = ImageSlider( | |
| type="pil", | |
| label="Before / After", | |
| interactive=False, | |
| height=512 | |
| ) | |
| used_seed = gr.Number( | |
| label="Seed Used", | |
| interactive=False, | |
| visible=False | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["high quality, ultra detailed, sharp"], | |
| ["cinematic, professional photography, enhanced details"], | |
| ["vibrant colors, high contrast, sharp focus"], | |
| ["photorealistic, 8k quality, fine details"], | |
| ], | |
| inputs=[prompt], | |
| label="Example Prompts" | |
| ) | |
| # Event handler | |
| enhance_btn.click( | |
| fn=enhance_image, | |
| inputs=[ | |
| input_image, | |
| prompt, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| denoising_strength, | |
| ], | |
| outputs=[result_slider, used_seed] | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 2rem; text-align: center; color: #666;"> | |
| <p>π Pipeline: ESRGAN 4x-UltraSharp β FLUX Dev Enhancement</p> | |
| <p>β‘ Optimized for ZeroGPU with automatic memory management</p> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=3).launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |