JackAILab commited on
Commit
d8e6a5c
·
verified ·
1 Parent(s): 40290c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -206
app.py CHANGED
@@ -14,30 +14,21 @@ from huggingface_hub import hf_hub_download
14
  from models.BiSeNet.model import BiSeNet
15
 
16
  # ====================================================================================
17
- # CRITICAL: Global variables for model management with ZeroGPU
18
- # Models are loaded on CPU at startup and moved to GPU only during inference
19
  # ====================================================================================
20
- DEVICE = "cuda" # Device to use during inference
21
- pipe = None # Will hold the main pipeline
22
- bise_net = None # Will hold the face parsing model
23
 
24
- # ====================================================================================
25
- # Model loading function - loads all models on CPU to avoid ZeroGPU startup issues
26
- # ====================================================================================
27
  def load_models():
28
- """
29
- Load all models on CPU at startup.
30
- This prevents CUDA initialization errors with ZeroGPU.
31
- Models will be moved to GPU only during inference.
32
- """
33
  global pipe, bise_net
34
 
35
  if pipe is not None:
36
- return # Models already loaded
37
 
38
- print("Loading models on CPU...")
39
 
40
- # Download and prepare model paths
41
  base_model_path = "SG161222/RealVisXL_V3.0"
42
  consistentID_path = hf_hub_download(
43
  repo_id="JackAILab/ConsistentID",
@@ -45,7 +36,7 @@ def load_models():
45
  repo_type="model"
46
  )
47
 
48
- # Load main pipeline on CPU with fp16 precision
49
  pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained(
50
  base_model_path,
51
  torch_dtype=torch.float16,
@@ -53,7 +44,7 @@ def load_models():
53
  variant="fp16"
54
  )
55
 
