JackAILab commited on
Commit
40290c7
·
verified ·
1 Parent(s): f56aad5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -185
app.py CHANGED
@@ -11,222 +11,350 @@ from diffusers.utils import load_image
11
  from diffusers import EulerDiscreteScheduler
12
  from pipline_StableDiffusionXL_ConsistentID import ConsistentIDStableDiffusionXLPipeline
13
  from huggingface_hub import hf_hub_download
14
- ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
15
- ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
16
- ### Thanks for the open source of face-parsing model.
17
  from models.BiSeNet.model import BiSeNet
18
 
19
- # zero = torch.Tensor([0]).cuda()
20
- # print(zero.device) # <-- 'cpu' 🤔
21
- # device = zero.device # "cuda"
22
- device = "cuda"
 
 
 
23
 
24
- # Gets the absolute path of the current script
25
- script_directory = os.path.dirname(os.path.realpath(__file__))
26
-
27
- # download ConsistentID checkpoint to cache
28
- base_model_path = "SG161222/RealVisXL_V3.0"
29
- consistentID_path = hf_hub_download(repo_id="JackAILab/ConsistentID", filename="ConsistentID_SDXL-v1.bin", repo_type="model")
30
-
31
- ### Load base model
32
- pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained(
33
- base_model_path,
34
- torch_dtype=torch.float16,
35
- safety_checker=None, # use_safetensors=True,
36
- variant="fp16"
37
- ).to(device)
38
-
39
- ### Load other pretrained models
40
- ## BiSenet
41
- bise_net_cp_path = hf_hub_download(repo_id="JackAILab/ConsistentID", filename="face_parsing.pth", local_dir="./checkpoints")
42
- bise_net = BiSeNet(n_classes = 19)
43
- bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu")) # device fail
44
- bise_net.cuda()
45
-
46
- # import sys
47
- # sys.path.append("./models/LLaVA1.5/LLaVA/")
48
- # from llava_infer.model.builder import load_pretrained_model
49
- # from llava_infer.mm_utils import get_model_name_from_path
50
- # from llava_infer.eval.run_llava import eval_model
51
-
52
- ### Load Llava for prompt enhancement
53
- # llva_model_path = "liuhaotian/llava-v1.5-7b"
54
- # llva_tokenizer, llva_model, llva_image_processor, llva_context_len = load_pretrained_model(
55
- # model_path=llva_model_path,
56
- # model_base=None,
57
- # model_name=get_model_name_from_path(llva_model_path),)
58
- # llva_model.to(device)
59
-
60
- ### Load consistentID_model checkpoint
61
- pipe.load_ConsistentID_model(
62
- os.path.dirname(consistentID_path),
63
- bise_net,
64
- subfolder="",
65
- weight_name=os.path.basename(consistentID_path),
66
- trigger_word="img",
67
- )
68
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
69
-
70
- ### Load to cuda
71
- pipe.to(device)
72
- pipe.image_encoder.to(device)
73
- pipe.image_proj_model.to(device)
74
- pipe.FacialEncoder.to(device)
 
75
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- @spaces.GPU
79
- def process(selected_template_images,costum_image,prompt
80
- ,negative_prompt,prompt_selected,retouching,model_selected_tab,prompt_selected_tab,width,height,merge_steps,seed_set):
 
 
 
 
 
 
 
 
 
 
81
 
82
- if model_selected_tab==0:
83
- select_images = load_image(Image.open(selected_template_images))
84
- else:
85
- select_images = load_image(Image.fromarray(costum_image))
86
-
87
- if prompt_selected_tab==0:
88
- prompt = prompt_selected
89
- negative_prompt = ""
90
- need_safetycheck = False
91
- else:
92
- need_safetycheck = True
93
-
94
- # hyper-parameter
95
- num_steps = 50
96
- seed_set = torch.randint(0, 1000, (1,)).item()
97
- # merge_steps = 30
98
-
99
- @torch.inference_mode()
100
- def Enhance_prompt(prompt,select_images):
101
-
102
- llva_prompt = f'Please ignore the image. Enhance the following text prompt for me. You can associate more details with the character\'s gesture, environment, and decent clothing:"{prompt}".'
103
- args = type('Args', (), {
104
- "model_path": llva_model_path,
105
- "model_base": None,
106
- "model_name": get_model_name_from_path(llva_model_path),
107
- "query": llva_prompt,
108
- "conv_mode": None,
109
- "image_file": select_images,
110
- "sep": ",",
111
- "temperature": 0,
112
- "top_p": None,
113
- "num_beams": 1,
114
- "max_new_tokens": 512
115
- })()
116
- Enhanced_prompt = eval_model(args, llva_tokenizer, llva_model, llva_image_processor)
117
 
