Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import sys | |
| from typing import Sequence, Mapping, Any, Union | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| # Download required models from Hugging Face | |
| hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae") | |
| hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="flux1-dev.safetensors", local_dir="models/diffusion_models") | |
| hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders") | |
| hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders") | |
| hf_hub_download(repo_id="kim2091/UltraSharp", filename="4x-UltraSharp.pth", local_dir="models/upscale_models") | |
| def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: | |
| """Returns the value at the given index of a sequence or mapping.""" | |
| try: | |
| return obj[index] | |
| except KeyError: | |
| return obj["result"][index] | |
| def find_path(name: str, path: str = None) -> str: | |
| """Recursively looks at parent folders starting from the given path until it finds the given name.""" | |
| if path is None: | |
| path = os.getcwd() | |
| if name in os.listdir(path): | |
| path_name = os.path.join(path, name) | |
| print(f"{name} found: {path_name}") | |
| return path_name | |
| parent_directory = os.path.dirname(path) | |
| if parent_directory == path: | |
| return None | |
| return find_path(name, parent_directory) | |
| def add_comfyui_directory_to_sys_path() -> None: | |
| """Add 'ComfyUI' to the sys.path""" | |
| comfyui_path = find_path("ComfyUI") | |
| if comfyui_path is not None and os.path.isdir(comfyui_path): | |
| sys.path.append(comfyui_path) | |
| print(f"'{comfyui_path}' added to sys.path") | |
| def add_extra_model_paths() -> None: | |
| """Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path.""" | |
| try: | |
| from main import load_extra_path_config | |
| except ImportError: | |
| print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.") | |
| from utils.extra_config import load_extra_path_config | |
| extra_model_paths = find_path("extra_model_paths.yaml") | |
| if extra_model_paths is not None: | |
| load_extra_path_config(extra_model_paths) | |
| else: | |
| print("Could not find the extra_model_paths config file.") | |
| add_comfyui_directory_to_sys_path() | |
| add_extra_model_paths() | |
| def import_custom_nodes() -> None: | |
| """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS""" | |
| import asyncio | |
| import execution | |
| from nodes import init_extra_nodes | |
| import server | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| server_instance = server.PromptServer(loop) | |
| execution.PromptQueue(server_instance) | |
| init_extra_nodes() | |
| from nodes import NODE_CLASS_MAPPINGS | |
| # Pre-load models outside the decorated function for ZeroGPU efficiency | |
| import_custom_nodes() | |
| # Initialize model loaders | |
| dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() | |
| dualcliploader_54 = dualcliploader.load_clip( | |
| clip_name1="clip_l.safetensors", | |
| clip_name2="t5xxl_fp16.safetensors", | |
| type="flux", | |
| device="default", | |
| ) | |
| upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]() | |
| upscalemodelloader_44 = upscalemodelloader.load_model(model_name="4x-UltraSharp.pth") | |
| vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() | |
| vaeloader_55 = vaeloader.load_vae(vae_name="ae.safetensors") | |
| unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() | |
| unetloader_58 = unetloader.load_unet( | |
| unet_name="flux1-dev.safetensors", weight_dtype="default" | |
| ) | |
| downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]() | |
| downloadandloadflorence2model_52 = downloadandloadflorence2model.loadmodel( | |
| model="microsoft/Florence-2-large", precision="fp16", attention="sdpa" | |
| ) | |
| # Pre-load models to GPU for efficiency | |
| from comfy import model_management | |
| model_loaders = [dualcliploader_54, vaeloader_55, unetloader_58, downloadandloadflorence2model_52] | |
| valid_models = [ | |
| getattr(loader[0], 'patcher', loader[0]) | |
| for loader in model_loaders | |
| if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict) | |
| ] | |
| model_management.load_models_gpu(valid_models) | |
| # Adjust duration based on your workflow speed | |
| def enhance_image(image_input, upscale_factor, steps, cfg_scale, denoise_strength, guidance_scale): | |
| """ | |
| Main function to enhance and upscale images using Florence-2 captioning and FLUX upscaling | |
| """ | |
| try: | |
| with torch.inference_mode(): | |
| # Handle different input types (file upload vs URL) | |
| if isinstance(image_input, str) and image_input.startswith(('http://', 'https://')): | |
| # Load from URL | |
| load_image_from_url_mtb = NODE_CLASS_MAPPINGS["Load Image From Url (mtb)"]() | |
| load_image_result = load_image_from_url_mtb.load(url=image_input) | |
| else: | |
| # Load from uploaded file | |
| loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() | |
| load_image_result = loadimage.load_image(image=image_input) | |
| # Generate detailed caption using Florence-2 | |
| florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]() | |
| florence2run_51 = florence2run.encode( | |
| text_input="", | |
| task="more_detailed_caption", | |
| fill_mask=True, | |
| keep_model_loaded=False, | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| do_sample=True, | |
| output_mask_select="", | |
| seed=random.randint(1, 2**64), | |
| image=get_value_at_index(load_image_result, 0), | |
| florence2_model=get_value_at_index(downloadandloadflorence2model_52, 0), | |
| ) | |
| # Encode the generated caption | |
| cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() | |
| cliptextencode_6 = cliptextencode.encode( | |
| text=get_value_at_index(florence2run_51, 2), | |
| clip=get_value_at_index(dualcliploader_54, 0), | |
| ) | |
| # Encode empty negative prompt | |
| cliptextencode_42 = cliptextencode.encode( | |
| text="", clip=get_value_at_index(dualcliploader_54, 0) | |
| ) | |
| # Set up upscale factor | |
| primitivefloat = NODE_CLASS_MAPPINGS["PrimitiveFloat"]() | |
| primitivefloat_60 = primitivefloat.execute(value=upscale_factor) | |
| # Apply FLUX guidance | |
| fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() | |
| fluxguidance_26 = fluxguidance.append( | |
| guidance=guidance_scale, | |
| conditioning=get_value_at_index(cliptextencode_6, 0) | |
| ) | |
| # Perform ultimate upscaling | |
| ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]() | |
| ultimatesdupscale_50 = ultimatesdupscale.upscale( | |
| upscale_by=get_value_at_index(primitivefloat_60, 0), | |
| seed=random.randint(1, 2**64), | |
| steps=steps, | |
| cfg=cfg_scale, | |
| sampler_name="euler", | |
| scheduler="normal", | |
| denoise=denoise_strength, | |
| mode_type="Linear", | |
| tile_width=1024, | |
| tile_height=1024, | |
| mask_blur=8, | |
| tile_padding=32, | |
| seam_fix_mode="None", | |
| seam_fix_denoise=1, | |
| seam_fix_width=64, | |
| seam_fix_mask_blur=8, | |
| seam_fix_padding=16, | |
| force_uniform_tiles=True, | |
| tiled_decode=False, | |
| image=get_value_at_index(load_image_result, 0), | |
| model=get_value_at_index(unetloader_58, 0), | |
| positive=get_value_at_index(fluxguidance_26, 0), | |
| negative=get_value_at_index(cliptextencode_42, 0), | |
| vae=get_value_at_index(vaeloader_55, 0), | |
| upscale_model=get_value_at_index(upscalemodelloader_44, 0), | |
| ) | |
| # Save the result | |
| saveimage = NODE_CLASS_MAPPINGS["SaveImage"]() | |
| saveimage_43 = saveimage.save_images( | |
| filename_prefix="enhanced_image", | |
| images=get_value_at_index(ultimatesdupscale_50, 0), | |
| ) | |
| # Return the path to the saved image | |
| saved_path = f"output/{saveimage_43['ui']['images'][0]['filename']}" | |
| # Also return the generated caption for user feedback | |
| generated_caption = get_value_at_index(florence2run_51, 2) | |
| return saved_path, generated_caption | |
| except Exception as e: | |
| print(f"Error in enhance_image: {str(e)}") | |
| raise gr.Error(f"Enhancement failed: {str(e)}") | |
| # Create the Gradio interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="🚀 AI Image Enhancer - Florence-2 + FLUX", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 2rem; | |
| } | |
| .result-gallery { | |
| min-height: 400px; | |
| } | |
| """ | |
| ) as app: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>🎨 AI Image Enhancer</h1> | |
| <p>Upload an image or provide a URL to enhance it using Florence-2 captioning and FLUX upscaling</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>📤 Input Settings</h3>") | |
| with gr.Tabs(): | |
| with gr.TabItem("📁 Upload Image"): | |
| image_upload = gr.Image( | |
| label="Upload Image", | |
| type="filepath", | |
| height=300 | |
| ) | |
| with gr.TabItem("🔗 Image URL"): | |
| image_url = gr.Textbox( | |
| label="Image URL", | |
| placeholder="https://example.com/image.jpg", | |
| value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg" | |
| ) | |
| gr.HTML("<h3>⚙️ Enhancement Settings</h3>") | |
| upscale_factor = gr.Slider( | |
| minimum=1.0, | |
| maximum=4.0, | |
| value=2.0, | |
| step=0.5, | |
| label="Upscale Factor", | |
| info="How much to upscale the image" | |
| ) | |
| steps = gr.Slider( | |
| minimum=10, | |
| maximum=50, | |
| value=25, | |
| step=5, | |
| label="Steps", | |
| info="Number of denoising steps" | |
| ) | |
| cfg_scale = gr.Slider( | |
| minimum=0.5, | |
| maximum=10.0, | |
| value=1.0, | |
| step=0.5, | |
| label="CFG Scale", | |
| info="Classifier-free guidance scale" | |
| ) | |
| denoise_strength = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.1, | |
| label="Denoise Strength", | |
| info="How much to denoise the image" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.5, | |
| step=0.5, | |
| label="Guidance Scale", | |
| info="FLUX guidance strength" | |
| ) | |
| enhance_btn = gr.Button( | |
| "🚀 Enhance Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.HTML("<h3>📊 Results</h3>") | |
| output_image = gr.Image( | |
| label="Enhanced Image", | |
| type="filepath", | |
| height=400, | |
| interactive=False | |
| ) | |
| generated_caption = gr.Textbox( | |
| label="Generated Caption", | |
| placeholder="The AI-generated caption will appear here...", | |
| lines=3, | |
| interactive=False | |
| ) | |
| gr.HTML(""" | |
| <div style="margin-top: 1rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;"> | |
| <h4>💡 How it works:</h4> | |
| <ol> | |
| <li>Florence-2 analyzes your image and generates a detailed caption</li> | |
| <li>FLUX uses this caption to guide the upscaling process</li> | |
| <li>The result is an enhanced, higher-resolution image</li> | |
| </ol> | |
| </div> | |
| """) | |
| # Event handlers | |
| def process_image(img_upload, img_url, upscale_f, steps_val, cfg_val, denoise_val, guidance_val): | |
| # Determine input source | |
| image_input = img_upload if img_upload is not None else img_url | |
| if not image_input: | |
| raise gr.Error("Please provide an image (upload or URL)") | |
| return enhance_image(image_input, upscale_f, steps_val, cfg_val, denoise_val, guidance_val) | |
| enhance_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| image_upload, | |
| image_url, | |
| upscale_factor, | |
| steps, | |
| cfg_scale, | |
| denoise_strength, | |
| guidance_scale | |
| ], | |
| outputs=[output_image, generated_caption] | |
| ) | |
| # Example inputs | |
| gr.Examples( | |
| examples=[ | |
| [None, "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg", 2.0, 25, 1.0, 0.3, 3.5], | |
| [None, "https://picsum.photos/512/512", 2.0, 20, 1.5, 0.4, 4.0], | |
| ], | |
| inputs=[ | |
| image_upload, | |
| image_url, | |
| upscale_factor, | |
| steps, | |
| cfg_scale, | |
| denoise_strength, | |
| guidance_scale | |
| ] | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_interface() | |
| app.launch(share=True, server_name="0.0.0.0", server_port=7860) |