LPX55 commited on
Commit
b524f8f
·
verified ·
1 Parent(s): 7374c5e

Update app_inpaint2.py

Browse files
Files changed (1) hide show
  1. app_inpaint2.py +24 -36
app_inpaint2.py CHANGED
@@ -58,38 +58,35 @@ def load_default_pipeline():
58
  vae=vae,
59
  controlnet=model,
60
  ).to("cuda")
61
- # 此函数原用于 Misc 选项卡,现在无需返回 gr.update
62
  print("Default pipeline loaded!")
63
 
64
  @spaces.GPU(duration=7)
65
  def fill_image(prompt, image, model_selection, paste_back):
66
  """
67
- 处理 ImageMaskgr.ImageMask)输入的 fill/repair 流程。
68
- 在这里对用户绘制的 mask 做默认 5% 膨胀。
69
  """
70
  print(f"Received image: {image}")
71
  if image is None:
72
  yield None, None
73
  return
74
 
75
- # 如果用户选择了不同的模型 key,则加载对应预训练仓库
76
- # 注意:此逻辑原 Outpaint 中有,Inpaint 中缺失,现在补充以支持模型切换
77
- global pipe
78
- if model_selection in MODELS and pipe.config.model_name != MODELS[model_selection]:
79
- # 释放旧模型显存
80
- del pipe
81
- torch.cuda.empty_cache()
82
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
83
- MODELS[model_selection],
84
- torch_dtype=torch.float16,
85
- vae=vae,
86
- controlnet=model,
87
- variant="fp16", # 保持 variant 设置
88
- )
89
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
90
- pipe.to("cuda")
91
- print(f"Loaded new SDXL model: {pipe.config.model_name}")
92
-
93
 
94
  (
95
  prompt_embeds,
@@ -157,16 +154,10 @@ def clear_result():
157
 
158
  def use_output_as_input(output_image):
159
  """
160
- 接收 ImageSlider 的输出 (image_out, cnet_image)
161
- 返回 cnet_image 作为新的输入。
162
  """
163
- # ImageSlider 的 value 是一个 tuple (image1, image2)
164
- # 这里的 image2 (即 cnet_image) 是包含修复结果的图像 (如果 paste_back 为 True)
165
- # 或者只是 ControlNet 输入(如果 fill_image 逻辑有变)
166
- # 假设我们想要将修复后的结果图作为新的输入图像(新的 source)
167
- # 在 fill_image 中,最终 yield 的是 (final_output, cnet_image)
168
- # 我们应该使用 final_output 作为新的背景图。
169
- return gr.update(value=output_image[0]) # output_image[0] 是最终修复图像
170
 
171
  css = """
172
  .nulgradio-container {
@@ -176,7 +167,7 @@ css = """
176
  overflow-y: scroll !important;
177
  padding: 10px 40px !important;
178
  }
179
- div#component-17 { /* 这是一个动态 ID,可能需要调整或移除 */
180
  height: auto !important;
181
  }
182
 
@@ -185,7 +176,6 @@ div#component-17 { /* 这是一个动态 ID,可能需要调整或移除 */
185
  display: block !important;
186
  margin-bottom: 20px !important;
187
  }
188
- /* 移除掉 component-16 的引用,因为它是动态的 */
189
  }
190
 
191
  """
@@ -202,9 +192,8 @@ title = """<h1 align="center">Diffusers Image Inpaint</h1>
202
  """
203
 
204
  with gr.Blocks(css=css, fill_height=True) as demo:
205
- gr.Markdown(title) # 使用更简洁的 Markdown 标题
206
 
207
- # 只保留 Inpaint 选项卡的内容,移除 Tabs 结构,让 Inpaint 成为主界面
208
  with gr.Column():
209
  with gr.Row():
210
  with gr.Column():
@@ -238,7 +227,7 @@ with gr.Blocks(css=css, fill_height=True) as demo:
238
  fn=use_output_as_input,
239
  inputs=[result],
240
  outputs=[input_image],
241
- queue=False # 这是一个快速操作
242
  )
243
 
244
  # Generates image on button click
@@ -285,5 +274,4 @@ with gr.Blocks(css=css, fill_height=True) as demo:
285
  queue=False,
286
  )
287
 
288
- # 将 queue 和 launch 保持不变
289
  demo.queue(max_size=10).launch(show_error=True)
 
58
  vae=vae,
59
  controlnet=model,
60
  ).to("cuda")
 
61
  print("Default pipeline loaded!")
62
 
63
  @spaces.GPU(duration=7)
64
  def fill_image(prompt, image, model_selection, paste_back):
65
  """
66
+ Handles the fill/repair process for inputs from ImageMask (gr. ImageMask). Applies a default 5% expansion to user-drawn masks here.
 
67
  """
68
  print(f"Received image: {image}")
69
  if image is None:
70
  yield None, None
71
  return
72
 
73
+ if model_selection in MODELS:
74
+ current_model = pipe.config.get("_name_or_path", "")
75
+ target_model = MODELS[model_selection]
76
+ if current_model != target_model:
77
+ # 释放旧模型显存
78
+ del pipe
79
+ torch.cuda.empty_cache()
80
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
81
+ target_model,
82
+ torch_dtype=torch.float16,
83
+ vae=vae,
84
+ controlnet=model,
85
+ variant="fp16",
86
+ )
87
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
88
+ pipe.to("cuda")
89
+ print(f"Loaded new SDXL model: {target_model}")
 
90
 
91
  (
92
  prompt_embeds,
 
154
 
155
  def use_output_as_input(output_image):
156
  """
157
+ Receives the output of ImageSlider (image_out, cnet_image) and returns cnet_image as the new input.
 
158
  """
159
+
160
+ return gr.update(value=output_image[0])
 
 
 
 
 
161
 
162
  css = """
163
  .nulgradio-container {
 
167
  overflow-y: scroll !important;
168
  padding: 10px 40px !important;
169
  }
170
+ div#component-17 {
171
  height: auto !important;
172
  }
173
 
 
176
  display: block !important;
177
  margin-bottom: 20px !important;
178
  }
 
179
  }
180
 
181
  """
 
192
  """
193
 
194
  with gr.Blocks(css=css, fill_height=True) as demo:
195
+ gr.Markdown(title)
196
 
 
197
  with gr.Column():
198
  with gr.Row():
199
  with gr.Column():
 
227
  fn=use_output_as_input,
228
  inputs=[result],
229
  outputs=[input_image],
230
+ queue=False
231
  )
232
 
233
  # Generates image on button click
 
274
  queue=False,
275
  )
276
 
 
277
  demo.queue(max_size=10).launch(show_error=True)