118
- return Enhanced_prompt
 
 
 
 
 
119
 
120
- if prompt == "":
121
- prompt = "A woman, in a forest"
122
- prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
123
- prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
124
- else:
125
- # prompt=Enhance_prompt(prompt,Image.new('RGB', (200, 200), color = 'white'))
126
- # prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
127
- print(prompt)
128
 
129
- if negative_prompt == "":
130
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
131
 
132
- #Extend Prompt
133
- prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
 
134
 
135
- negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
136
- negative_prompt = negative_prompt + negtive_prompt_group
137
-
138
- # seed = torch.randint(0, 1000, (1,)).item()
139
- generator = torch.Generator(device=device).manual_seed(seed_set)
140
-
141
- images = pipe(
142
- prompt=prompt,
143
- width=width,
144
- height=height,
145
- input_id_images=select_images,
146
- input_image_path=selected_template_images, ### path maybe not right, do not use
147
- negative_prompt=negative_prompt,
148
- num_images_per_prompt=1,
149
- num_inference_steps=num_steps,
150
- start_merge_step=merge_steps,
151
- generator=generator,
152
- retouching=retouching,
153
- need_safetycheck=need_safetycheck,
154
- ).images[0]
155
-
156
- current_date = datetime.today()
157
- return np.array(images)
158
-
159
- # Gets the templates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  script_directory = os.path.dirname(os.path.realpath(__file__))
161
  preset_template = glob.glob("./images/templates/*.png")
162
  preset_template = preset_template + glob.glob("./images/templates/*.jpg")
163
 
164
-
165
  with gr.Blocks(title="ConsistentID_SDXL Demo") as demo:
166
  gr.Markdown("# ConsistentID_SDXL Demo")
167
- gr.Markdown("\
168
- Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
169
- gr.Markdown("\
170
- If you find our work interesting, please leave a star in GitHub for us!<br>\
171
- https://github.com/JackAILab/ConsistentID")
 
 
 
 
172
  with gr.Row():
173
  with gr.Column():
174
- model_selected_tab = gr.State(0)
175
- with gr.TabItem("template images") as template_images_tab:
176
- template_gallery_list = [(i, i) for i in preset_template]
177
- gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
178
-
179
- def select_function(evt: gr.SelectData):
180
- return preset_template[evt.index]
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
183
- print(f"=========selected_template_images : {selected_template_images}=============== \r\n ")
184
- gallery.select(select_function, None, selected_template_images)
185
- with gr.TabItem("Upload Image") as upload_image_tab:
186
- costum_image = gr.Image(label="Upload Image")
 
 
 
 
187
 
188
- model_selected_tabs = [template_images_tab, upload_image_tab]
189
- for i, tab in enumerate(model_selected_tabs):
190
- tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
 
 
 
191
 
 
192
  with gr.Column():
193
- prompt_selected_tab = gr.State(0)
194
- with gr.TabItem("template prompts") as template_prompts_tab:
195
- prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
196
- "A woman in a wedding dress",
197
- "A woman, queen, in a gorgeous palace",
198
- "A man sitting at the beach with sunset",
199
- "A person, police officer, half body shot",
200
- "A man, sailor, in a boat above ocean",
201
- "A women wearing headphone, listening music",
202
- "A man, firefighter, half body shot"], label=f"prepared prompts")
203
-
204
- with gr.TabItem("custom prompt") as custom_prompt_tab:
205
- prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
206
- nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
209
- for i, tab in enumerate(prompt_selected_tabs):
210
- tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
 
 
 
 
211
 
212
- retouching = gr.Checkbox(label="face retouching",value=False,visible=False)
213
- width = gr.Slider(label="image width",minimum=512,maximum=1280,value=864,step=8)
214
- height = gr.Slider(label="image height",minimum=512,maximum=1280,value=1152,step=8)
215
- width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
216
- height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
217
- merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
218
- seed_set = gr.Slider(label="set the random seed for different results",minimum=1,maximum=2147483647,value=2024,step=1)
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  btn = gr.Button("Run")
 
221
  with gr.Column():
222
  out = gr.Image(label="Output")
