telcom commited on
Commit
4e13a1e
·
verified ·
1 Parent(s): def5e5a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gc
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from diffusers import (
11
+ StableDiffusionXLPipeline,
12
+ StableDiffusionXLImg2ImgPipeline,
13
+ EulerAncestralDiscreteScheduler,
14
+ )
15
+ from huggingface_hub import login
16
+
17
+ # ============================================================
18
+ # GPU decorator (optional)
19
+ # ============================================================
20
+ try:
21
+ import spaces
22
+ GPU_DECORATOR = spaces.GPU
23
+ except Exception:
24
+ def GPU_DECORATOR(fn):
25
+ return fn
26
+
27
+ from compel import CompelForSDXL
28
+
29
+ MODEL_ID = "telcom/dee-unlearning-tiny-sd"
30
+ REVISION="main"
31
+
32
+ HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
33
+ if HF_TOKEN:
34
+ login(token=HF_TOKEN)
35
+
36
+ # ============================================================
37
+ # Detect device
38
+ # ============================================================
39
+ cuda_available = torch.cuda.is_available()
40
+ device = torch.device("cuda" if cuda_available else "cpu")
41
+ dtype = torch.float16 if cuda_available else torch.float32
42
+
43
+ MAX_SEED = np.iinfo(np.int32).max
44
+ MAX_IMAGE_SIZE = 1216 if cuda_available else 768 # CPU smaller
45
+
46
+ pipe_txt2img = None
47
+ pipe_img2img = None
48
+ compel = None
49
+ model_loaded = False
50
+ load_error = None
51
+ fallback_msg = ""
52
+
53
+
54
+ # ============================================================
55
+ # Load model (txt2img + img2img sharing weights)
56
+ # ============================================================
57
+ try:
58
+ from_pretrained_kwargs = dict(
59
+ torch_dtype=dtype,
60
+ use_safetensors=True,
61
+ )
62
+
63
+ if cuda_available:
64
+ from_pretrained_kwargs["variant"] = "fp16"
65
+
66
+ if HF_TOKEN:
67
+ from_pretrained_kwargs["token"] = HF_TOKEN
68
+
69
+ # Base txt2img pipeline revision=REVISION,
70
+ pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
71
+ MODEL_ID, revision=REVISION, **from_pretrained_kwargs
72
+ )
73
+ pipe_txt2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
74
+ pipe_txt2img.scheduler.config
75
+ )
76
+ pipe_txt2img = pipe_txt2img.to(device)
77
+
78
+ # Memory opts
79
+ try:
80
+ pipe_txt2img.enable_vae_slicing()
81
+ except Exception:
82
+ pass
83
+ try:
84
+ pipe_txt2img.enable_attention_slicing()
85
+ except Exception:
86
+ pass
87
+ try:
88
+ pipe_txt2img.enable_xformers_memory_efficient_attention()
89
+ except Exception:
90
+ pass
91
+
92
+ pipe_txt2img.set_progress_bar_config(disable=True)
93
+
94
+ # Create img2img pipeline from txt2img components (no extra weights)
95
+ pipe_img2img = StableDiffusionXLImg2ImgPipeline(**pipe_txt2img.components)
96
+ pipe_img2img.scheduler = EulerAncestralDiscreteScheduler.from_config(
97
+ pipe_img2img.scheduler.config
98
+ )
99
+ pipe_img2img = pipe_img2img.to(device)
100
+
101
+ try:
102
+ compel = CompelForSDXL(pipe_txt2img, device=str(device))
103
+ except TypeError:
104
+ compel = CompelForSDXL(pipe_txt2img)
105
+
106
+ model_loaded = True
107
+
108
+ except Exception as e:
109
+ load_error = repr(e)
110
+ model_loaded = False
111
+
112
+
113
+ if not cuda_available:
114
+ fallback_msg = "GPU unavailable. Running in CPU fallback mode (slower, smaller images)."
115
+
116
+
117
+ # ============================================================
118
+ # Error image
119
+ # ============================================================
120
+ def _make_error_image(w, h, text):
121
+ img = Image.new("RGB", (w, h), (18, 18, 22))
122
+ return img
123
+
124
+
125
+ # ============================================================
126
+ # Inference (txt2img or img2img depending on init_image)
127
+ # ============================================================
128
+ @GPU_DECORATOR
129
+ def infer(
130
+ prompt,
131
+ negative_prompt,
132
+ seed,
133
+ randomize_seed,
134
+ width,
135
+ height,
136
+ guidance_scale,
137
+ num_inference_steps,
138
+ init_image, # new: optional image
139
+ strength, # new: img2img strength
140
+ ):
141
+ width = int(width)
142
+ height = int(height)
143
+ seed = int(seed)
144
+
145
+ if not model_loaded or pipe_txt2img is None or pipe_img2img is None or compel is None:
146
+ msg = "Model failed to load."
147
+ if load_error:
148
+ msg += f" (details: {load_error})"
149
+ return _make_error_image(width, height, msg), msg
150
+
151
+ # Randomize seed if requested
152
+ if randomize_seed:
153
+ seed = random.randint(0, MAX_SEED)
154
+
155
+ if device.type == "cuda":
156
+ generator = torch.Generator(device=device).manual_seed(seed)
157
+ else:
158
+ generator = torch.Generator().manual_seed(seed)
159
+
160
+ status = f"Seed: {seed}"
161
+ if fallback_msg:
162
+ status += f" | {fallback_msg}"
163
+
164
+ try:
165
+ with torch.inference_mode():
166
+ conditioning = compel(prompt, negative_prompt=negative_prompt)
167
+
168
+ common_kwargs = dict(
169
+ prompt_embeds=conditioning.embeds,
170
+ pooled_prompt_embeds=conditioning.pooled_embeds,
171
+ negative_prompt_embeds=conditioning.negative_embeds,
172
+ negative_pooled_prompt_embeds=conditioning.negative_pooled_embeds,
173
+ guidance_scale=float(guidance_scale),
174
+ num_inference_steps=int(num_inference_steps),
175
+ generator=generator,
176
+ )
177
+
178
+ if device.type == "cuda":
179
+ with torch.autocast("cuda", dtype=dtype):
180
+
181
+ # If init_image is provided, use img2img
182
+ if init_image is not None:
183
+ image = pipe_img2img(
184
+ image=init_image,
185
+ strength=float(strength),
186
+ **common_kwargs,
187
+ ).images[0]
188
+ else:
189
+ image = pipe_txt2img(
190
+ width=width,
191
+ height=height,
192
+ **common_kwargs,
193
+ ).images[0]
194
+ else:
195
+ if init_image is not None:
196
+ image = pipe_img2img(
197
+ image=init_image,
198
+ strength=float(strength),
199
+ **common_kwargs,
200
+ ).images[0]
201
+ else:
202
+ image = pipe_txt2img(
203
+ width=width,
204
+ height=height,
205
+ **common_kwargs,
206
+ ).images[0]
207
+
208
+ return image, status
209
+
210
+ except Exception as e:
211
+ msg = f"Error during generation: {type(e).__name__}: {e}"
212
+ return _make_error_image(width, height, msg), msg
213
+
214
+ finally:
215
+ gc.collect()
216
+ if device.type == "cuda":
217
+ torch.cuda.empty_cache()
218
+
219
+
220
+ # ============================================================
221
+ # UI
222
+ # ============================================================
223
+
224
+ CSS = """
225
+ body{
226
+ background:#000;
227
+ color:#fff;
228
+ }
229
+ """
230
+
231
+ with gr.Blocks(title="Text to Image / Image to Image") as demo:
232
+
233
+ gr.HTML(f"<style>{CSS}</style>")
234
+
235
+ with gr.Column():
236
+
237
+ # banner first
238
+ if fallback_msg:
239
+ gr.Markdown(f"**{fallback_msg}**")
240
+
241
+ if not model_loaded:
242
+ gr.Markdown(
243
+ f"⚠️ **Model failed to load.**\n\nDetails: {load_error}",
244
+ elem_classes=["small-note"],
245
+ )
246
+
247
+ gr.Markdown("## SDXL Generator (txt2img + img2img)")
248
+
249
+ prompt = gr.Textbox(
250
+ label="Prompt",
251
+ placeholder="Enter your prompt...",
252
+ lines=2,
253
+ )
254
+
255
+ # NEW: optional initial image for img2img
256
+ init_image = gr.Image(
257
+ label="Initial image (optional)",
258
+ type="pil",
259
+ )
260
+
261
+ run_button = gr.Button("Generate")
262
+ result = gr.Image(label="Result")
263
+ status = gr.Markdown("")
264
+
265
+ with gr.Accordion("Advanced Settings", open=False):
266
+ negative_prompt = gr.Textbox(label="Negative prompt", value="")
267
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
268
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
269
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
270
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
271
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=0, maximum=20, step=0.1, value=7)
272
+ num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=40, step=1, value=20)
273
+
274
+ # NEW: strength for img2img
275
+ strength = gr.Slider(
276
+ label="Image strength (for img2img)",
277
+ minimum=0.0,
278
+ maximum=1.0,
279
+ step=0.05,
280
+ value=0.7,
281
+ )
282
+
283
+ run_button.click(
284
+ fn=infer,
285
+ inputs=[
286
+ prompt,
287
+ negative_prompt,
288
+ seed,
289
+ randomize_seed,
290
+ width,
291
+ height,
292
+ guidance_scale,
293
+ num_inference_steps,
294
+ init_image,
295
+ strength,
296
+ ],
297
+ outputs=[result, status],
298
+ )
299
+
300
+ demo.queue().launch(ssr_mode=False)