Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,222 +11,350 @@ from diffusers.utils import load_image
|
|
| 11 |
from diffusers import EulerDiscreteScheduler
|
| 12 |
from pipline_StableDiffusionXL_ConsistentID import ConsistentIDStableDiffusionXLPipeline
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
-
### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
|
| 15 |
-
### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
|
| 16 |
-
### Thanks for the open source of face-parsing model.
|
| 17 |
from models.BiSeNet.model import BiSeNet
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
#
|
| 21 |
-
#
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
#
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
#
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
bise_net,
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
pipe.
|
| 74 |
-
|
|
|
|
| 75 |
|
|
|
|
|
|
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
num_steps = 50
|
| 96 |
-
seed_set = torch.randint(0, 1000, (1,)).item()
|
| 97 |
-
# merge_steps = 30
|
| 98 |
-
|
| 99 |
-
@torch.inference_mode()
|
| 100 |
-
def Enhance_prompt(prompt,select_images):
|
| 101 |
-
|
| 102 |
-
llva_prompt = f'Please ignore the image. Enhance the following text prompt for me. You can associate more details with the character\'s gesture, environment, and decent clothing:"{prompt}".'
|
| 103 |
-
args = type('Args', (), {
|
| 104 |
-
"model_path": llva_model_path,
|
| 105 |
-
"model_base": None,
|
| 106 |
-
"model_name": get_model_name_from_path(llva_model_path),
|
| 107 |
-
"query": llva_prompt,
|
| 108 |
-
"conv_mode": None,
|
| 109 |
-
"image_file": select_images,
|
| 110 |
-
"sep": ",",
|
| 111 |
-
"temperature": 0,
|
| 112 |
-
"top_p": None,
|
| 113 |
-
"num_beams": 1,
|
| 114 |
-
"max_new_tokens": 512
|
| 115 |
-
})()
|
| 116 |
-
Enhanced_prompt = eval_model(args, llva_tokenizer, llva_model, llva_image_processor)
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
print(prompt)
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
prompt
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
script_directory = os.path.dirname(os.path.realpath(__file__))
|
| 161 |
preset_template = glob.glob("./images/templates/*.png")
|
| 162 |
preset_template = preset_template + glob.glob("./images/templates/*.jpg")
|
| 163 |
|
| 164 |
-
|
| 165 |
with gr.Blocks(title="ConsistentID_SDXL Demo") as demo:
|
| 166 |
gr.Markdown("# ConsistentID_SDXL Demo")
|
| 167 |
-
gr.Markdown(
|
| 168 |
-
Put the reference figure to be redrawn into the box below
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
with gr.Row():
|
| 173 |
with gr.Column():
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
|
|
|
|
| 192 |
with gr.Column():
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
|
| 213 |
-
width
|
| 214 |
-
height
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
btn = gr.Button("Run")
|
|
|
|
| 221 |
with gr.Column():
|
| 222 |
out = gr.Image(label="Output")
|
| 223 |
gr.Markdown('''
|
| 224 |
N.B.:<br/>
|
| 225 |
-
- If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease
|
| 226 |
-
- At the same time, use prompt with
|
| 227 |
-
- Due to
|
| 228 |
''')
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
-
|
|
|
|
|
|
|
|
|
| 11 |
from diffusers import EulerDiscreteScheduler
|
| 12 |
from pipline_StableDiffusionXL_ConsistentID import ConsistentIDStableDiffusionXLPipeline
|
| 13 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
| 14 |
from models.BiSeNet.model import BiSeNet
|
| 15 |
|
| 16 |
+
# ====================================================================================
|
| 17 |
+
# CRITICAL: Global variables for model management with ZeroGPU
|
| 18 |
+
# Models are loaded on CPU at startup and moved to GPU only during inference
|
| 19 |
+
# ====================================================================================
|
| 20 |
+
DEVICE = "cuda" # Device to use during inference
|
| 21 |
+
pipe = None # Will hold the main pipeline
|
| 22 |
+
bise_net = None # Will hold the face parsing model
|
| 23 |
|
| 24 |
+
# ====================================================================================
|
| 25 |
+
# Model loading function - loads all models on CPU to avoid ZeroGPU startup issues
|
| 26 |
+
# ====================================================================================
|
| 27 |
+
def load_models():
|
| 28 |
+
"""
|
| 29 |
+
Load all models on CPU at startup.
|
| 30 |
+
This prevents CUDA initialization errors with ZeroGPU.
|
| 31 |
+
Models will be moved to GPU only during inference.
|
| 32 |
+
"""
|
| 33 |
+
global pipe, bise_net
|
| 34 |
+
|
| 35 |
+
if pipe is not None:
|
| 36 |
+
return # Models already loaded
|
| 37 |
+
|
| 38 |
+
print("Loading models on CPU...")
|
| 39 |
+
|
| 40 |
+
# Download and prepare model paths
|
| 41 |
+
base_model_path = "SG161222/RealVisXL_V3.0"
|
| 42 |
+
consistentID_path = hf_hub_download(
|
| 43 |
+
repo_id="JackAILab/ConsistentID",
|
| 44 |
+
filename="ConsistentID_SDXL-v1.bin",
|
| 45 |
+
repo_type="model"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Load main pipeline on CPU with fp16 precision
|
| 49 |
+
pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained(
|
| 50 |
+
base_model_path,
|
| 51 |
+
torch_dtype=torch.float16,
|
| 52 |
+
safety_checker=None,
|
| 53 |
+
variant="fp16"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Load BiSeNet face parsing model
|
| 57 |
+
bise_net_cp_path = hf_hub_download(
|
| 58 |
+
repo_id="JackAILab/ConsistentID",
|
| 59 |
+
filename="face_parsing.pth",
|
| 60 |
+
local_dir="./checkpoints"
|
| 61 |
+
)
|
| 62 |
+
bise_net = BiSeNet(n_classes=19)
|
| 63 |
+
bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu"))
|
| 64 |
+
|
| 65 |
+
# Load ConsistentID model components
|
| 66 |
+
pipe.load_ConsistentID_model(
|
| 67 |
+
os.path.dirname(consistentID_path),
|
| 68 |
+
bise_net,
|
| 69 |
+
subfolder="",
|
| 70 |
+
weight_name=os.path.basename(consistentID_path),
|
| 71 |
+
trigger_word="img",
|
| 72 |
+
)
|
| 73 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 74 |
+
|
| 75 |
+
print("Successfully loaded all models on CPU")
|
| 76 |
|
| 77 |
+
# Initialize models at startup
|
| 78 |
+
load_models()
|
| 79 |
|
| 80 |
+
# ====================================================================================
|
| 81 |
+
# Main inference function with ZeroGPU decorator
|
| 82 |
+
# ====================================================================================
|
| 83 |
+
@spaces.GPU(duration=120) # Request GPU for 120 seconds
|
| 84 |
+
def process(selected_template_images, costum_image, prompt,
|
| 85 |
+
negative_prompt, prompt_selected, retouching, model_selected_tab,
|
| 86 |
+
prompt_selected_tab, width, height, merge_steps, seed_set):
|
| 87 |
+
"""
|
| 88 |
+
Main inference function that generates images using ConsistentID.
|
| 89 |
+
Models are moved to GPU at the start and back to CPU at the end.
|
| 90 |
|
| 91 |
+
Args:
|
| 92 |
+
selected_template_images: Path to template image
|
| 93 |
+
costum_image: User uploaded image
|
| 94 |
+
prompt: Text prompt for generation
|
| 95 |
+
negative_prompt: Negative prompt
|
| 96 |
+
prompt_selected: Selected template prompt
|
| 97 |
+
retouching: Whether to apply face retouching
|
| 98 |
+
model_selected_tab: Which image source tab is selected
|
| 99 |
+
prompt_selected_tab: Which prompt tab is selected
|
| 100 |
+
width: Output image width
|
| 101 |
+
height: Output image height
|
| 102 |
+
merge_steps: Step to start merging facial details
|
| 103 |
+
seed_set: Random seed for generation
|
| 104 |
|
| 105 |
+
Returns:
|
| 106 |
+
numpy.ndarray: Generated image
|
| 107 |
+
"""
|
| 108 |
+
global pipe, bise_net
|
| 109 |
+
|
| 110 |
+
print(f"Starting inference, moving models to {DEVICE}")
|
| 111 |
+
|
| 112 |
+
# Move all model components to GPU
|
| 113 |
+
pipe.to(DEVICE)
|
| 114 |
+
pipe.image_encoder.to(DEVICE)
|
| 115 |
+
pipe.image_proj_model.to(DEVICE)
|
| 116 |
+
pipe.FacialEncoder.to(DEVICE)
|
| 117 |
+
bise_net.to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
+
try:
|
| 120 |
+
# Process input image based on selected tab
|
| 121 |
+
if model_selected_tab == 0:
|
| 122 |
+
select_images = load_image(Image.open(selected_template_images))
|
| 123 |
+
else:
|
| 124 |
+
select_images = load_image(Image.fromarray(costum_image))
|
| 125 |
|
| 126 |
+
# Process prompt based on selected tab
|
| 127 |
+
if prompt_selected_tab == 0:
|
| 128 |
+
prompt = prompt_selected
|
| 129 |
+
negative_prompt = ""
|
| 130 |
+
need_safetycheck = False
|
| 131 |
+
else:
|
| 132 |
+
need_safetycheck = True
|
|
|
|
| 133 |
|
| 134 |
+
# Generation parameters
|
| 135 |
+
num_steps = 50
|
| 136 |
|
| 137 |
+
# Default prompt if empty
|
| 138 |
+
if prompt == "":
|
| 139 |
+
prompt = "A person, in a forest"
|
| 140 |
|
| 141 |
+
# Default negative prompt if empty
|
| 142 |
+
if negative_prompt == "":
|
| 143 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
|
| 144 |
+
|
| 145 |
+
# Extend prompt with quality tags
|
| 146 |
+
prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
|
| 147 |
+
|
| 148 |
+
# Add negative prompt group
|
| 149 |
+
negtive_prompt_group = "((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
|
| 150 |
+
negative_prompt = negative_prompt + negtive_prompt_group
|
| 151 |
+
|
| 152 |
+
# Create generator with seed
|
| 153 |
+
generator = torch.Generator(device=DEVICE).manual_seed(seed_set)
|
| 154 |
+
|
| 155 |
+
print("Generating image...")
|
| 156 |
+
|
| 157 |
+
# Run the pipeline
|
| 158 |
+
images = pipe(
|
| 159 |
+
prompt=prompt,
|
| 160 |
+
width=width,
|
| 161 |
+
height=height,
|
| 162 |
+
input_id_images=select_images,
|
| 163 |
+
input_image_path=selected_template_images,
|
| 164 |
+
negative_prompt=negative_prompt,
|
| 165 |
+
num_images_per_prompt=1,
|
| 166 |
+
num_inference_steps=num_steps,
|
| 167 |
+
start_merge_step=merge_steps,
|
| 168 |
+
generator=generator,
|
| 169 |
+
retouching=retouching,
|
| 170 |
+
need_safetycheck=need_safetycheck,
|
| 171 |
+
).images[0]
|
| 172 |
+
|
| 173 |
+
print("Image generated successfully")
|
| 174 |
+
return np.array(images)
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(f"Error during inference: {e}")
|
| 178 |
+
raise
|
| 179 |
+
|
| 180 |
+
finally:
|
| 181 |
+
# Always move models back to CPU to free GPU memory
|
| 182 |
+
print("Cleaning up GPU memory")
|
| 183 |
+
pipe.to("cpu")
|
| 184 |
+
pipe.image_encoder.to("cpu")
|
| 185 |
+
pipe.image_proj_model.to("cpu")
|
| 186 |
+
pipe.FacialEncoder.to("cpu")
|
| 187 |
+
bise_net.to("cpu")
|
| 188 |
+
|
| 189 |
+
# Clear CUDA cache
|
| 190 |
+
if torch.cuda.is_available():
|
| 191 |
+
torch.cuda.empty_cache()
|
| 192 |
+
|
| 193 |
+
# ====================================================================================
|
| 194 |
+
# Gradio Interface
|
| 195 |
+
# ====================================================================================
|
| 196 |
+
|
| 197 |
+
# Get template images
|
| 198 |
script_directory = os.path.dirname(os.path.realpath(__file__))
|
| 199 |
preset_template = glob.glob("./images/templates/*.png")
|
| 200 |
preset_template = preset_template + glob.glob("./images/templates/*.jpg")
|
| 201 |
|
| 202 |
+
# Build Gradio interface
|
| 203 |
with gr.Blocks(title="ConsistentID_SDXL Demo") as demo:
|
| 204 |
gr.Markdown("# ConsistentID_SDXL Demo")
|
| 205 |
+
gr.Markdown(
|
| 206 |
+
"Put the reference figure to be redrawn into the box below "
|
| 207 |
+
"(There is a small probability of referencing failure. You can submit it repeatedly)"
|
| 208 |
+
)
|
| 209 |
+
gr.Markdown(
|
| 210 |
+
"If you find our work interesting, please leave a star in GitHub for us!<br>"
|
| 211 |
+
"https://github.com/JackAILab/ConsistentID"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
with gr.Row():
|
| 215 |
with gr.Column():
|
| 216 |
+
# Hidden state for tracking which image source tab is selected
|
| 217 |
+
model_selected_tab = gr.Number(value=0, visible=False)
|
| 218 |
+
|
| 219 |
+
# Image source tabs
|
| 220 |
+
with gr.Tabs() as image_tabs:
|
| 221 |
+
with gr.Tab("template images") as template_images_tab:
|
| 222 |
+
template_gallery_list = [(i, i) for i in preset_template]
|
| 223 |
+
gallery = gr.Gallery(
|
| 224 |
+
template_gallery_list,
|
| 225 |
+
columns=4,
|
| 226 |
+
rows=2,
|
| 227 |
+
object_fit="contain",
|
| 228 |
+
height="auto",
|
| 229 |
+
show_label=False
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def select_function(evt: gr.SelectData):
|
| 233 |
+
return preset_template[evt.index]
|
| 234 |
|
| 235 |
+
selected_template_images = gr.Textbox(
|
| 236 |
+
show_label=False,
|
| 237 |
+
visible=False,
|
| 238 |
+
placeholder="Selected"
|
| 239 |
+
)
|
| 240 |
+
gallery.select(select_function, None, selected_template_images)
|
| 241 |
+
|
| 242 |
+
with gr.Tab("Upload Image") as upload_image_tab:
|
| 243 |
+
costum_image = gr.Image(label="Upload Image")
|
| 244 |
|
| 245 |
+
# Update model_selected_tab when tab changes
|
| 246 |
+
def update_image_tab(tab_index):
|
| 247 |
+
return tab_index
|
| 248 |
+
|
| 249 |
+
template_images_tab.select(fn=lambda: 0, inputs=[], outputs=[model_selected_tab])
|
| 250 |
+
upload_image_tab.select(fn=lambda: 1, inputs=[], outputs=[model_selected_tab])
|
| 251 |
|
| 252 |
+
# Prompt section
|
| 253 |
with gr.Column():
|
| 254 |
+
# Hidden state for tracking which prompt tab is selected
|
| 255 |
+
prompt_selected_tab = gr.Number(value=0, visible=False)
|
| 256 |
+
|
| 257 |
+
# Prompt tabs
|
| 258 |
+
with gr.Tabs() as prompt_tabs:
|
| 259 |
+
with gr.Tab("template prompts") as template_prompts_tab:
|
| 260 |
+
prompt_selected = gr.Dropdown(
|
| 261 |
+
value="A person, police officer, half body shot",
|
| 262 |
+
choices=[
|
| 263 |
+
"A woman in a wedding dress",
|
| 264 |
+
"A woman, queen, in a gorgeous palace",
|
| 265 |
+
"A man sitting at the beach with sunset",
|
| 266 |
+
"A person, police officer, half body shot",
|
| 267 |
+
"A man, sailor, in a boat above ocean",
|
| 268 |
+
"A women wearing headphone, listening music",
|
| 269 |
+
"A man, firefighter, half body shot"
|
| 270 |
+
],
|
| 271 |
+
label="prepared prompts"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
with gr.Tab("custom prompt") as custom_prompt_tab:
|
| 275 |
+
prompt = gr.Textbox(
|
| 276 |
+
label="prompt",
|
| 277 |
+
placeholder="A man/woman wearing a santa hat"
|
| 278 |
+
)
|
| 279 |
+
nagetive_prompt = gr.Textbox(
|
| 280 |
+
label="negative prompt",
|
| 281 |
+
placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Update prompt_selected_tab when tab changes
|
| 285 |
+
template_prompts_tab.select(fn=lambda: 0, inputs=[], outputs=[prompt_selected_tab])
|
| 286 |
+
custom_prompt_tab.select(fn=lambda: 1, inputs=[], outputs=[prompt_selected_tab])
|
| 287 |
+
|
| 288 |
+
# Generation parameters
|
| 289 |
+
retouching = gr.Checkbox(label="face retouching", value=False, visible=False)
|
| 290 |
+
|
| 291 |
+
width = gr.Slider(
|
| 292 |
+
label="image width",
|
| 293 |
+
minimum=512,
|
| 294 |
+
maximum=1280,
|
| 295 |
+
value=864,
|
| 296 |
+
step=8
|
| 297 |
+
)
|
| 298 |
|
| 299 |
+
height = gr.Slider(
|
| 300 |
+
label="image height",
|
| 301 |
+
minimum=512,
|
| 302 |
+
maximum=1280,
|
| 303 |
+
value=1152,
|
| 304 |
+
step=8
|
| 305 |
+
)
|
| 306 |
|
| 307 |
+
# Ensure width + height doesn't exceed 1280
|
| 308 |
+
width.release(lambda x, y: min(1280-x, y), inputs=[width, height], outputs=[height])
|
| 309 |
+
height.release(lambda x, y: min(1280-y, x), inputs=[width, height], outputs=[width])
|
| 310 |
+
|
| 311 |
+
merge_steps = gr.Slider(
|
| 312 |
+
label="step starting to merge facial details (30 is recommended)",
|
| 313 |
+
minimum=10,
|
| 314 |
+
maximum=50,
|
| 315 |
+
value=30,
|
| 316 |
+
step=1
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
seed_set = gr.Slider(
|
| 320 |
+
label="set the random seed for different results",
|
| 321 |
+
minimum=1,
|
| 322 |
+
maximum=2147483647,
|
| 323 |
+
value=2024,
|
| 324 |
+
step=1
|
| 325 |
+
)
|
| 326 |
|
| 327 |
btn = gr.Button("Run")
|
| 328 |
+
|
| 329 |
with gr.Column():
|
| 330 |
out = gr.Image(label="Output")
|
| 331 |
gr.Markdown('''
|
| 332 |
N.B.:<br/>
|
| 333 |
+
- If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.<br/>
|
| 334 |
+
- At the same time, use prompt with "man" or "woman" instead of "person" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.<br/>
|
| 335 |
+
- Due to ZeroGPU limitations, generation may take 1-2 minutes. Please be patient.<br/>
|
| 336 |
''')
|
| 337 |
+
|
| 338 |
+
# Connect the button to the processing function
|
| 339 |
+
btn.click(
|
| 340 |
+
fn=process,
|
| 341 |
+
inputs=[
|
| 342 |
+
selected_template_images,
|
| 343 |
+
costum_image,
|
| 344 |
+
prompt,
|
| 345 |
+
nagetive_prompt,
|
| 346 |
+
prompt_selected,
|
| 347 |
+
retouching,
|
| 348 |
+
model_selected_tab,
|
| 349 |
+
prompt_selected_tab,
|
| 350 |
+
width,
|
| 351 |
+
height,
|
| 352 |
+
merge_steps,
|
| 353 |
+
seed_set
|
| 354 |
+
],
|
| 355 |
+
outputs=out
|
| 356 |
+
)
|
| 357 |
|
| 358 |
+
# Launch the interface
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
demo.launch()
|