import logging import random import warnings import os import gradio as gr import numpy as np import spaces import torch from gradio_imageslider import ImageSlider from PIL import Image import requests import sys import subprocess from huggingface_hub import hf_hub_download import tempfile os.environ["GIT_TERMINAL_PROMPT"] = "0" # Setup ComfyUI and custom nodes if not os.path.exists("ComfyUI"): subprocess.run(["git", "clone", "https://github.com/comfyanonymous/ComfyUI"]) custom_nodes_dir = os.path.join("ComfyUI", "custom_nodes") os.makedirs(custom_nodes_dir, exist_ok=True) # Clone UltimateSDUpscale usd_dir = os.path.join(custom_nodes_dir, "ComfyUI_UltimateSDUpscale") if not os.path.exists(usd_dir): subprocess.run(["git", "clone", "https://github.com/ssitu/ComfyUI_UltimateSDUpscale", usd_dir]) # Clone comfy_mtb mtb_dir = os.path.join(custom_nodes_dir, "comfy_mtb") if not os.path.exists(mtb_dir): subprocess.run(["git", "clone", "https://github.com/melMass/comfy_mtb", mtb_dir]) # Install requirements if os.path.exists(os.path.join(mtb_dir, "requirements.txt")): subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=mtb_dir) # Clone KJNodes kjn_dir = os.path.join(custom_nodes_dir, "ComfyUI-KJNodes") if not os.path.exists(kjn_dir): subprocess.run(["git", "clone", "https://github.com/kijai/ComfyUI-KJNodes", kjn_dir]) # Install requirements if os.path.exists(os.path.join(kjn_dir, "requirements.txt")): subprocess.run([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=kjn_dir) # Download models if not present comfy_models_dir = os.path.join("ComfyUI", "models") os.makedirs(comfy_models_dir, exist_ok=True) # Diffusion models (Flux FP8) diffusion_dir = os.path.join(comfy_models_dir, "diffusion_models") os.makedirs(diffusion_dir, exist_ok=True) if not os.path.exists(os.path.join(diffusion_dir, "flux1-dev-fp8.safetensors")): hf_hub_download(repo_id="Kijai/flux-fp8", filename="flux1-dev-fp8.safetensors", local_dir=diffusion_dir) # CLIP models clip_dir = os.path.join(comfy_models_dir, "clip") os.makedirs(clip_dir, exist_ok=True) if not os.path.exists(os.path.join(clip_dir, "clip_l.safetensors")): hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir=clip_dir) if not os.path.exists(os.path.join(clip_dir, "t5xxl_fp8_e4m3fn.safetensors")): hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir=clip_dir) # VAE vae_dir = os.path.join(comfy_models_dir, "vae") os.makedirs(vae_dir, exist_ok=True) if not os.path.exists(os.path.join(vae_dir, "ae.safetensors")): hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir=vae_dir) # Upscale models upscale_dir = os.path.join(comfy_models_dir, "upscale_models") os.makedirs(upscale_dir, exist_ok=True) for model_name in ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"]: model_path = os.path.join(upscale_dir, model_name) if not os.path.exists(model_path): url = f"https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/{model_name}" with open(model_path, "wb") as f: f.write(requests.get(url).content) # Add ComfyUI to sys.path sys.path.append(os.path.abspath("ComfyUI")) # Import custom nodes from nodes import NODE_CLASS_MAPPINGS, init_custom_nodes init_custom_nodes() # From the provided script def get_value_at_index(obj, index): try: return obj[index] except KeyError: return obj["result"][index] # CSS and constants similar to original css = """ #col-container { margin: 0 auto; max-width: 800px; } .main-header { text-align: center; margin-bottom: 2rem; } """ power_device = "ZeroGPU" MAX_SEED = 1000000 MAX_PIXEL_BUDGET = 8192 * 8192 def make_divisible_by_16(size): return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16) def process_input(input_image, upscale_factor): w, h = input_image.size w_original, h_original = w, h was_resized = False if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: gr.Info("Requested output too large. Resizing input.") target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2) scale = (target_input_pixels / (w * h)) ** 0.5 new_w = max(16, int(w * scale) // 16 * 16) new_h = max(16, int(h * scale) // 16 * 16) input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS) was_resized = True return input_image, w_original, h_original, was_resized def load_image_from_url(url): try: response = requests.get(url, stream=True) response.raise_for_status() return Image.open(response.raw) except Exception as e: raise gr.Error(f"Failed to load image: {e}") @spaces.GPU(duration=120) def enhance_image( image_input, image_url, seed, randomize_seed, num_inference_steps, upscale_factor, denoising_strength, custom_prompt, tile_size, progress=gr.Progress(track_tqdm=True), ): with torch.inference_mode(): # Handle input image if image_input is not None: true_input_image = image_input elif image_url: true_input_image = load_image_from_url(image_url) else: raise gr.Error("Provide an image or URL") input_image, w_original, h_original, was_resized = process_input(true_input_image, upscale_factor) if randomize_seed: seed = random.randint(0, MAX_SEED) # Prepare ComfyUI input image input_dir = os.path.join("ComfyUI", "input") os.makedirs(input_dir, exist_ok=True) temp_filename = f"input_{random.randint(0, 1000000)}.png" input_path = os.path.join(input_dir, temp_filename) input_image.save(input_path) # Nodes load_image_node = NODE_CLASS_MAPPINGS["LoadImage"]() image_loaded = load_image_node.load_image(image=temp_filename) image = get_value_at_index(image_loaded, 0) text_multiline = NODE_CLASS_MAPPINGS["Text Multiline"]() text_out = text_multiline.text_multiline(text=custom_prompt if custom_prompt.strip() else "") prompt_text = get_value_at_index(text_out, 0) dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() clip_out = dualcliploader.load_clip( clip_name1="clip_l.safetensors", clip_name2="t5xxl_fp8_e4m3fn.safetensors", type="flux", ) clip = get_value_at_index(clip_out, 0) cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() conditioning = get_value_at_index(cliptextencode.encode(text=prompt_text, clip=clip), 0) fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() positive_out = fluxguidance.append(guidance=3.5, conditioning=conditioning) # Using 3.5 as in original app positive = get_value_at_index(positive_out, 0) conditioningzeroout = NODE_CLASS_MAPPINGS["ConditioningZeroOut"]() negative_out = conditioningzeroout.zero_out(conditioning=conditioning) negative = get_value_at_index(negative_out, 0) upscale_name = "RealESRGAN_x2.pth" if upscale_factor == 2 else "RealESRGAN_x4.pth" upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]() upscale_model = get_value_at_index(upscalemodelloader.load_model(model_name=upscale_name), 0) vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() vae = get_value_at_index(vaeloader.load_vae(vae_name="ae.safetensors"), 0) unetloader = NODE_CLASS_MAPPINGS["LoadDiffusionModel"]() model = get_value_at_index(unetloader.load_diffusion_model(unet_name="flux1-dev-fp8.safetensors", weight_dtype="fp8_e4m3fn"), 0) ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]() upscale_out = ultimatesdupscale.upscale( upscale_by=float(upscale_factor), seed=seed, steps=num_inference_steps, cfg=1.0, sampler_name="euler", scheduler="normal", denoise=denoising_strength, mode_type="Linear", tile_width=tile_size, tile_height=tile_size, mask_blur=8, tile_padding=32, seam_fix_mode="None", seam_fix_denoise=1.0, seam_fix_width=64, seam_fix_mask_blur=8, seam_fix_padding=16, force_uniform_tiles=True, tiled_decode=False, image=image, model=model, positive=positive, negative=negative, vae=vae, upscale_model=upscale_model, ) upscaled_tensor = get_value_at_index(upscale_out, 0) # Convert to PIL upscaled_img = Image.fromarray((upscaled_tensor[0].cpu().numpy() * 255).astype(np.uint8)) target_w, target_h = w_original * upscale_factor, h_original * upscale_factor if upscaled_img.size != (target_w, target_h): upscaled_img = upscaled_img.resize((target_w, target_h), resample=Image.LANCZOS) if was_resized: upscaled_img = upscaled_img.resize((target_w, target_h), resample=Image.LANCZOS) resized_input = true_input_image.resize(upscaled_img.size, resample=Image.LANCZOS) # Cleanup temp file os.remove(input_path) return [resized_input, upscaled_img] # Gradio interface similar to original with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Flux FP8") as demo: gr.HTML("""

