JackAILab's picture
Update app.py
40290c7 verified
raw
history blame
14.3 kB
import gradio as gr
import torch
import os
import glob
import spaces
import numpy as np
from datetime import datetime
from PIL import Image
from diffusers.utils import load_image
from diffusers import EulerDiscreteScheduler
from pipline_StableDiffusionXL_ConsistentID import ConsistentIDStableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from models.BiSeNet.model import BiSeNet
# ====================================================================================
# CRITICAL: Global variables for model management with ZeroGPU
# Models are loaded on CPU at startup and moved to GPU only during inference
# ====================================================================================
DEVICE = "cuda" # Device to use during inference
pipe = None # Will hold the main pipeline
bise_net = None # Will hold the face parsing model
# ====================================================================================
# Model loading function - loads all models on CPU to avoid ZeroGPU startup issues
# ====================================================================================
def load_models():
"""
Load all models on CPU at startup.
This prevents CUDA initialization errors with ZeroGPU.
Models will be moved to GPU only during inference.
"""
global pipe, bise_net
if pipe is not None:
return # Models already loaded
print("Loading models on CPU...")
# Download and prepare model paths
base_model_path = "SG161222/RealVisXL_V3.0"
consistentID_path = hf_hub_download(
repo_id="JackAILab/ConsistentID",
filename="ConsistentID_SDXL-v1.bin",
repo_type="model"
)
# Load main pipeline on CPU with fp16 precision
pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
safety_checker=None,
variant="fp16"
)
# Load BiSeNet face parsing model
bise_net_cp_path = hf_hub_download(
repo_id="JackAILab/ConsistentID",
filename="face_parsing.pth",
local_dir="./checkpoints"
)
bise_net = BiSeNet(n_classes=19)
bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu"))
# Load ConsistentID model components
pipe.load_ConsistentID_model(
os.path.dirname(consistentID_path),
bise_net,
subfolder="",
weight_name=os.path.basename(consistentID_path),
trigger_word="img",
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
print("Successfully loaded all models on CPU")
# Initialize models at startup
load_models()
# ====================================================================================
# Main inference function with ZeroGPU decorator
# ====================================================================================
@spaces.GPU(duration=120) # Request GPU for 120 seconds
def process(selected_template_images, costum_image, prompt,
negative_prompt, prompt_selected, retouching, model_selected_tab,
prompt_selected_tab, width, height, merge_steps, seed_set):
"""
Main inference function that generates images using ConsistentID.
Models are moved to GPU at the start and back to CPU at the end.
Args:
selected_template_images: Path to template image
costum_image: User uploaded image
prompt: Text prompt for generation
negative_prompt: Negative prompt
prompt_selected: Selected template prompt
retouching: Whether to apply face retouching
model_selected_tab: Which image source tab is selected
prompt_selected_tab: Which prompt tab is selected
width: Output image width
height: Output image height
merge_steps: Step to start merging facial details
seed_set: Random seed for generation
Returns:
numpy.ndarray: Generated image
"""
global pipe, bise_net
print(f"Starting inference, moving models to {DEVICE}")
# Move all model components to GPU
pipe.to(DEVICE)
pipe.image_encoder.to(DEVICE)
pipe.image_proj_model.to(DEVICE)
pipe.FacialEncoder.to(DEVICE)
bise_net.to(DEVICE)
try:
# Process input image based on selected tab
if model_selected_tab == 0:
select_images = load_image(Image.open(selected_template_images))
else:
select_images = load_image(Image.fromarray(costum_image))
# Process prompt based on selected tab
if prompt_selected_tab == 0:
prompt = prompt_selected
negative_prompt = ""
need_safetycheck = False
else:
need_safetycheck = True
# Generation parameters
num_steps = 50
# Default prompt if empty
if prompt == "":
prompt = "A person, in a forest"
# Default negative prompt if empty
if negative_prompt == "":
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
# Extend prompt with quality tags
prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
# Add negative prompt group
negtive_prompt_group = "((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
negative_prompt = negative_prompt + negtive_prompt_group
# Create generator with seed
generator = torch.Generator(device=DEVICE).manual_seed(seed_set)
print("Generating image...")
# Run the pipeline
images = pipe(
prompt=prompt,
width=width,
height=height,
input_id_images=select_images,
input_image_path=selected_template_images,
negative_prompt=negative_prompt,
num_images_per_prompt=1,
num_inference_steps=num_steps,
start_merge_step=merge_steps,
generator=generator,
retouching=retouching,
need_safetycheck=need_safetycheck,
).images[0]
print("Image generated successfully")
return np.array(images)
except Exception as e:
print(f"Error during inference: {e}")
raise
finally:
# Always move models back to CPU to free GPU memory
print("Cleaning up GPU memory")
pipe.to("cpu")
pipe.image_encoder.to("cpu")
pipe.image_proj_model.to("cpu")
pipe.FacialEncoder.to("cpu")
bise_net.to("cpu")
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ====================================================================================
# Gradio Interface
# ====================================================================================
# Get template images
script_directory = os.path.dirname(os.path.realpath(__file__))
preset_template = glob.glob("./images/templates/*.png")
preset_template = preset_template + glob.glob("./images/templates/*.jpg")
# Build Gradio interface
with gr.Blocks(title="ConsistentID_SDXL Demo") as demo:
gr.Markdown("# ConsistentID_SDXL Demo")
gr.Markdown(
"Put the reference figure to be redrawn into the box below "
"(There is a small probability of referencing failure. You can submit it repeatedly)"
)
gr.Markdown(
"If you find our work interesting, please leave a star in GitHub for us!<br>"
"https://github.com/JackAILab/ConsistentID"
)
with gr.Row():
with gr.Column():
# Hidden state for tracking which image source tab is selected
model_selected_tab = gr.Number(value=0, visible=False)
# Image source tabs
with gr.Tabs() as image_tabs:
with gr.Tab("template images") as template_images_tab:
template_gallery_list = [(i, i) for i in preset_template]
gallery = gr.Gallery(
template_gallery_list,
columns=4,
rows=2,
object_fit="contain",
height="auto",
show_label=False
)
def select_function(evt: gr.SelectData):
return preset_template[evt.index]
selected_template_images = gr.Textbox(
show_label=False,
visible=False,
placeholder="Selected"
)
gallery.select(select_function, None, selected_template_images)
with gr.Tab("Upload Image") as upload_image_tab:
costum_image = gr.Image(label="Upload Image")
# Update model_selected_tab when tab changes
def update_image_tab(tab_index):
return tab_index
template_images_tab.select(fn=lambda: 0, inputs=[], outputs=[model_selected_tab])
upload_image_tab.select(fn=lambda: 1, inputs=[], outputs=[model_selected_tab])
# Prompt section
with gr.Column():
# Hidden state for tracking which prompt tab is selected
prompt_selected_tab = gr.Number(value=0, visible=False)
# Prompt tabs
with gr.Tabs() as prompt_tabs:
with gr.Tab("template prompts") as template_prompts_tab:
prompt_selected = gr.Dropdown(
value="A person, police officer, half body shot",
choices=[
"A woman in a wedding dress",
"A woman, queen, in a gorgeous palace",
"A man sitting at the beach with sunset",
"A person, police officer, half body shot",
"A man, sailor, in a boat above ocean",
"A women wearing headphone, listening music",
"A man, firefighter, half body shot"
],
label="prepared prompts"
)
with gr.Tab("custom prompt") as custom_prompt_tab:
prompt = gr.Textbox(
label="prompt",
placeholder="A man/woman wearing a santa hat"
)
nagetive_prompt = gr.Textbox(
label="negative prompt",
placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
)
# Update prompt_selected_tab when tab changes
template_prompts_tab.select(fn=lambda: 0, inputs=[], outputs=[prompt_selected_tab])
custom_prompt_tab.select(fn=lambda: 1, inputs=[], outputs=[prompt_selected_tab])
# Generation parameters
retouching = gr.Checkbox(label="face retouching", value=False, visible=False)
width = gr.Slider(
label="image width",
minimum=512,
maximum=1280,
value=864,
step=8
)
height = gr.Slider(
label="image height",
minimum=512,
maximum=1280,
value=1152,
step=8
)
# Ensure width + height doesn't exceed 1280
width.release(lambda x, y: min(1280-x, y), inputs=[width, height], outputs=[height])
height.release(lambda x, y: min(1280-y, x), inputs=[width, height], outputs=[width])
merge_steps = gr.Slider(
label="step starting to merge facial details (30 is recommended)",
minimum=10,
maximum=50,
value=30,
step=1
)
seed_set = gr.Slider(
label="set the random seed for different results",
minimum=1,
maximum=2147483647,
value=2024,
step=1
)
btn = gr.Button("Run")
with gr.Column():
out = gr.Image(label="Output")
gr.Markdown('''
N.B.:<br/>
- If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.<br/>
- At the same time, use prompt with "man" or "woman" instead of "person" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.<br/>
- Due to ZeroGPU limitations, generation may take 1-2 minutes. Please be patient.<br/>
''')
# Connect the button to the processing function
btn.click(
fn=process,
inputs=[
selected_template_images,
costum_image,
prompt,
nagetive_prompt,
prompt_selected,
retouching,
model_selected_tab,
prompt_selected_tab,
width,
height,
merge_steps,
seed_set
],
outputs=out
)
# Launch the interface
if __name__ == "__main__":
demo.launch()