LPX55 commited on
Commit
ac8f97a
·
verified ·
1 Parent(s): d06b0fa

Create app_inpaint2.py

Browse files
Files changed (1) hide show
  1. app_inpaint2.py +289 -0
app_inpaint2.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import AutoencoderKL, TCDScheduler
5
+ from diffusers.models.model_loading_utils import load_state_dict
6
+ from gradio_imageslider import ImageSlider
7
+ from huggingface_hub import hf_hub_download
8
+ from controlnet_union import ControlNetModel_Union
9
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
10
+ from PIL import Image, ImageFilter
11
+ import numpy as np
12
+ # from gradio.sketch.run import create
13
+
14
+ MODELS = {
15
+ "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
16
+ "Lustify Lightning": "GraydientPlatformAPI/lustify-lightning",
17
+ "Juggernaut XL Lightning": "RunDiffusion/Juggernaut-XL-Lightning",
18
+ "Juggernaut-XL-V9-GE-RDPhoto2": "AiWise/Juggernaut-XL-V9-GE-RDPhoto2-Lightning_4S",
19
+ "SatPony-Lightning": "John6666/satpony-lightning-v2-sdxl"
20
+ }
21
+
22
+ # --- ControlNet and Pipeline Setup (Retained) ---
23
+ config_file = hf_hub_download(
24
+ "xinsir/controlnet-union-sdxl-1.0",
25
+ filename="config_promax.json",
26
+ )
27
+ config = ControlNetModel_Union.load_config(config_file)
28
+ controlnet_model = ControlNetModel_Union.from_config(config)
29
+ model_file = hf_hub_download(
30
+ "xinsir/controlnet-union-sdxl-1.0",
31
+ filename="diffusion_pytorch_model_promax.safetensors",
32
+ )
33
+ state_dict = load_state_dict(model_file)
34
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
35
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
36
+ )
37
+ model.to(device="cuda", dtype=torch.float16)
38
+ vae = AutoencoderKL.from_pretrained(
39
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
40
+ ).to("cuda")
41
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
42
+ "SG161222/RealVisXL_V5.0_Lightning",
43
+ torch_dtype=torch.float16,
44
+ vae=vae,
45
+ controlnet=model,
46
+ variant="fp16",
47
+ )
48
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
49
+ pipe.to("cuda")
50
+ print(pipe)
51
+
52
+ def load_default_pipeline():
53
+ """仅保留,但当前 Inpaint 逻辑未直接使用,可以删除,但保留以防将来扩展。"""
54
+ global pipe
55
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
56
+ "SG161222/RealVisXL_V5.0_Lightning",
57
+ torch_dtype=torch.float16,
58
+ vae=vae,
59
+ controlnet=model,
60
+ ).to("cuda")
61
+ # 此函数原用于 Misc 选项卡,现在无需返回 gr.update
62
+ print("Default pipeline loaded!")
63
+
64
+ @spaces.GPU(duration=7)
65
+ def fill_image(prompt, image, model_selection, paste_back):
66
+ """
67
+ 处理 ImageMask(gr.ImageMask)输入的 fill/repair 流程。
68
+ 在这里对用户绘制的 mask 做默认 5% 膨胀。
69
+ """
70
+ print(f"Received image: {image}")
71
+ if image is None:
72
+ yield None, None
73
+ return
74
+
75
+ # 如果用户选择了不同的模型 key,则加载对应预训练仓库
76
+ # 注意:此逻辑原 Outpaint 中有,Inpaint 中缺失,现在补充以支持模型切换
77
+ global pipe
78
+ if model_selection in MODELS and pipe.config.model_name != MODELS[model_selection]:
79
+ # 释放旧模型显存
80
+ del pipe
81
+ torch.cuda.empty_cache()
82
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
83
+ MODELS[model_selection],
84
+ torch_dtype=torch.float16,
85
+ vae=vae,
86
+ controlnet=model,
87
+ variant="fp16", # 保持 variant 设置
88
+ )
89
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
90
+ pipe.to("cuda")
91
+ print(f"Loaded new SDXL model: {pipe.config.model_name}")
92
+
93
+
94
+ (
95
+ prompt_embeds,
96
+ negative_prompt_embeds,
97
+ pooled_prompt_embeds,
98
+ negative_pooled_prompt_embeds,
99
+ ) = pipe.encode_prompt(prompt, "cuda", True)
100
+ source = image["background"]
101
+ # 用户绘制的 mask layer(通常是 RGBA)
102
+ mask = image["layers"][0]
103
+ # 取 alpha 通道并转为二值 mask(255 表示 mask 区域)
104
+ alpha_channel = mask.split()[3]
105
+ binary_mask = alpha_channel.point(lambda p: 255 if p > 0 else 0).convert("L")
106
+
107
+ # ==== 扩大 5%(针对 fill_image 的二值 mask) ====
108
+ expand_px = max(1, int(min(binary_mask.width, binary_mask.height) * 0.05))
109
+ kernel_size = expand_px * 2 + 1
110
+ binary_mask = binary_mask.filter(ImageFilter.MaxFilter(kernel_size))
111
+ # ==== END 扩大 ====
112
+
113
+ cnet_image = source.copy()
114
+ # 在控制网络输入图上把 mask 区域填黑(以便 ControlNet/pipe 根据此区域生成)
115
+ cnet_image.paste(0, (0, 0), binary_mask)
116
+
117
+ # 调用管线(通常是生成若干中间结果,这里按原逻辑 yield)
118
+ for image_out in pipe(
119
+ prompt_embeds=prompt_embeds,
120
+ negative_prompt_embeds=negative_prompt_embeds,
121
+ pooled_prompt_embeds=pooled_prompt_embeds,
122
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
123
+ image=cnet_image,
124
+ # Inpaint 流程使用 image=cnet_image(原图 masked with black),
125
+ # 管道内部应该处理了 mask,但如果 StableDiffusionXLFillPipeline
126
+ # 需要显式 mask,这里可能需要调整。根据原代码的��名和逻辑,
127
+ # 假定 pipe(image=cnet_image) 适用于此填充流程。
128
+ ):
129
+ yield image_out, cnet_image # 这里的 yield 是为了流式输出
130
+
131
+ print(f"{model_selection=}")
132
+ print(f"{paste_back=}")
133
+ # 最后 paste 回原图(如用户选择)
134
+ if paste_back:
135
+ # image_out 是生成的修复部分
136
+ # cnet_image 在循环中已被用作 ControlNet 输入图(黑块版)
137
+ # 这里的 cnet_image 应该更新为 source.copy() 以避免和输入混淆,
138
+ # 但遵循原代码逻辑,使用 image_out + source/binary_mask
139
+
140
+ # 最终结果是 image_out(修复结果),我们将其粘贴回原图 source
141
+ # 的非 mask 区域(即只替换 mask 区域)
142
+ final_output = source.copy()
143
+ image_out_rgba = image_out.convert("RGBA")
144
+ # 使用二值 mask 的反转作为 paste 的 mask
145
+ inverted_mask = binary_mask.point(lambda p: 255 if p == 0 else 0).convert("L")
146
+
147
+ # 将 image_out 粘贴到 final_output 中,仅在 binary_mask 为 255 的区域(即修复区域)
148
+ final_output.paste(image_out_rgba, (0, 0), binary_mask)
149
+
150
+ yield final_output, cnet_image
151
+ else:
152
+ # 如果不 paste back,只返回生成的修复图像
153
+ yield image_out, cnet_image
154
+
155
+ def clear_result():
156
+ return gr.update(value=None)
157
+
158
+ def use_output_as_input(output_image):
159
+ """
160
+ 接收 ImageSlider 的输出 (image_out, cnet_image)
161
+ 返回 cnet_image 作为新的输入。
162
+ """
163
+ # ImageSlider 的 value 是一个 tuple (image1, image2)
164
+ # 这里的 image2 (即 cnet_image) 是包含修复结果的图像 (如果 paste_back 为 True)
165
+ # 或者只是 ControlNet 输入(如果 fill_image 逻辑有变)
166
+ # 假设我们想要将修复后的结果图作为新的输入图像(新的 source)
167
+ # 在 fill_image 中,最终 yield 的是 (final_output, cnet_image)
168
+ # 我们应该使用 final_output 作为新的背景图。
169
+ return gr.update(value=output_image[0]) # output_image[0] 是最终修复图像
170
+
171
+ css = """
172
+ .nulgradio-container {
173
+ width: 86vw !important;
174
+ }
175
+ .nulcontain {
176
+ overflow-y: scroll !important;
177
+ padding: 10px 40px !important;
178
+ }
179
+ div#component-17 { /* 这是一个动态 ID,可能需要调整或移除 */
180
+ height: auto !important;
181
+ }
182
+
183
+ @media screen and (max-width: 600px) {
184
+ .img-row{
185
+ display: block !important;
186
+ margin-bottom: 20px !important;
187
+ }
188
+ /* 移除掉 component-16 的引用,因为它是动态的 */
189
+ }
190
+
191
+ """
192
+
193
+ title = """<h1 align="center">Diffusers Image Inpaint</h1>
194
+ <div align="center">Upload an image, draw a mask, and enter a prompt to repair/inpaint the masked area.</div>
195
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
196
+ <p style="display: flex;gap: 6px;">
197
+ <a href="https://huggingface.co/spaces/fffiloni/diffusers-image-outpout?duplicate=true">
198
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space">
199
+ </a> to skip the queue and enjoy faster inference on the GPU of your choice
200
+ </p>
201
+ </div>
202
+ """
203
+
204
+ with gr.Blocks(css=css, fill_height=True) as demo:
205
+ gr.Markdown(title) # 使用更简洁的 Markdown 标题
206
+
207
+ # 只保留 Inpaint 选项卡的内容,移除 Tabs 结构,让 Inpaint 成为主界面
208
+ with gr.Column():
209
+ with gr.Row():
210
+ with gr.Column():
211
+ prompt = gr.Textbox(
212
+ label="Prompt",
213
+ info="Describe what to inpaint the mask with",
214
+ lines=3,
215
+ )
216
+ with gr.Column():
217
+ model_selection = gr.Dropdown(
218
+ choices=list(MODELS.keys()),
219
+ value="RealVisXL V5.0 Lightning",
220
+ label="Model",
221
+ )
222
+ with gr.Row():
223
+ run_button = gr.Button("Generate")
224
+ paste_back = gr.Checkbox(True, label="Paste back original")
225
+ with gr.Row(equal_height=False):
226
+ input_image = gr.ImageMask(
227
+ type="pil", label="Input Image", layers=True, elem_classes="img-row"
228
+ )
229
+ result = ImageSlider(
230
+ interactive=False,
231
+ label="Generated Image",
232
+ elem_classes="img-row"
233
+ )
234
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
235
+
236
+ # --- Event Handlers for Inpaint ---
237
+ use_as_input_button.click(
238
+ fn=use_output_as_input,
239
+ inputs=[result],
240
+ outputs=[input_image],
241
+ queue=False # 这是一个快速操作
242
+ )
243
+
244
+ # Generates image on button click
245
+ run_button.click(
246
+ fn=clear_result,
247
+ inputs=None,
248
+ outputs=result,
249
+ queue=False,
250
+ ).then(
251
+ fn=lambda: gr.update(visible=False),
252
+ inputs=None,
253
+ outputs=use_as_input_button,
254
+ queue=False,
255
+ ).then(
256
+ fn=fill_image,
257
+ inputs=[prompt, input_image, model_selection, paste_back],
258
+ outputs=[result],
259
+ ).then(
260
+ fn=lambda: gr.update(visible=True),
261
+ inputs=None,
262
+ outputs=use_as_input_button,
263
+ queue=False,
264
+ )
265
+
266
+ # Generates image on prompt submit
267
+ prompt.submit(
268
+ fn=clear_result,
269
+ inputs=None,
270
+ outputs=result,
271
+ queue=False,
272
+ ).then(
273
+ fn=lambda: gr.update(visible=False),
274
+ inputs=None,
275
+ outputs=use_as_input_button,
276
+ queue=False,
277
+ ).then(
278
+ fn=fill_image,
279
+ inputs=[prompt, input_image, model_selection, paste_back],
280
+ outputs=[result],
281
+ ).then(
282
+ fn=lambda: gr.update(visible=True),
283
+ inputs=None,
284
+ outputs=use_as_input_button,
285
+ queue=False,
286
+ )
287
+
288
+ # 将 queue 和 launch 保持不变
289
+ demo.queue(max_size=10).launch(show_error=True)