Spaces:
Runtime error
Runtime error
| import logging | |
| import random | |
| import warnings | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from diffusers import FluxImg2ImgPipeline | |
| from gradio_imageslider import ImageSlider | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import requests | |
| from transformers import T5TokenizerFast | |
| # For ESRGAN (requires pip install basicsr gfpgan) | |
| try: | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils import img2tensor, tensor2img | |
| USE_ESRGAN = True | |
| except ImportError: | |
| USE_ESRGAN = False | |
| warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.") | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 800px; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| """ | |
| # Device setup - Default to CPU, let runtime handle GPU | |
| power_device = "ZeroGPU" | |
| device = "cpu" | |
| # Get HuggingFace token | |
| huggingface_token = os.getenv("HF_TOKEN") | |
| MAX_SEED = 1000000 | |
| MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support | |
| def process_input(input_image, upscale_factor): | |
| """Process input image and handle size constraints""" | |
| w, h = input_image.size | |
| w_original, h_original = w, h | |
| aspect_ratio = w / h | |
| was_resized = False | |
| if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
| warnings.warn( | |
| f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget." | |
| ) | |
| gr.Info( | |
| f"Requested output image is too large. Resizing input to fit within pixel budget." | |
| ) | |
| target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2) | |
| scale = (target_input_pixels / (w * h)) ** 0.5 | |
| new_w = int(w * scale) // 16 * 16 # Ensure divisible by 16 for Flux compatibility | |
| new_h = int(h * scale) // 16 * 16 | |
| if new_w == 0 or new_h == 0: | |
| new_w = max(16, new_w) | |
| new_h = max(16, new_h) | |
| input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS) | |
| was_resized = True | |
| return input_image, w_original, h_original, was_resized | |
| def load_image_from_url(url): | |
| """Load image from URL""" | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| return Image.open(response.raw) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load image from URL: {e}") | |
| def esrgan_upscale(image, scale=4): | |
| if not USE_ESRGAN: | |
| return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS) | |
| img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True) | |
| with torch.no_grad(): | |
| output = esrgan_model(img.unsqueeze(0)).squeeze() | |
| output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1)) | |
| return Image.fromarray(output_img) | |
| def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32): | |
| """Tiled Img2Img to mimic Ultimate SD Upscaler tiling""" | |
| w, h = image.size | |
| output = image.copy() # Start with the control image | |
| for x in range(0, w, tile_size - overlap): | |
| for y in range(0, h, tile_size - overlap): | |
| tile_w = min(tile_size, w - x) | |
| tile_h = min(tile_size, h - y) | |
| tile = image.crop((x, y, x + tile_w, y + tile_h)) | |
| # Run Flux on tile | |
| gen_tile = pipe( | |
| prompt=prompt, | |
| image=tile, | |
| strength=strength, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| height=tile_h, | |
| width=tile_w, | |
| generator=generator, | |
| ).images[0] | |
| # Resize gen_tile back to original tile dimensions if pipeline resized it | |
| if gen_tile.size != (tile_w, tile_h): | |
| gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS) | |
| # Paste with blending if overlap | |
| if overlap > 0: | |
| paste_box = (x, y, x + tile_w, y + tile_h) | |
| if x > 0 or y > 0: | |
| # Simple linear blend on overlaps | |
| mask = Image.new('L', (tile_w, tile_h), 255) | |
| if x > 0: | |
| for i in range(overlap): | |
| for j in range(tile_h): | |
| mask.putpixel((i, j), int(255 * (i / overlap))) | |
| if y > 0: | |
| for i in range(tile_w): | |
| for j in range(overlap): | |
| mask.putpixel((i, j), int(255 * (j / overlap))) | |
| output.paste(gen_tile, paste_box, mask) | |
| else: | |
| output.paste(gen_tile, paste_box) | |
| else: | |
| output.paste(gen_tile, (x, y)) | |
| return output | |
| def enhance_image( | |
| image_input, | |
| image_url, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| upscale_factor, | |
| denoising_strength, | |
| custom_prompt, | |
| tile_size, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Main enhancement function""" | |
| # Lazy loading of models | |
| global pipe, esrgan_model | |
| if 'pipe' not in globals(): | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| print(f"📥 Loading FLUX Img2Img on {device}...") | |
| tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", token=huggingface_token) | |
| pipe = FluxImg2ImgPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| device_map="balanced", | |
| tokenizer_2=tokenizer_2, | |
| token=huggingface_token | |
| ) | |
| pipe.enable_vae_tiling() | |
| pipe.enable_vae_slicing() | |
| if device == "cuda": | |
| pipe.reset_device_map() | |
| pipe.enable_model_cpu_offload() | |
| if USE_ESRGAN: | |
| esrgan_path = "4x-UltraSharp.pth" | |
| if not os.path.exists(esrgan_path): | |
| url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth" | |
| with open(esrgan_path, "wb") as f: | |
| f.write(requests.get(url).content) | |
| esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| state_dict = torch.load(esrgan_path)['params_ema'] | |
| esrgan_model.load_state_dict(state_dict) | |
| esrgan_model.eval() | |
| esrgan_model.to(device) | |
| print("✅ Models loaded successfully!") | |
| except Exception as e: | |
| print(f"Model loading error: {e}, falling back to CPU") | |
| device = "cpu" | |
| dtype = torch.float32 | |
| # Reload on CPU if needed | |
| tokenizer_2 = T5TokenizerFast.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2", token=huggingface_token) | |
| pipe = FluxImg2ImgPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| device_map=None, | |
| tokenizer_2=tokenizer_2, | |
| token=huggingface_token | |
| ) | |
| pipe.enable_vae_tiling() | |
| pipe.enable_vae_slicing() | |
| # Handle image input | |
| if image_input is not None: | |
| input_image = image_input | |
| elif image_url: | |
| input_image = load_image_from_url(image_url) | |
| else: | |
| raise gr.Error("Please provide an image (upload or URL)") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| true_input_image = input_image | |
| # Process input image | |
| input_image, w_original, h_original, was_resized = process_input( | |
| input_image, upscale_factor | |
| ) | |
| prompt = custom_prompt if custom_prompt.strip() else "" | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| gr.Info("🚀 Upscaling image...") | |
| # Initial upscale | |
| if USE_ESRGAN and upscale_factor == 4: | |
| control_image = esrgan_upscale(input_image, upscale_factor) | |
| else: | |
| w, h = input_image.size | |
| control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS) | |
| # Tiled Flux Img2Img for refinement | |
| image = tiled_flux_img2img( | |
| pipe, | |
| prompt, | |
| control_image, | |
| denoising_strength, | |
| num_inference_steps, | |
| 3.5, # Updated guidance_scale to match workflow (3.5) | |
| generator, | |
| tile_size=tile_size, | |
| overlap=32 | |
| ) | |
| if was_resized: | |
| gr.Info(f"📏 Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}") | |
| image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS) | |
| # Resize input image to match output size for slider alignment | |
| resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS) | |
| # Move back to CPU to release GPU if possible | |
| if device == "cuda": | |
| pipe.to("cpu") | |
| if USE_ESRGAN: | |
| esrgan_model.to("cpu") | |
| return [resized_input, image] | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, title="🎨 AI Image Upscaler - FLUX") as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>🎨 AI Image Upscaler</h1> | |
| <p>Upload an image or provide a URL to upscale it using FLUX upscaling</p> | |
| <p>Currently running on <strong>{}</strong></p> | |
| </div> | |
| """.format(power_device)) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>📤 Input</h3>") | |
| with gr.Tabs(): | |
| with gr.TabItem("📁 Upload Image"): | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| height=200 # Made smaller | |
| ) | |
| with gr.TabItem("🔗 Image URL"): | |
| image_url = gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" | |
| ) | |
| gr.HTML("<h3>🎛️ Prompt Settings</h3>") | |
| custom_prompt = gr.Textbox( | |
| label="Custom Prompt (optional)", | |
| placeholder="Enter custom prompt or leave empty", | |
| lines=2 | |
| ) | |
| gr.HTML("<h3>⚙️ Upscaling Settings</h3>") | |
| upscale_factor = gr.Slider( | |
| label="Upscale Factor", | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=2, | |
| info="How much to upscale the image" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of Inference Steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=4, | |
| info="More steps = better quality but slower (default 4 for schnell)" | |
| ) | |
| denoising_strength = gr.Slider( | |
| label="Denoising Strength", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.3, | |
| info="Controls how much the image is transformed" | |
| ) | |
| tile_size = gr.Slider( | |
| label="Tile Size", | |
| minimum=256, | |
| maximum=2048, | |
| step=64, | |
| value=1024, | |
| info="Size of tiles for processing (larger = faster but more memory)" | |
| ) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize seed", | |
| value=True | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=42, | |
| interactive=True | |
| ) | |
| enhance_btn = gr.Button( | |
| "🚀 Upscale Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=2): # Larger scale for results | |
| gr.HTML("<h3>📊 Results</h3>") | |
| result_slider = ImageSlider( | |
| type="pil", | |
| interactive=False, # Disable interactivity to prevent uploads | |
| height=600, # Made larger | |
| elem_id="result_slider", | |
| label=None # Remove default label | |
| ) | |
| # Event handler | |
| enhance_btn.click( | |
| fn=enhance_image, | |
| inputs=[ | |
| input_image, | |
| image_url, | |
| seed, | |
| randomize_seed, | |
| num_inference_steps, | |
| upscale_factor, | |
| denoising_strength, | |
| custom_prompt, | |
| tile_size | |
| ], | |
| outputs=[result_slider] | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> | |
| <p><strong>Note:</strong> This upscaler uses the Flux.1-schnell model. Users are responsible for obtaining commercial rights if used commercially under their license.</p> | |
| </div> | |
| """) | |
| # Custom CSS for slider | |
| gr.HTML(""" | |
| <style> | |
| #result_slider .slider { | |
| width: 100% !important; | |
| max-width: inherit !important; | |
| } | |
| #result_slider img { | |
| object-fit: contain !important; | |
| width: 100% !important; | |
| height: auto !important; | |
| } | |
| #result_slider .gr-button-tool { | |
| display: none !important; | |
| } | |
| #result_slider .gr-button-undo { | |
| display: none !important; | |
| } | |
| #result_slider .gr-button-clear { | |
| display: none !important; | |
| } | |
| #result_slider .badge-container .badge { | |
| display: none !important; | |
| } | |
| #result_slider .badge-container::before { | |
| content: "Before"; | |
| position: absolute; | |
| top: 10px; | |
| left: 10px; | |
| background: rgba(0,0,0,0.5); | |
| color: white; | |
| padding: 5px; | |
| border-radius: 5px; | |
| z-index: 10; | |
| } | |
| #result_slider .badge-container::after { | |
| content: "After"; | |
| position: absolute; | |
| top: 10px; | |
| right: 10px; | |
| background: rgba(0,0,0,0.5); | |
| color: white; | |
| padding: 5px; | |
| border-radius: 5px; | |
| z-index: 10; | |
| } | |
| #result_slider .fullscreen img { | |
| object-fit: contain !important; | |
| width: 100vw !important; | |
| height: 100vh !important; | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| } | |
| </style> | |
| """) | |
| # JS to set slider default position to middle | |
| gr.HTML(""" | |
| <script> | |
| document.addEventListener('DOMContentLoaded', function() { | |
| const sliderInput = document.querySelector('#result_slider input[type="range"]'); | |
| if (sliderInput) { | |
| sliderInput.value = 50; | |
| sliderInput.dispatchEvent(new Event('input')); | |
| } | |
| }); | |
| </script> | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860) |