import gradio as gr import torch import os import glob import spaces import numpy as np from datetime import datetime from PIL import Image from diffusers.utils import load_image from diffusers import EulerDiscreteScheduler from pipline_StableDiffusionXL_ConsistentID import ConsistentIDStableDiffusionXLPipeline from huggingface_hub import hf_hub_download from models.BiSeNet.model import BiSeNet # ==================================================================================== # Global model management for ZeroGPU compatibility # ==================================================================================== DEVICE = "cuda" pipe = None bise_net = None def load_models(): """Load all models on CPU to avoid ZeroGPU initialization issues""" global pipe, bise_net if pipe is not None: return print("โณ Loading models on CPU...") base_model_path = "SG161222/RealVisXL_V3.0" consistentID_path = hf_hub_download( repo_id="JackAILab/ConsistentID", filename="ConsistentID_SDXL-v1.bin", repo_type="model" ) # Load pipeline on CPU pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, safety_checker=None, variant="fp16" ) # Load BiSeNet bise_net_cp_path = hf_hub_download( repo_id="JackAILab/ConsistentID", filename="face_parsing.pth", local_dir="./checkpoints" ) bise_net = BiSeNet(n_classes=19) bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu")) # Load ConsistentID components pipe.load_ConsistentID_model( os.path.dirname(consistentID_path), bise_net, subfolder="", weight_name=os.path.basename(consistentID_path), trigger_word="img", ) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) print("โœ… Models loaded successfully") load_models() # ==================================================================================== # Inference function with GPU management # ==================================================================================== @spaces.GPU(duration=180) # Extended duration for SDXL def generate_image( selected_template_images, custom_image, prompt, negative_prompt, prompt_selected, model_selected_tab, prompt_selected_tab, width, height, merge_steps, seed, num_steps ): """ Generate image using ConsistentID-SDXL """ global pipe, bise_net print("๐Ÿš€ Moving models to GPU...") # Move to GPU pipe.to(DEVICE) pipe.image_encoder.to(DEVICE) pipe.image_proj_model.to(DEVICE) pipe.FacialEncoder.to(DEVICE) bise_net.to(DEVICE) try: # Select input image if model_selected_tab == 0: input_image = load_image(Image.open(selected_template_images)) else: input_image = load_image(Image.fromarray(custom_image)) # Select prompt if prompt_selected_tab == 0: prompt = prompt_selected negative_prompt = "" need_safetycheck = False else: need_safetycheck = True # Default prompts if not prompt or prompt.strip() == "": prompt = "A person, professional portrait" if not negative_prompt or negative_prompt.strip() == "": negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry" # Enhance prompt enhanced_prompt = f"cinematic photo, {prompt}, 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed" # Negative prompt enhancement negative_enhancement = "((cross-eye)), ((cross-eyed)), (((NSFW))), (nipple), ((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))), out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))" final_negative_prompt = negative_prompt + ", " + negative_enhancement generator = torch.Generator(device=DEVICE).manual_seed(seed) print(f"๐ŸŽจ Generating with prompt: {enhanced_prompt[:100]}...") images = pipe( prompt=enhanced_prompt, width=width, height=height, input_id_images=input_image, input_image_path=selected_template_images if model_selected_tab == 0 else None, negative_prompt=final_negative_prompt, num_images_per_prompt=1, num_inference_steps=num_steps, start_merge_step=merge_steps, generator=generator, retouching=False, need_safetycheck=need_safetycheck, ).images[0] print("โœ… Generation completed") return np.array(images) except Exception as e: print(f"โŒ Error: {str(e)}") raise finally: # Clean up GPU print("๐Ÿงน Releasing GPU memory...") pipe.to("cpu") pipe.image_encoder.to("cpu") pipe.image_proj_model.to("cpu") pipe.FacialEncoder.to("cpu") bise_net.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() # ==================================================================================== # Beautiful Gradio Interface # ==================================================================================== # Get template images preset_templates = glob.glob("./images/templates/*.png") + glob.glob("./images/templates/*.jpg") # Custom CSS for beautiful interface custom_css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } .main-title { text-align: center; font-size: 2.5em; font-weight: 700; background: linear-gradient(45deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1em; } .subtitle { text-align: center; font-size: 1.1em; color: #666; margin-bottom: 2em; } .section-header { font-size: 1.3em; font-weight: 600; margin: 1em 0 0.5em 0; color: #333; } .info-box { background: #f8f9fa; border-left: 4px solid #667eea; padding: 1em; margin: 1em 0; border-radius: 4px; } .generate-btn { background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important; border: none !important; color: white !important; font-size: 1.1em !important; font-weight: 600 !important; padding: 0.8em 2em !important; border-radius: 8px !important; } .gallery-item { border-radius: 8px; overflow: hidden; } """ # Template prompts with better organization template_prompts = [ ("๐Ÿ‘ฐ Wedding", "A woman in an elegant wedding dress, professional photography"), ("๐Ÿ‘‘ Royalty", "A person as royalty, sitting on throne in gorgeous palace, regal attire"), ("๐Ÿ–๏ธ Beach", "A person sitting at the beach with beautiful sunset, relaxed atmosphere"), ("๐Ÿ‘ฎ Officer", "A person as police officer, professional uniform, half body shot"), ("โ›ต Sailor", "A person as sailor, on boat deck above ocean, nautical uniform"), ("๐ŸŽง Music", "A person wearing headphones, listening to music, modern setting"), ("๐Ÿš’ Firefighter", "A person as firefighter, professional gear, half body shot"), ("๐Ÿ’ผ Business", "A person in business attire, professional corporate environment"), ("๐ŸŽจ Artist", "A person as artist in studio, creative atmosphere, artistic clothing"), ("๐Ÿ”ฌ Scientist", "A person as scientist in laboratory, lab coat, professional setting"), ] with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="ConsistentID-SDXL") as demo: # Header gr.HTML("""
โœจ ConsistentID-SDXL Demo โœจ
High-fidelity portrait generation with consistent identity preservation
""") gr.Markdown("""
โญ Star us on GitHub | ๐Ÿ“„ Read the Paper
""") with gr.Row(): # Left column - Inputs with gr.Column(scale=1): gr.HTML("
๐Ÿ“ธ Input Image
") model_selected_tab = gr.Number(value=0, visible=False) with gr.Tabs() as image_tabs: with gr.Tab("๐Ÿ–ผ๏ธ Templates") as template_tab: template_gallery = gr.Gallery( value=[(img, img) for img in preset_templates], columns=4, rows=2, height=300, object_fit="cover", show_label=False, elem_classes="gallery-item" ) selected_template = gr.Textbox(visible=False) def select_template(evt: gr.SelectData): return preset_templates[evt.index] template_gallery.select(select_template, None, selected_template) with gr.Tab("๐Ÿ“ค Upload") as upload_tab: custom_image = gr.Image( label="Upload your image", type="numpy", height=300 ) template_tab.select(fn=lambda: 0, inputs=[], outputs=[model_selected_tab]) upload_tab.select(fn=lambda: 1, inputs=[], outputs=[model_selected_tab]) gr.HTML("
โœ๏ธ Prompt
") prompt_selected_tab = gr.Number(value=0, visible=False) with gr.Tabs() as prompt_tabs: with gr.Tab("๐Ÿ“‹ Templates") as template_prompt_tab: prompt_dropdown = gr.Dropdown( choices=[f"{icon} {name}" for icon, name in template_prompts], value="๐Ÿ‘ฎ Officer", label="Choose a style", scale=1 ) # Hidden textbox to store actual prompt prompt_mapping = {f"{icon} {name}": prompt for (icon, name), (_, prompt) in zip([(icon, name) for icon, name in template_prompts], template_prompts)} prompt_selected = gr.Textbox(value=template_prompts[3][1], visible=False) def update_prompt(choice): for (icon, name), (_, prompt) in zip([(icon, name) for icon, name in template_prompts], template_prompts): if f"{icon} {name}" == choice: return prompt return template_prompts[0][1] prompt_dropdown.change(update_prompt, inputs=[prompt_dropdown], outputs=[prompt_selected]) with gr.Tab("โœ๏ธ Custom") as custom_prompt_tab: custom_prompt = gr.Textbox( label="Your prompt", placeholder="A person wearing a santa hat, festive atmosphere...", lines=3 ) custom_negative = gr.Textbox( label="Negative prompt (optional)", placeholder="blurry, low quality...", lines=2 ) template_prompt_tab.select(fn=lambda: 0, inputs=[], outputs=[prompt_selected_tab]) custom_prompt_tab.select(fn=lambda: 1, inputs=[], outputs=[prompt_selected_tab]) gr.HTML("
โš™๏ธ Generation Settings
") with gr.Row(): width = gr.Slider( label="Width", minimum=512, maximum=1280, value=896, step=64 ) height = gr.Slider( label="Height", minimum=512, maximum=1280, value=1152, step=64 ) with gr.Row(): num_steps = gr.Slider( label="Steps", minimum=20, maximum=50, value=30, step=1 ) merge_steps = gr.Slider( label="Merge Step", minimum=10, maximum=40, value=20, step=1 ) seed = gr.Slider( label="๐ŸŽฒ Seed", minimum=0, maximum=2147483647, value=42, step=1 ) generate_btn = gr.Button( "๐ŸŽจ Generate Image", variant="primary", size="lg", elem_classes="generate-btn" ) # Right column - Output with gr.Column(scale=1): gr.HTML("
๐Ÿ–ผ๏ธ Generated Result
") output_image = gr.Image( label="Output", height=600, show_label=False ) gr.HTML("""

๐Ÿ’ก Tips for Best Results:

""") gr.Markdown("""
Powered by ConsistentID-SDXL | Model Card
""") # Connect the button generate_btn.click( fn=generate_image, inputs=[ selected_template, custom_image, custom_prompt, custom_negative, prompt_selected, model_selected_tab, prompt_selected_tab, width, height, merge_steps, seed, num_steps ], outputs=output_image ) if __name__ == "__main__": demo.queue(max_size=20) demo.launch()