comrender commited on
Commit
7bda99d
·
verified ·
1 Parent(s): 1332b22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -444
app.py CHANGED
@@ -1,89 +1,13 @@
1
- import logging
2
- import random
3
- import warnings
4
- import os
5
- import gradio as gr
6
- import numpy as np
7
- import spaces
8
- import torch
9
- from gradio_imageslider import ImageSlider
10
- from PIL import Image
11
- from huggingface_hub import hf_hub_download
12
- import subprocess
13
- import sys
14
- import tempfile
15
- from typing import Sequence, Mapping, Any, Union
16
- import asyncio
17
- import execution
18
- from nodes import init_extra_nodes
19
- import server
20
-
21
- # Copy functions from FluxSimpleUpscaler.txt
22
- def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
23
- try:
24
- return obj[index]
25
- except KeyError:
26
- return obj["result"][index]
27
-
28
- def find_path(name: str, path: str = None) -> str:
29
- if path is None:
30
- path = os.getcwd()
31
- if name in os.listdir(path):
32
- path_name = os.path.join(path, name)
33
- print(f"{name} found: {path_name}")
34
- return path_name
35
- parent_directory = os.path.dirname(path)
36
- if parent_directory == path:
37
- return None
38
- return find_path(name, parent_directory)
39
-
40
- def add_comfyui_directory_to_sys_path() -> None:
41
- comfyui_path = find_path("ComfyUI")
42
- if comfyui_path is not None and os.path.isdir(comfyui_path):
43
- sys.path.append(comfyui_path)
44
- print(f"'{comfyui_path}' added to sys.path")
45
-
46
- def add_extra_model_paths() -> None:
47
- try:
48
- from main import load_extra_path_config
49
- except ImportError:
50
- print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
51
- from utils.extra_config import load_extra_path_config
52
- extra_model_paths = find_path("extra_model_paths.yaml")
53
- if extra_model_paths is not None:
54
- load_extra_path_config(extra_model_paths)
55
- else:
56
- print("Could not find the extra_model_paths config file.")
57
-
58
- def import_custom_nodes() -> None:
59
- import asyncio
60
- import execution
61
- from nodes import init_extra_nodes
62
- import server
63
- loop = asyncio.new_event_loop()
64
- asyncio.set_event_loop(loop)
65
- server_instance = server.PromptServer(loop)
66
- execution.PromptQueue(server_instance)
67
- init_extra_nodes()
68
-
69
- # Setup ComfyUI and custom nodes
70
- if not os.path.exists("ComfyUI"):
71
- subprocess.run(["git", "clone", "https://github.com/comfyanonymous/ComfyUI.git"])
72
-
73
- custom_node_path = "ComfyUI/custom_nodes/ComfyUI_UltimateSDUpscale"
74
- if not os.path.exists(custom_node_path):
75
- subprocess.run(["git", "clone", "https://github.com/ssitu/ComfyUI_UltimateSDUpscale.git", custom_node_path])
76
-
77
  # Create model directories
78
- os.makedirs("ComfyUI/models/unet", exist_ok=True)
79
  os.makedirs("ComfyUI/models/clip", exist_ok=True)
80
  os.makedirs("ComfyUI/models/vae", exist_ok=True)
81
  os.makedirs("ComfyUI/models/upscale_models", exist_ok=True)
82
 
83
  # Download models if not present
84
- unet_path = "ComfyUI/models/unet/flux1-dev-fp8.safetensors"
85
- if not os.path.exists(unet_path):
86
- hf_hub_download("Kijai/flux-fp8", "flux1-dev-fp8.safetensors", local_dir="ComfyUI/models/unet")
87
 
88
  clip_l_path = "ComfyUI/models/clip/clip_l.safetensors"
89
  if not os.path.exists(clip_l_path):
@@ -105,375 +29,25 @@ esrgan_x4_path = "ComfyUI/models/upscale_models/RealESRGAN_x4.pth"
105
  if not os.path.exists(esrgan_x4_path):
106
  hf_hub_download("ai-forever/Real-ESRGAN", "RealESRGAN_x4.pth", local_dir="ComfyUI/models/upscale_models")
107
 
108
- # Add ComfyUI to path and import custom nodes
 
109
  add_comfyui_directory_to_sys_path()
110
  add_extra_model_paths()