🎨 AI Image Upscaler - Flux FP8

Upscale images using Flux FP8 with ComfyUI workflow

Running on {}

""".format(power_device)) with gr.Row(): with gr.Column(scale=1): gr.HTML("

📤 Input

") with gr.Tabs(): with gr.TabItem("📁 Upload Image"): input_image = gr.Image(label="Upload Image", type="pil", height=200) with gr.TabItem("🔗 Image URL"): image_url = gr.Textbox( label="Image URL", placeholder="https://example.com/image.jpg", value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" ) gr.HTML("

🎛️ Prompt Settings

") custom_prompt = gr.Textbox( label="Custom Prompt (optional)", placeholder="Enter custom prompt or leave empty", lines=2 ) gr.HTML("

⚙️ Upscaling Settings

") upscale_factor = gr.Slider( label="Upscale Factor", minimum=1, maximum=4, step=1, value=2 ) num_inference_steps = gr.Slider( label="Inference Steps", minimum=1, maximum=50, step=1, value=25 ) denoising_strength = gr.Slider( label="Denoising Strength", minimum=0.0, maximum=1.0, step=0.05, value=0.3 ) tile_size = gr.Slider( label="Tile Size", minimum=256, maximum=2048, step=64, value=1024 ) with gr.Row(): randomize_seed = gr.Checkbox(label="Randomize seed", value=True) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) enhance_btn = gr.Button("🚀 Upscale Image", variant="primary", size="lg") with gr.Column(scale=2): gr.HTML("

📊 Results

") result_slider = ImageSlider(type="pil", interactive=False, height=600, label=None) enhance_btn.click( fn=enhance_image, inputs=[ input_image, image_url, seed, randomize_seed, num_inference_steps, upscale_factor, denoising_strength, custom_prompt, tile_size ], outputs=[result_slider] ) gr.HTML("""

Note: Uses Flux FP8 model. Ensure compliance with licenses for commercial use.

""") gr.HTML(""" """) gr.HTML(""" """) if __name__ == "__main__": demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)