223
  gr.Markdown('''
224
  N.B.:<br/>
225
- - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
226
- - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
227
- - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
228
  ''')
229
- btn.click(fn=process, inputs=[selected_template_images,costum_image,prompt,nagetive_prompt,prompt_selected,retouching
230
- ,model_selected_tab,prompt_selected_tab,width,height,merge_steps,seed_set], outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- demo.launch()
 
 
 
11
  from diffusers import EulerDiscreteScheduler
12
  from pipline_StableDiffusionXL_ConsistentID import ConsistentIDStableDiffusionXLPipeline
13
  from huggingface_hub import hf_hub_download
 
 
 
14
  from models.BiSeNet.model import BiSeNet
15
 
16
+ # ====================================================================================
17
+ # CRITICAL: Global variables for model management with ZeroGPU
18
+ # Models are loaded on CPU at startup and moved to GPU only during inference
19
+ # ====================================================================================
20
+ DEVICE = "cuda" # Device to use during inference
21
+ pipe = None # Will hold the main pipeline
22
+ bise_net = None # Will hold the face parsing model
23
 
24
+ # ====================================================================================
25
+ # Model loading function - loads all models on CPU to avoid ZeroGPU startup issues
26
+ # ====================================================================================
27
+ def load_models():
28
+ """
29
+ Load all models on CPU at startup.
30
+ This prevents CUDA initialization errors with ZeroGPU.
31
+ Models will be moved to GPU only during inference.
32
+ """
33
+ global pipe, bise_net
34
+
35
+ if pipe is not None:
36
+ return # Models already loaded
37
+
38
+ print("Loading models on CPU...")
39
+
40
+ # Download and prepare model paths
41
+ base_model_path = "SG161222/RealVisXL_V3.0"
42
+ consistentID_path = hf_hub_download(
43
+ repo_id="JackAILab/ConsistentID",
44
+ filename="ConsistentID_SDXL-v1.bin",
45
+ repo_type="model"
46
+ )
47
+
48
+ # Load main pipeline on CPU with fp16 precision
49
+ pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained(
50
+ base_model_path,
51
+ torch_dtype=torch.float16,
52
+ safety_checker=None,
53
+ variant="fp16"
54
+ )
55
+
56
+ # Load BiSeNet face parsing model
57
+ bise_net_cp_path = hf_hub_download(
58
+ repo_id="JackAILab/ConsistentID",
59
+ filename="face_parsing.pth",
60
+ local_dir="./checkpoints"
61
+ )
62
+ bise_net = BiSeNet(n_classes=19)
63
+ bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu"))
64
+
65
+ # Load ConsistentID model components
66
+ pipe.load_ConsistentID_model(
67
+ os.path.dirname(consistentID_path),
68
+ bise_net,
69
+ subfolder="",
70
+ weight_name=os.path.basename(consistentID_path),
71
+ trigger_word="img",
72
+ )
73
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
74
+
75
+ print("Successfully loaded all models on CPU")
76
 
77
+ # Initialize models at startup
78
+ load_models()
79
 
80
+ # ====================================================================================
81
+ # Main inference function with ZeroGPU decorator
82
+ # ====================================================================================
83
+ @spaces.GPU(duration=120) # Request GPU for 120 seconds
84
+ def process(selected_template_images, costum_image, prompt,
85
+ negative_prompt, prompt_selected, retouching, model_selected_tab,
86
+ prompt_selected_tab, width, height, merge_steps, seed_set):
87
+ """
88
+ Main inference function that generates images using ConsistentID.
89
+ Models are moved to GPU at the start and back to CPU at the end.
90
 
91
+ Args:
92
+ selected_template_images: Path to template image
93
+ costum_image: User uploaded image
94
+ prompt: Text prompt for generation
95
+ negative_prompt: Negative prompt
96
+ prompt_selected: Selected template prompt
97
+ retouching: Whether to apply face retouching
98
+ model_selected_tab: Which image source tab is selected
99
+ prompt_selected_tab: Which prompt tab is selected
100
+ width: Output image width
101
+ height: Output image height
102
+ merge_steps: Step to start merging facial details
103
+ seed_set: Random seed for generation
104
 
105
+ Returns:
106
+ numpy.ndarray: Generated image
107
+ """
108
+ global pipe, bise_net
109
+
110
+ print(f"Starting inference, moving models to {DEVICE}")
111
+
112
+ # Move all model components to GPU
113
+ pipe.to(DEVICE)
114
+ pipe.image_encoder.to(DEVICE)
115
+ pipe.image_proj_model.to(DEVICE)
116
+ pipe.FacialEncoder.to(DEVICE)
117
+ bise_net.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ try:
120
+ # Process input image based on selected tab
121
+ if model_selected_tab == 0:
122
+ select_images = load_image(Image.open(selected_template_images))
123
+ else:
124
+ select_images = load_image(Image.fromarray(costum_image))
125
 
126
+ # Process prompt based on selected tab
127
+ if prompt_selected_tab == 0:
128
+ prompt = prompt_selected
129
+ negative_prompt = ""
130
+ need_safetycheck = False
131
+ else:
132
+ need_safetycheck = True
 
133
 
134
+ # Generation parameters
135
+ num_steps = 50
136
 
137
+ # Default prompt if empty
138
+ if prompt == "":
139
+ prompt = "A person, in a forest"
140
 
141
+ # Default negative prompt if empty
142
+ if negative_prompt == "":
143
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
144
+
145
+ # Extend prompt with quality tags
146
+ prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
147
+
148
+ # Add negative prompt group
149
+ negtive_prompt_group = "((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
150
+ negative_prompt = negative_prompt + negtive_prompt_group
151
+
152
+ # Create generator with seed
153
+ generator = torch.Generator(device=DEVICE).manual_seed(seed_set)
154
+
155
+ print("Generating image...")
156
+
157
+ # Run the pipeline
158
+ images = pipe(
159
+ prompt=prompt,
160
+ width=width,
161
+ height=height,
162
+ input_id_images=select_images,
163
+ input_image_path=selected_template_images,
164
+ negative_prompt=negative_prompt,
165
+ num_images_per_prompt=1,
166
+ num_inference_steps=num_steps,
167
+ start_merge_step=merge_steps,
168
+ generator=generator,
169
+ retouching=retouching,
170
+ need_safetycheck=need_safetycheck,
171
+ ).images[0]
172
+
173
+ print("Image generated successfully")
174
+ return np.array(images)
175
+
176
+ except Exception as e:
177
+ print(f"Error during inference: {e}")
178
+ raise
179
+
180
+ finally:
181
+ # Always move models back to CPU to free GPU memory
182
+ print("Cleaning up GPU memory")
183
+ pipe.to("cpu")
184
+ pipe.image_encoder.to("cpu")
185
+ pipe.image_proj_model.to("cpu")
186
+ pipe.FacialEncoder.to("cpu")
187
+ bise_net.to("cpu")
188
+
189
+ # Clear CUDA cache
190
+ if torch.cuda.is_available():
191
+ torch.cuda.empty_cache()
192
+
193
+ # ====================================================================================
194
+ # Gradio Interface
195
+ # ====================================================================================
196
+
197
+ # Get template images
198
  script_directory = os.path.dirname(os.path.realpath(__file__))
199
  preset_template = glob.glob("./images/templates/*.png")
200
  preset_template = preset_template + glob.glob("./images/templates/*.jpg")
201
 
202
+ # Build Gradio interface
203
  with gr.Blocks(title="ConsistentID_SDXL Demo") as demo:
204
  gr.Markdown("# ConsistentID_SDXL Demo")
205
+ gr.Markdown(
206
+ "Put the reference figure to be redrawn into the box below "
207
+ "(There is a small probability of referencing failure. You can submit it repeatedly)"
208
+ )
209
+ gr.Markdown(
210
+ "If you find our work interesting, please leave a star in GitHub for us!<br>"
211
+ "https://github.com/JackAILab/ConsistentID"
212
+ )
213
+
214
  with gr.Row():
215
  with gr.Column():
216
+ # Hidden state for tracking which image source tab is selected
217
+ model_selected_tab = gr.Number(value=0, visible=False)
218
+
219
+ # Image source tabs
220
+ with gr.Tabs() as image_tabs:
221
+ with gr.Tab("template images") as template_images_tab:
222
+ template_gallery_list = [(i, i) for i in preset_template]
223
+ gallery = gr.Gallery(
224
+ template_gallery_list,
225
+ columns=4,
226
+ rows=2,
227
+ object_fit="contain",
228
+ height="auto",
229
+ show_label=False
230
+ )
231
+
232
+ def select_function(evt: gr.SelectData):
233
+ return preset_template[evt.index]
234
 
235
+ selected_template_images = gr.Textbox(
236
+ show_label=False,
237
+ visible=False,
238
+ placeholder="Selected"
239
+ )
240
+ gallery.select(select_function, None, selected_template_images)
241
+
242
+ with gr.Tab("Upload Image") as upload_image_tab:
243
+ costum_image = gr.Image(label="Upload Image")
244
 
245
+ # Update model_selected_tab when tab changes
246
+ def update_image_tab(tab_index):
247
+ return tab_index
248
+
249
+ template_images_tab.select(fn=lambda: 0, inputs=[], outputs=[model_selected_tab])
250
+ upload_image_tab.select(fn=lambda: 1, inputs=[], outputs=[model_selected_tab])
251
 
252
+ # Prompt section
253
  with gr.Column():
254
+ # Hidden state for tracking which prompt tab is selected
255
+ prompt_selected_tab = gr.Number(value=0, visible=False)
256
+
257
+ # Prompt tabs
258
+ with gr.Tabs() as prompt_tabs:
259
+ with gr.Tab("template prompts") as template_prompts_tab:
260
+ prompt_selected = gr.Dropdown(
261
+ value="A person, police officer, half body shot",
262
+ choices=[
263
+ "A woman in a wedding dress",
264
+ "A woman, queen, in a gorgeous palace",
265
+ "A man sitting at the beach with sunset",
266
+ "A person, police officer, half body shot",
267
+ "A man, sailor, in a boat above ocean",
268
+ "A women wearing headphone, listening music",
269
+ "A man, firefighter, half body shot"
270
+ ],
271
+ label="prepared prompts"
272
+ )
273
+
274
+ with gr.Tab("custom prompt") as custom_prompt_tab:
275
+ prompt = gr.Textbox(
276
+ label="prompt",
277
+ placeholder="A man/woman wearing a santa hat"
278
+ )
279
+ nagetive_prompt = gr.Textbox(
280
+ label="negative prompt",
281
+ placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
282
+ )
283
+
284
+ # Update prompt_selected_tab when tab changes
285
+ template_prompts_tab.select(fn=lambda: 0, inputs=[], outputs=[prompt_selected_tab])
286
+ custom_prompt_tab.select(fn=lambda: 1, inputs=[], outputs=[prompt_selected_tab])
287
+
288
+ # Generation parameters
289
+ retouching = gr.Checkbox(label="face retouching", value=False, visible=False)
290
+
291
+ width = gr.Slider(
292
+ label="image width",
293
+ minimum=512,
294
+ maximum=1280,
295
+ value=864,
296
+ step=8
297
+ )
298
 
299
+ height = gr.Slider(
300
+ label="image height",
301
+ minimum=512,
302
+ maximum=1280,
303
+ value=1152,
304
+ step=8
305
+ )
306
 
307
+ # Ensure width + height doesn't exceed 1280
308
+ width.release(lambda x, y: min(1280-x, y), inputs=[width, height], outputs=[height])
309
+ height.release(lambda x, y: min(1280-y, x), inputs=[width, height], outputs=[width])
310
+
311
+ merge_steps = gr.Slider(
312
+ label="step starting to merge facial details (30 is recommended)",
313
+ minimum=10,
314
+ maximum=50,
315
+ value=30,
316
+ step=1
317
+ )
318
+
319
+ seed_set = gr.Slider(
320
+ label="set the random seed for different results",
321
+ minimum=1,
322
+ maximum=2147483647,
323
+ value=2024,
324
+ step=1
325
+ )
326
 
327
  btn = gr.Button("Run")
328
+
329
  with gr.Column():
330
  out = gr.Image(label="Output")
331
  gr.Markdown('''
332
  N.B.:<br/>
333
+ - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.<br/>
334
+ - At the same time, use prompt with "man" or "woman" instead of "person" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.<br/>
335
+ - Due to ZeroGPU limitations, generation may take 1-2 minutes. Please be patient.<br/>
336
  ''')
337
+
338
+ # Connect the button to the processing function
339
+ btn.click(
340
+ fn=process,
341
+ inputs=[
342
+ selected_template_images,
343
+ costum_image,
344
+ prompt,
345
+ nagetive_prompt,
346
+ prompt_selected,
347
+ retouching,
348
+ model_selected_tab,
349
+ prompt_selected_tab,
350
+ width,
351
+ height,
352
+ merge_steps,
353
+ seed_set
354
+ ],
355
+ outputs=out
356
+ )
357
 
358
+ # Launch the interface
359
+ if __name__ == "__main__":
360
+ demo.launch()