111
- import_custom_nodes()
112
-
113
- from nodes import NODE_CLASS_MAPPINGS
114
-
115
- css = """
116
- #col-container {
117
- margin: 0 auto;
118
- max-width: 800px;
119
- }
120
- .main-header {
121
- text-align: center;
122
- margin-bottom: 2rem;
123
- }
124
- """
125
-
126
- MAX_SEED = 1000000
127
- MAX_PIXEL_BUDGET = 8192 * 8192
128
 
129
- def make_divisible_by_16(size):
130
- return ((size // 16) * 16) if (size % 16) < 8 else ((size // 16 + 1) * 16)
131
 
132
- def process_input(input_image, upscale_factor):
133
- w, h = input_image.size
134
- w_original, h_original = w, h
135
 
136
- was_resized = False
137
-
138
- if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
139
- gr.Info(f"Requested output image is too large. Resizing input to fit within pixel budget.")
140
- target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
141
- scale = (target_input_pixels / (w * h)) ** 0.5
142
- new_w = max(16, int(w * scale) // 16 * 16)
143
- new_h = max(16, int(h * scale) // 16 * 16)
144
- input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
145
- was_resized = True
146
-
147
- return input_image, w_original, h_original, was_resized
148
-
149
- import requests
150
- def load_image_from_url(url):
151
- try:
152
- response = requests.get(url, stream=True)
153
- response.raise_for_status()
154
- return Image.open(response.raw)
155
- except Exception as e:
156
- raise gr.Error(f"Failed to load image from URL: {e}")
157
-
158
- def tensor_to_pil(tensor):
159
- tensor = tensor.cpu().clamp(0, 1) * 255
160
- img = tensor.numpy().astype(np.uint8)[0]
161
- return Image.fromarray(img)
162
-
163
- @spaces.GPU(duration=120)
164
- def enhance_image(
165
- image_input,
166
- image_url,
167
- seed,
168
- randomize_seed,
169
- num_inference_steps,
170
- upscale_factor,
171
- denoising_strength,
172
- custom_prompt,
173
- tile_size,
174
- progress=gr.Progress(track_tqdm=True),
175
- ):
176
- if image_input is not None:
177
- true_input_image = image_input
178
- elif image_url:
179
- true_input_image = load_image_from_url(image_url)
180
- else:
181
- raise gr.Error("Please provide an image (upload or URL)")
182
-
183
- if randomize_seed:
184
- seed = random.randint(0, MAX_SEED)
185
-
186
- input_image, w_original, h_original, was_resized = process_input(true_input_image, upscale_factor)
187
-
188
- if upscale_factor == 2:
189
- upscale_model_name = "RealESRGAN_x2.pth"
190
- else:
191
- upscale_model_name = "RealESRGAN_x4.pth"
192
-
193
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
194
- input_image.save(tmp.name)
195
- image_path = tmp.name
196
-
197
- with torch.inference_mode():
198
- dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
199
- dualcliploader_res = dualcliploader.load_clip(
200
- clip_name1="clip_l.safetensors",
201
- clip_name2="t5xxl_fp8_e4m3fn.safetensors",
202
- type="flux",
203
  )
204
- clip = get_value_at_index(dualcliploader_res, 0)
205
-
206
- cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
207
- positive_res = cliptextencode.encode(
208
- text=custom_prompt,
209
- clip=clip
210
- )
211
- negative_res = cliptextencode.encode(
212
- text="",
213
- clip=clip
214
- )
215
-
216
- upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
217
- upscalemodelloader_res = upscalemodelloader.load_model(
218
- model_name=upscale_model_name
219
- )
220
-
221
- vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
222
- vaeloader_res = vaeloader.load_vae(vae_name="ae.safetensors")
223
-
224
- unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
225
- unetloader_res = unetloader.load_unet(
226
- unet_name="flux1-dev-fp8.safetensors", weight_dtype="fp8_e4m3fn"
227
- )
228
-
229
- loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
230
- loadimage_res = loadimage.load_image(image=os.path.basename(image_path))
231
-
232
- fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
233
- fluxguidance_res = fluxguidance.append(
234
- guidance=30, conditioning=get_value_at_index(positive_res, 0)
235
- )
236
-
237
- ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]()
238
- usd_res = ultimatesdupscale.upscale(
239
- upscale_by=upscale_factor,
240
- seed=seed,
241
- steps=num_inference_steps,
242
- cfg=1,
243
- sampler_name="euler",
244
- scheduler="normal",
245
- denoise=denoising_strength,
246
- mode_type="Linear",
247
- tile_width=tile_size,
248
- tile_height=tile_size,
249
- mask_blur=8,
250
- tile_padding=32,
251
- seam_fix_mode="None",
252
- seam_fix_denoise=1,
253
- seam_fix_width=64,
254
- seam_fix_mask_blur=8,
255
- seam_fix_padding=16,
256
- force_uniform_tiles=True,
257
- tiled_decode=False,
258
- image=get_value_at_index(loadimage_res, 0),
259
- model=get_value_at_index(unetloader_res, 0),
260
- positive=get_value_at_index(fluxguidance_res, 0),
261
- negative=get_value_at_index(negative_res, 0),
262
- vae=get_value_at_index(vaeloader_res, 0),
263
- upscale_model=get_value_at_index(upscalemodelloader_res, 0),
264
- )
265
-
266
- output_tensor = get_value_at_index(usd_res, 0)
267
- image = tensor_to_pil(output_tensor)
268
-
269
- os.unlink(image_path)
270
-
271
- target_w, target_h = w_original * upscale_factor, h_original * upscale_factor
272
- if image.size != (target_w, target_h):
273
- image = image.resize((target_w, target_h), resample=Image.LANCZOS)
274
-
275
- if was_resized:
276
- gr.Info(f"Resizing output to target size: {target_w}x{target_h}")
277
- image = image.resize((target_w, target_h), resample=Image.LANCZOS)
278
-
279
- resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
280
-
281
- return [resized_input, image]
282
-
283
- with gr.Blocks(css=css, title="🎨 AI Image Upscaler - FLUX ComfyUI") as demo:
284
- gr.HTML("""
285
- <div class="main-header">
286
- <h1>🎨 AI Image Upscaler (ComfyUI Workflow)</h1>
287
- <p>Upload an image or provide a URL to upscale it using FLUX FP8 with ComfyUI Ultimate SD Upscale</p>
288
- <p>Using FLUX.1-dev FP8 model</p>
289
- </div>
290
- """)
291
-
292
- with gr.Row():
293
- with gr.Column(scale=1):
294
- gr.HTML("<h3>📤 Input</h3>")
295
-
296
- with gr.Tabs():
297
- with gr.TabItem("📁 Upload Image"):
298
- input_image = gr.Image(
299
- label="Upload Image",
300
- type="pil",
301
- height=200
302
- )
303
-
304
- with gr.TabItem("🔗 Image URL"):
305
- image_url = gr.Textbox(
306
- label="Image URL",
307
- placeholder="https://example.com/image.jpg",
308
- value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
309
- )
310
-
311
- gr.HTML("<h3>🎛️ Prompt Settings</h3>")
312
-
313
- custom_prompt = gr.Textbox(
314
- label="Custom Prompt (optional)",
315
- placeholder="Enter custom prompt or leave empty",
316
- lines=2
317
- )
318
-
319
- gr.HTML("<h3>⚙️ Upscaling Settings</h3>")
320
-
321
- upscale_factor = gr.Slider(
322
- label="Upscale Factor",
323
- minimum=1,
324
- maximum=4,
325
- step=1,
326
- value=2,
327
- info="How much to upscale the image"
328
- )
329
-
330
- num_inference_steps = gr.Slider(
331
- label="Number of Inference Steps",
332
- minimum=1,
333
- maximum=50,
334
- step=1,
335
- value=25,
336
- info="More steps = better quality but slower"
337
- )
338
-
339
- denoising_strength = gr.Slider(
340
- label="Denoising Strength",
341
- minimum=0.0,
342
- maximum=1.0,
343
- step=0.05,
344
- value=0.3,
345
- info="Controls how much the image is transformed"
346
- )
347
-
348
- tile_size = gr.Slider(
349
- label="Tile Size",
350
- minimum=256,
351
- maximum=2048,
352
- step=64,
353
- value=1024,
354
- info="Size of tiles for processing (larger = faster but more memory)"
355
- )
356
-
357
- with gr.Row():
358
- randomize_seed = gr.Checkbox(
359
- label="Randomize seed",
360
- value=True
361
- )
362
- seed = gr.Slider(
363
- label="Seed",
364
- minimum=0,
365
- maximum=MAX_SEED,
366
- step=1,
367
- value=42,
368
- interactive=True
369
- )
370
-
371
- enhance_btn = gr.Button(
372
- "🚀 Upscale Image",
373
- variant="primary",
374
- size="lg"
375
- )
376
-
377
- with gr.Column(scale=2):
378
- gr.HTML("<h3>📊 Results</h3>")
379
-
380
- result_slider = ImageSlider(
381
- type="pil",
382
- interactive=False,
383
- height=600,
384
- elem_id="result_slider",
385
- label=None
386
- )
387
 
388
- enhance_btn.click(
389
- fn=enhance_image,
390
- inputs=[
391
- input_image,
392
- image_url,
393
- seed,
394
- randomize_seed,
395
- num_inference_steps,
396
- upscale_factor,
397
- denoising_strength,
398
- custom_prompt,
399
- tile_size
400
- ],
401
- outputs=[result_slider]
402
- )
403
-
404
- gr.HTML("""
405
- <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
406
- <p><strong>Note:</strong> This upscaler uses the Flux.1-dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
407
- </div>
408
- """)
409
-
410
- gr.HTML("""
411
- <style>
412
- #result_slider .slider {
413
- width: 100% !important;
414
- max-width: inherit !important;
415
- }
416
- #result_slider img {
417
- object-fit: contain !important;
418
- width: 100% !important;
419
- height: auto !important;
420
- }
421
- #result_slider .gr-button-tool {
422
- display: none !important;
423
- }
424
- #result_slider .gr-button-undo {
425
- display: none !important;
426
- }
427
- #result_slider .gr-button-clear {
428
- display: none !important;
429
- }
430
- #result_slider .badge-container .badge {
431
- display: none !important;
432
- }
433
- #result_slider .badge-container::before {
434
- content: "Before";
435
- position: absolute;
436
- top: 10px;
437
- left: 10px;
438
- background: rgba(0,0,0,0.5);
439
- color: white;
440
- padding: 5px;
441
- border-radius: 5px;
442
- z-index: 10;
443
- }
444
- #result_slider .badge-container::after {
445
- content: "After";
446
- position: absolute;
447
- top: 10px;
448
- right: 10px;
449
- background: rgba(0,0,0,0.5);
450
- color: white;
451
- padding: 5px;
452
- border-radius: 5px;
453
- z-index: 10;
454
- }
455
- #result_slider .fullscreen img {
456
- object-fit: contain !important;
457
- width: 100vw !important;
458
- height: 100vh !important;
459
- position: absolute;
460
- top: 0;
461
- left: 0;
462
- }
463
- </style>
464
- """)
465
-
466
- gr.HTML("""
467
- <script>
468
- document.addEventListener('DOMContentLoaded', function() {
469
- const sliderInput = document.querySelector('#result_slider input[type="range"]');
470
- if (sliderInput) {
471
- sliderInput.value = 50;
472
- sliderInput.dispatchEvent(new Event('input'));
473
- }
474
- });
475
- </script>
476
- """)
477
 
478
- if __name__ == "__main__":
479
- demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Create model directories
2
+ os.makedirs("ComfyUI/models/diffusion_models", exist_ok=True)
3
  os.makedirs("ComfyUI/models/clip", exist_ok=True)
4
  os.makedirs("ComfyUI/models/vae", exist_ok=True)
5
  os.makedirs("ComfyUI/models/upscale_models", exist_ok=True)
6
 
7
  # Download models if not present
8
+ diffusion_path = "ComfyUI/models/diffusion_models/flux1-dev-fp8.safetensors"
9
+ if not os.path.exists(diffusion_path):
10
+ hf_hub_download("Kijai/flux-fp8", "flux1-dev-fp8.safetensors", local_dir="ComfyUI/models/diffusion_models")
11
 
12
  clip_l_path = "ComfyUI/models/clip/clip_l.safetensors"
13
  if not os.path.exists(clip_l_path):
 
29
  if not os.path.exists(esrgan_x4_path):
30
  hf_hub_download("ai-forever/Real-ESRGAN", "RealESRGAN_x4.pth", local_dir="ComfyUI/models/upscale_models")
31
 
32
+ # ...
33
+
34
  add_comfyui_directory_to_sys_path()
35
  add_extra_model_paths()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ from folder_paths import add_model_folder_path
38
+ add_model_folder_path("checkpoints", "ComfyUI/models/diffusion_models")
39
 
40
+ # ...
 
 
41
 
42
+ checkpointloader = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]()
43
+ checkpointloader_res = checkpointloader.load_checkpoint(
44
+ ckpt_name="flux1-dev-fp8.safetensors"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ ultimatesdupscale_50 = ultimatesdupscale.upscale(
50
+ # ...
51
+ model=get_value_at_index(checkpointloader_res, 0),
52
+ # ...
53
+ )