56
- # Load BiSeNet face parsing model
57
  bise_net_cp_path = hf_hub_download(
58
  repo_id="JackAILab/ConsistentID",
59
  filename="face_parsing.pth",
@@ -62,7 +53,7 @@ def load_models():
62
  bise_net = BiSeNet(n_classes=19)
63
  bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu"))
64
 
65
- # Load ConsistentID model components
66
  pipe.load_ConsistentID_model(
67
  os.path.dirname(consistentID_path),
68
  bise_net,
@@ -72,44 +63,36 @@ def load_models():
72
  )
73
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
74
 
75
- print("Successfully loaded all models on CPU")
76
 
77
- # Initialize models at startup
78
  load_models()
79
 
80
  # ====================================================================================
81
- # Main inference function with ZeroGPU decorator
82
  # ====================================================================================
83
- @spaces.GPU(duration=120) # Request GPU for 120 seconds
84
- def process(selected_template_images, costum_image, prompt,
85
- negative_prompt, prompt_selected, retouching, model_selected_tab,
86
- prompt_selected_tab, width, height, merge_steps, seed_set):
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
- Main inference function that generates images using ConsistentID.
89
- Models are moved to GPU at the start and back to CPU at the end.
90
-
91
- Args:
92
- selected_template_images: Path to template image
93
- costum_image: User uploaded image
94
- prompt: Text prompt for generation
95
- negative_prompt: Negative prompt
96
- prompt_selected: Selected template prompt
97
- retouching: Whether to apply face retouching
98
- model_selected_tab: Which image source tab is selected
99
- prompt_selected_tab: Which prompt tab is selected
100
- width: Output image width
101
- height: Output image height
102
- merge_steps: Step to start merging facial details
103
- seed_set: Random seed for generation
104
-
105
- Returns:
106
- numpy.ndarray: Generated image
107
  """
108
  global pipe, bise_net
109
 
110
- print(f"Starting inference, moving models to {DEVICE}")
111
 
112
- # Move all model components to GPU
113
  pipe.to(DEVICE)
114
  pipe.image_encoder.to(DEVICE)
115
  pipe.image_proj_model.to(DEVICE)
@@ -117,13 +100,13 @@ def process(selected_template_images, costum_image, prompt,
117
  bise_net.to(DEVICE)
118
 
119
  try:
120
- # Process input image based on selected tab
121
  if model_selected_tab == 0:
122
- select_images = load_image(Image.open(selected_template_images))
123
  else:
124
- select_images = load_image(Image.fromarray(costum_image))
125
 
126
- # Process prompt based on selected tab
127
  if prompt_selected_tab == 0:
128
  prompt = prompt_selected
129
  negative_prompt = ""
@@ -131,230 +114,328 @@ def process(selected_template_images, costum_image, prompt,
131
  else:
132
  need_safetycheck = True
133
 
134
- # Generation parameters
135
- num_steps = 50
 
136
 
137
- # Default prompt if empty
138
- if prompt == "":
139
- prompt = "A person, in a forest"
140
-
141
- # Default negative prompt if empty
142
- if negative_prompt == "":
143
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
144
 
145
- # Extend prompt with quality tags
146
- prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
147
 
148
- # Add negative prompt group
149
- negtive_prompt_group = "((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
150
- negative_prompt = negative_prompt + negtive_prompt_group
151
 
152
- # Create generator with seed
153
- generator = torch.Generator(device=DEVICE).manual_seed(seed_set)
154
 
155
- print("Generating image...")
156
 
157
- # Run the pipeline
158
  images = pipe(
159
- prompt=prompt,
160
  width=width,
161
  height=height,
162
- input_id_images=select_images,
163
- input_image_path=selected_template_images,
164
- negative_prompt=negative_prompt,
165
  num_images_per_prompt=1,
166
  num_inference_steps=num_steps,
167
  start_merge_step=merge_steps,
168
  generator=generator,
169
- retouching=retouching,
170
  need_safetycheck=need_safetycheck,
171
  ).images[0]
172
 
173
- print("Image generated successfully")
174
  return np.array(images)
175
 
176
  except Exception as e:
177
- print(f"Error during inference: {e}")
178
  raise
179
 
180
  finally:
181
- # Always move models back to CPU to free GPU memory
182
- print("Cleaning up GPU memory")
183
  pipe.to("cpu")
184
  pipe.image_encoder.to("cpu")
185
  pipe.image_proj_model.to("cpu")
186
  pipe.FacialEncoder.to("cpu")
187
  bise_net.to("cpu")
188
 
189
- # Clear CUDA cache
190
  if torch.cuda.is_available():
191
  torch.cuda.empty_cache()
192
 
193
  # ====================================================================================
194
- # Gradio Interface
195
  # ====================================================================================
196
 
197
  # Get template images
198
- script_directory = os.path.dirname(os.path.realpath(__file__))
199
- preset_template = glob.glob("./images/templates/*.png")
200
- preset_template = preset_template + glob.glob("./images/templates/*.jpg")
201
 
202
- # Build Gradio interface
203
- with gr.Blocks(title="ConsistentID_SDXL Demo") as demo:
204
- gr.Markdown("# ConsistentID_SDXL Demo")
205
- gr.Markdown(
206
- "Put the reference figure to be redrawn into the box below "
207
- "(There is a small probability of referencing failure. You can submit it repeatedly)"
208
- )
209
- gr.Markdown(
210
- "If you find our work interesting, please leave a star in GitHub for us!<br>"
211
- "https://github.com/JackAILab/ConsistentID"
212
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  with gr.Row():
215
- with gr.Column():
216
- # Hidden state for tracking which image source tab is selected
 
 
217
  model_selected_tab = gr.Number(value=0, visible=False)
218
 
219
- # Image source tabs
220
  with gr.Tabs() as image_tabs:
221
- with gr.Tab("template images") as template_images_tab:
222
- template_gallery_list = [(i, i) for i in preset_template]
223
- gallery = gr.Gallery(
224
- template_gallery_list,
225
  columns=4,
226
  rows=2,
227
- object_fit="contain",
228
- height="auto",
229
- show_label=False
230
- )
231
-
232
- def select_function(evt: gr.SelectData):
233
- return preset_template[evt.index]
234
-
235
- selected_template_images = gr.Textbox(
236
  show_label=False,
237
- visible=False,
238
- placeholder="Selected"
239
  )
240
- gallery.select(select_function, None, selected_template_images)
241
 
242
- with gr.Tab("Upload Image") as upload_image_tab:
243
- costum_image = gr.Image(label="Upload Image")
244
-
245
- # Update model_selected_tab when tab changes
246
- def update_image_tab(tab_index):
247
- return tab_index
 
 
 
 
 
 
 
248
 
249
- template_images_tab.select(fn=lambda: 0, inputs=[], outputs=[model_selected_tab])
250
- upload_image_tab.select(fn=lambda: 1, inputs=[], outputs=[model_selected_tab])
251
-
252
- # Prompt section
253
- with gr.Column():
254
- # Hidden state for tracking which prompt tab is selected
255
- prompt_selected_tab = gr.Number(value=0, visible=False)
256
-
257
- # Prompt tabs
258
- with gr.Tabs() as prompt_tabs:
259
- with gr.Tab("template prompts") as template_prompts_tab:
260
- prompt_selected = gr.Dropdown(
261
- value="A person, police officer, half body shot",
262
- choices=[
263
- "A woman in a wedding dress",
264
- "A woman, queen, in a gorgeous palace",
265
- "A man sitting at the beach with sunset",
266
- "A person, police officer, half body shot",
267
- "A man, sailor, in a boat above ocean",
268
- "A women wearing headphone, listening music",
269
- "A man, firefighter, half body shot"
270
- ],
271
- label="prepared prompts"
272
- )
273
-
274
- with gr.Tab("custom prompt") as custom_prompt_tab:
275
- prompt = gr.Textbox(
276
- label="prompt",
277
- placeholder="A man/woman wearing a santa hat"
278
- )
279
- nagetive_prompt = gr.Textbox(
280
- label="negative prompt",
281
- placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
282
- )
283
-
284
- # Update prompt_selected_tab when tab changes
285
- template_prompts_tab.select(fn=lambda: 0, inputs=[], outputs=[prompt_selected_tab])
286
- custom_prompt_tab.select(fn=lambda: 1, inputs=[], outputs=[prompt_selected_tab])
287
 
288
- # Generation parameters
289
- retouching = gr.Checkbox(label="face retouching", value=False, visible=False)
290
 
291
- width = gr.Slider(
292
- label="image width",
293
- minimum=512,
294
- maximum=1280,
295
- value=864,
296
- step=8
297
- )
298
 
299
- height = gr.Slider(
300
- label="image height",
301
- minimum=512,
302
- maximum=1280,
303
- value=1152,
304
- step=8
305
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
- # Ensure width + height doesn't exceed 1280
308
- width.release(lambda x, y: min(1280-x, y), inputs=[width, height], outputs=[height])
309
- height.release(lambda x, y: min(1280-y, x), inputs=[width, height], outputs=[width])
310
 
311
- merge_steps = gr.Slider(
312
- label="step starting to merge facial details (30 is recommended)",
313
- minimum=10,
314
- maximum=50,
315
- value=30,
316
- step=1
317
- )
 
 
 
 
 
 
 
 
 
 
318
 
319
- seed_set = gr.Slider(
320
- label="set the random seed for different results",
321
- minimum=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  maximum=2147483647,
323
- value=2024,
324
  step=1
325
  )
326
 
327
- btn = gr.Button("Run")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
- with gr.Column():
330
- out = gr.Image(label="Output")
331
- gr.Markdown('''
332
- N.B.:<br/>
333
- - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.<br/>
334
- - At the same time, use prompt with "man" or "woman" instead of "person" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.<br/>
335
- - Due to ZeroGPU limitations, generation may take 1-2 minutes. Please be patient.<br/>
336
- ''')
337
 
338
- # Connect the button to the processing function
339
- btn.click(
340
- fn=process,
341
  inputs=[
342
- selected_template_images,
343
- costum_image,
344
- prompt,
345
- nagetive_prompt,
346
  prompt_selected,
347
- retouching,
348
  model_selected_tab,
349
  prompt_selected_tab,
350
  width,
351
  height,
352
  merge_steps,
353
- seed_set
 
354
  ],
355
- outputs=out
356
  )
357
 
358
- # Launch the interface
359
  if __name__ == "__main__":
 
360
  demo.launch()
 
14
  from models.BiSeNet.model import BiSeNet
15
 
16
  # ====================================================================================
17
+ # Global model management for ZeroGPU compatibility
 
18
  # ====================================================================================
19
+ DEVICE = "cuda"
20
+ pipe = None
21
+ bise_net = None
22
 
 
 
 
23
  def load_models():
24
+ """Load all models on CPU to avoid ZeroGPU initialization issues"""
 
 
 
 
25
  global pipe, bise_net
26
 
27
  if pipe is not None:
28
+ return
29
 
30
+ print("Loading models on CPU...")
31
 
 
32
  base_model_path = "SG161222/RealVisXL_V3.0"
33
  consistentID_path = hf_hub_download(
34
  repo_id="JackAILab/ConsistentID",
 
36
  repo_type="model"
37
  )
38
 
39
+ # Load pipeline on CPU
40
  pipe = ConsistentIDStableDiffusionXLPipeline.from_pretrained(
41
  base_model_path,
42
  torch_dtype=torch.float16,
 
44
  variant="fp16"
45
  )
46
 
47
+ # Load BiSeNet
48
  bise_net_cp_path = hf_hub_download(
49
  repo_id="JackAILab/ConsistentID",
50
  filename="face_parsing.pth",
 
53
  bise_net = BiSeNet(n_classes=19)
54
  bise_net.load_state_dict(torch.load(bise_net_cp_path, map_location="cpu"))
55
 
56
+ # Load ConsistentID components
57
  pipe.load_ConsistentID_model(
58
  os.path.dirname(consistentID_path),
59
  bise_net,
 
63
  )
64
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
65
 
66
+ print(" Models loaded successfully")
67
 
 
68
  load_models()
69
 
70
  # ====================================================================================
71
+ # Inference function with GPU management
72
  # ====================================================================================
73
+ @spaces.GPU(duration=180) # Extended duration for SDXL
74
+ def generate_image(
75
+ selected_template_images,
76
+ custom_image,
77
+ prompt,
78
+ negative_prompt,
79
+ prompt_selected,
80
+ model_selected_tab,
81
+ prompt_selected_tab,
82
+ width,
83
+ height,
84
+ merge_steps,
85
+ seed,
86
+ num_steps
87
+ ):
88
  """
89
+ Generate image using ConsistentID-SDXL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  """
91
  global pipe, bise_net
92
 
93
+ print("🚀 Moving models to GPU...")
94
 
95
+ # Move to GPU
96
  pipe.to(DEVICE)
97
  pipe.image_encoder.to(DEVICE)
98
  pipe.image_proj_model.to(DEVICE)
 
100
  bise_net.to(DEVICE)
101
 
102
  try:
103
+ # Select input image
104
  if model_selected_tab == 0:
105
+ input_image = load_image(Image.open(selected_template_images))
106
  else:
107
+ input_image = load_image(Image.fromarray(custom_image))
108
 
109
+ # Select prompt
110
  if prompt_selected_tab == 0:
111
  prompt = prompt_selected
112
  negative_prompt = ""
 
114
  else:
115
  need_safetycheck = True
116
 
117
+ # Default prompts
118
+ if not prompt or prompt.strip() == "":
119
+ prompt = "A person, professional portrait"
120
 
121
+ if not negative_prompt or negative_prompt.strip() == "":
 
 
 
 
 
122
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
123
 
124
+ # Enhance prompt
125
+ enhanced_prompt = f"cinematic photo, {prompt}, 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
126
 
127
+ # Negative prompt enhancement
128
+ negative_enhancement = "((cross-eye)), ((cross-eyed)), (((NSFW))), (nipple), ((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))), out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
129
+ final_negative_prompt = negative_prompt + ", " + negative_enhancement
130
 
131
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
132
 
133
+ print(f"🎨 Generating with prompt: {enhanced_prompt[:100]}...")
134
 
 
135
  images = pipe(
136
+ prompt=enhanced_prompt,
137
  width=width,
138
  height=height,
139
+ input_id_images=input_image,
140
+ input_image_path=selected_template_images if model_selected_tab == 0 else None,
141
+ negative_prompt=final_negative_prompt,
142
  num_images_per_prompt=1,
143
  num_inference_steps=num_steps,
144
  start_merge_step=merge_steps,
145
  generator=generator,
146
+ retouching=False,
147
  need_safetycheck=need_safetycheck,
148
  ).images[0]
149
 
150
+ print(" Generation completed")
151
  return np.array(images)
152
 
153
  except Exception as e:
154
+ print(f"Error: {str(e)}")
155
  raise
156
 
157
  finally:
158
+ # Clean up GPU
159
+ print("🧹 Releasing GPU memory...")
160
  pipe.to("cpu")
161
  pipe.image_encoder.to("cpu")
162
  pipe.image_proj_model.to("cpu")
163
  pipe.FacialEncoder.to("cpu")
164
  bise_net.to("cpu")
165
 
 
166
  if torch.cuda.is_available():
167
  torch.cuda.empty_cache()
168
 
169
  # ====================================================================================
170
+ # Beautiful Gradio Interface
171
  # ====================================================================================
172
 
173
  # Get template images
174
+ preset_templates = glob.glob("./images/templates/*.png") + glob.glob("./images/templates/*.jpg")
 
 
175
 
176
+ # Custom CSS for beautiful interface
177
+ custom_css = """
178
+ .gradio-container {
179
+ font-family: 'IBM Plex Sans', sans-serif;
180
+ }
181
+
182
+ .main-title {
183
+ text-align: center;
184
+ font-size: 2.5em;
185
+ font-weight: 700;
186
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%);
187
+ -webkit-background-clip: text;
188
+ -webkit-text-fill-color: transparent;
189
+ margin-bottom: 1em;
190
+ }
191
+
192
+ .subtitle {
193
+ text-align: center;
194
+ font-size: 1.1em;
195
+ color: #666;
196
+ margin-bottom: 2em;
197
+ }
198
+
199
+ .section-header {
200
+ font-size: 1.3em;
201
+ font-weight: 600;
202
+ margin: 1em 0 0.5em 0;
203
+ color: #333;
204
+ }
205
+
206
+ .info-box {
207
+ background: #f8f9fa;
208
+ border-left: 4px solid #667eea;
209
+ padding: 1em;
210
+ margin: 1em 0;
211
+ border-radius: 4px;
212
+ }
213
+
214
+ .generate-btn {
215
+ background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
216
+ border: none !important;
217
+ color: white !important;
218
+ font-size: 1.1em !important;
219
+ font-weight: 600 !important;
220
+ padding: 0.8em 2em !important;
221
+ border-radius: 8px !important;
222
+ }
223
+
224
+ .gallery-item {
225
+ border-radius: 8px;
226
+ overflow: hidden;
227
+ }
228
+ """
229
+
230
+ # Template prompts with better organization
231
+ template_prompts = [
232
+ ("👰 Wedding", "A woman in an elegant wedding dress, professional photography"),
233
+ ("👑 Royalty", "A person as royalty, sitting on throne in gorgeous palace, regal attire"),
234
+ ("🏖️ Beach", "A person sitting at the beach with beautiful sunset, relaxed atmosphere"),
235
+ ("👮 Officer", "A person as police officer, professional uniform, half body shot"),
236
+ ("⛵ Sailor", "A person as sailor, on boat deck above ocean, nautical uniform"),
237
+ ("🎧 Music", "A person wearing headphones, listening to music, modern setting"),
238
+ ("🚒 Firefighter", "A person as firefighter, professional gear, half body shot"),
239
+ ("💼 Business", "A person in business attire, professional corporate environment"),
240
+ ("🎨 Artist", "A person as artist in studio, creative atmosphere, artistic clothing"),
241
+ ("🔬 Scientist", "A person as scientist in laboratory, lab coat, professional setting"),
242
+ ]
243
+
244
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="ConsistentID-SDXL") as demo:
245
+
246
+ # Header
247
+ gr.HTML("""
248
+ <div class="main-title">✨ ConsistentID-SDXL Demo ✨</div>
249
+ <div class="subtitle">
250
+ High-fidelity portrait generation with consistent identity preservation
251
+ </div>
252
+ """)
253
+
254
+ gr.Markdown("""
255
+ <div style='text-align: center; margin-bottom: 2em;'>
256
+ <a href='https://github.com/JackAILab/ConsistentID' target='_blank' style='text-decoration: none;'>
257
+ ⭐ Star us on GitHub
258
+ </a> |
259
+ <a href='https://arxiv.org/abs/2404.16771' target='_blank' style='text-decoration: none;'>
260
+ 📄 Read the Paper
261
+ </a>
262
+ </div>
263
+ """)
264
 
265
  with gr.Row():
266
+ # Left column - Inputs
267
+ with gr.Column(scale=1):
268
+ gr.HTML("<div class='section-header'>📸 Input Image</div>")
269
+
270
  model_selected_tab = gr.Number(value=0, visible=False)
271
 
 
272
  with gr.Tabs() as image_tabs:
273
+ with gr.Tab("🖼️ Templates") as template_tab:
274
+ template_gallery = gr.Gallery(
275
+ value=[(img, img) for img in preset_templates],
 
276
  columns=4,
277
  rows=2,
278
+ height=300,
279
+ object_fit="cover",
 
 
 
 
 
 
 
280
  show_label=False,
281
+ elem_classes="gallery-item"
 
282
  )
 
283
 
284
+ selected_template = gr.Textbox(visible=False)
285
+
286
+ def select_template(evt: gr.SelectData):
287
+ return preset_templates[evt.index]
288
+
289
+ template_gallery.select(select_template, None, selected_template)
290
+
291
+ with gr.Tab("📤 Upload") as upload_tab:
292
+ custom_image = gr.Image(
293
+ label="Upload your image",
294
+ type="numpy",
295
+ height=300
296
+ )
297
 
298
+ template_tab.select(fn=lambda: 0, inputs=[], outputs=[model_selected_tab])
299
+ upload_tab.select(fn=lambda: 1, inputs=[], outputs=[model_selected_tab])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
+ gr.HTML("<div class='section-header'>✍️ Prompt</div>")
 
302
 
303
+ prompt_selected_tab = gr.Number(value=0, visible=False)
 
 
 
 
 
 
304
 
305
+ with gr.Tabs() as prompt_tabs:
306
+ with gr.Tab("📋 Templates") as template_prompt_tab:
307
+ prompt_dropdown = gr.Dropdown(
308
+ choices=[f"{icon} {name}" for icon, name in template_prompts],
309
+ value="👮 Officer",
310
+ label="Choose a style",
311
+ scale=1
312
+ )
313
+
314
+ # Hidden textbox to store actual prompt
315
+ prompt_mapping = {f"{icon} {name}": prompt for (icon, name), (_, prompt) in zip([(icon, name) for icon, name in template_prompts], template_prompts)}
316
+ prompt_selected = gr.Textbox(value=template_prompts[3][1], visible=False)
317
+
318
+ def update_prompt(choice):
319
+ for (icon, name), (_, prompt) in zip([(icon, name) for icon, name in template_prompts], template_prompts):
320
+ if f"{icon} {name}" == choice:
321
+ return prompt
322
+ return template_prompts[0][1]
323
+
324
+ prompt_dropdown.change(update_prompt, inputs=[prompt_dropdown], outputs=[prompt_selected])
325
+
326
+ with gr.Tab("✏️ Custom") as custom_prompt_tab:
327
+ custom_prompt = gr.Textbox(
328
+ label="Your prompt",
329
+ placeholder="A person wearing a santa hat, festive atmosphere...",
330
+ lines=3
331
+ )
332
+ custom_negative = gr.Textbox(
333
+ label="Negative prompt (optional)",
334
+ placeholder="blurry, low quality...",
335
+ lines=2
336
+ )
337
 
338
+ template_prompt_tab.select(fn=lambda: 0, inputs=[], outputs=[prompt_selected_tab])
339
+ custom_prompt_tab.select(fn=lambda: 1, inputs=[], outputs=[prompt_selected_tab])
 
340
 
341
+ gr.HTML("<div class='section-header'>⚙️ Generation Settings</div>")
342
+
343
+ with gr.Row():
344
+ width = gr.Slider(
345
+ label="Width",
346
+ minimum=512,
347
+ maximum=1280,
348
+ value=896,
349
+ step=64
350
+ )
351
+ height = gr.Slider(
352
+ label="Height",
353
+ minimum=512,
354
+ maximum=1280,
355
+ value=1152,
356
+ step=64
357
+ )
358
 
359
+ with gr.Row():
360
+ num_steps = gr.Slider(
361
+ label="Steps",
362
+ minimum=20,
363
+ maximum=50,
364
+ value=30,
365
+ step=1
366
+ )
367
+ merge_steps = gr.Slider(
368
+ label="Merge Step",
369
+ minimum=10,
370
+ maximum=40,
371
+ value=20,
372
+ step=1
373
+ )
374
+
375
+ seed = gr.Slider(
376
+ label="🎲 Seed",
377
+ minimum=0,
378
  maximum=2147483647,
379
+ value=42,
380
  step=1
381
  )
382
 
383
+ generate_btn = gr.Button(
384
+ "🎨 Generate Image",
385
+ variant="primary",
386
+ size="lg",
387
+ elem_classes="generate-btn"
388
+ )
389
+
390
+ # Right column - Output
391
+ with gr.Column(scale=1):
392
+ gr.HTML("<div class='section-header'>🖼️ Generated Result</div>")
393
+
394
+ output_image = gr.Image(
395
+ label="Output",
396
+ height=600,
397
+ show_label=False
398
+ )
399
+
400
+ gr.HTML("""
401
+ <div class='info-box'>
402
+ <h4>💡 Tips for Best Results:</h4>
403
+ <ul>
404
+ <li>✅ Use clear face images with good lighting</li>
405
+ <li>✅ Faces should be clearly visible and not too small</li>
406
+ <li>✅ Use "man" or "woman" instead of "person" in prompts</li>
407
+ <li>⏱️ Generation takes 1-3 minutes with ZeroGPU</li>
408
+ </ul>
409
+ </div>
410
+ """)
411
 
412
+ gr.Markdown("""
413
+ <div style='text-align: center; margin-top: 2em; color: #666; font-size: 0.9em;'>
414
+ Powered by ConsistentID-SDXL |
415
+ <a href='https://huggingface.co/JackAILab/ConsistentID' target='_blank'>Model Card</a>
416
+ </div>
417
+ """)
 
 
418
 
419
+ # Connect the button
420
+ generate_btn.click(
421
+ fn=generate_image,
422
  inputs=[
423
+ selected_template,
424
+ custom_image,
425
+ custom_prompt,
426
+ custom_negative,
427
  prompt_selected,
 
428
  model_selected_tab,
429
  prompt_selected_tab,
430
  width,
431
  height,
432
  merge_steps,
433
+ seed,
434
+ num_steps
435
  ],
436
+ outputs=output_image
437
  )
438
 
 
439
  if __name__ == "__main__":
440
+ demo.queue(max_size=20)
441
  demo.launch()