Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
6138235
1
Parent(s):
14638d9
switch model
Browse files- app.py +44 -55
- pipeline/util.py +29 -1
app.py
CHANGED
|
@@ -1,49 +1,59 @@
|
|
| 1 |
import torch
|
| 2 |
import spaces
|
| 3 |
-
from diffusers import ControlNetUnionModel, AutoencoderKL
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
from pipeline.mod_controlnet_tile_sr_sdxl import StableDiffusionXLControlNetTileSRPipeline, calculate_overlap
|
| 7 |
from pipeline.util import (
|
| 8 |
SAMPLERS,
|
| 9 |
create_hdr_effect,
|
|
|
|
| 10 |
progressive_upscale,
|
|
|
|
| 11 |
select_scheduler,
|
| 12 |
-
torch_gc,
|
| 13 |
)
|
| 14 |
|
| 15 |
device = "cuda"
|
| 16 |
pipe = None
|
| 17 |
last_loaded_model = None
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
| 22 |
-
).to(device=device)
|
| 23 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
|
| 24 |
|
| 25 |
def load_model(model_id):
|
| 26 |
global pipe, last_loaded_model
|
| 27 |
-
|
| 28 |
if model_id != last_loaded_model:
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
| 32 |
-
model_id, controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
| 36 |
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
| 37 |
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
last_loaded_model = model_id
|
| 39 |
|
| 40 |
-
load_model("
|
| 41 |
|
| 42 |
# region functions
|
| 43 |
@spaces.GPU(duration=120)
|
| 44 |
def predict(
|
| 45 |
-
model_id,
|
| 46 |
image,
|
|
|
|
| 47 |
prompt,
|
| 48 |
negative_prompt,
|
| 49 |
resolution,
|
|
@@ -124,38 +134,6 @@ def set_maximum_resolution(max_tile_size, current_value):
|
|
| 124 |
def select_tile_weighting_method(tile_weighting_method):
|
| 125 |
return gr.update(visible=True if tile_weighting_method=="Gaussian" else False)
|
| 126 |
|
| 127 |
-
@spaces.GPU(duration=120)
|
| 128 |
-
def run_for_examples(image,
|
| 129 |
-
prompt,
|
| 130 |
-
negative_prompt,
|
| 131 |
-
resolution,
|
| 132 |
-
hdr,
|
| 133 |
-
num_inference_steps,
|
| 134 |
-
denoising_strenght,
|
| 135 |
-
controlnet_strength,
|
| 136 |
-
tile_gaussian_sigma,
|
| 137 |
-
scheduler,
|
| 138 |
-
guidance_scale,
|
| 139 |
-
max_tile_size,
|
| 140 |
-
tile_weighting_method):
|
| 141 |
-
|
| 142 |
-
predict(
|
| 143 |
-
model.value,
|
| 144 |
-
image,
|
| 145 |
-
prompt,
|
| 146 |
-
negative_prompt,
|
| 147 |
-
resolution,
|
| 148 |
-
hdr,
|
| 149 |
-
num_inference_steps,
|
| 150 |
-
denoising_strenght,
|
| 151 |
-
controlnet_strength,
|
| 152 |
-
tile_gaussian_sigma,
|
| 153 |
-
scheduler,
|
| 154 |
-
guidance_scale,
|
| 155 |
-
max_tile_size,
|
| 156 |
-
tile_weighting_method)
|
| 157 |
-
|
| 158 |
-
|
| 159 |
# endregion
|
| 160 |
|
| 161 |
css = """
|
|
@@ -174,7 +152,7 @@ body {
|
|
| 174 |
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
|
| 175 |
}
|
| 176 |
.fillable {
|
| 177 |
-
width:
|
| 178 |
max-width: unset !important;
|
| 179 |
}
|
| 180 |
#examples_container {
|
|
@@ -279,7 +257,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 279 |
with gr.Row(elem_id="parameters_row"):
|
| 280 |
gr.Markdown("### General parameters")
|
| 281 |
model = gr.Dropdown(
|
| 282 |
-
label="Model", choices=
|
| 283 |
)
|
| 284 |
tile_weighting_method = gr.Dropdown(
|
| 285 |
label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
|
|
@@ -303,9 +281,10 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 303 |
with gr.Accordion(label="Example Images", open=True):
|
| 304 |
with gr.Row(elem_id="examples_row"):
|
| 305 |
with gr.Column(scale=12, elem_id="examples_container"):
|
| 306 |
-
gr.Examples(
|
| 307 |
examples=[
|
| 308 |
[ "./examples/1.jpg",
|
|
|
|
| 309 |
prompt.value,
|
| 310 |
negative_prompt.value,
|
| 311 |
4096,
|
|
@@ -320,6 +299,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 320 |
"Cosine"
|
| 321 |
],
|
| 322 |
[ "./examples/1.jpg",
|
|
|
|
| 323 |
prompt.value,
|
| 324 |
negative_prompt.value,
|
| 325 |
4096,
|
|
@@ -334,6 +314,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 334 |
"Cosine"
|
| 335 |
],
|
| 336 |
[ "./examples/2.jpg",
|
|
|
|
| 337 |
prompt.value,
|
| 338 |
negative_prompt.value,
|
| 339 |
4096,
|
|
@@ -348,6 +329,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 348 |
"Cosine"
|
| 349 |
],
|
| 350 |
[ "./examples/2.jpg",
|
|
|
|
| 351 |
prompt.value,
|
| 352 |
negative_prompt.value,
|
| 353 |
4096,
|
|
@@ -362,6 +344,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 362 |
"Cosine"
|
| 363 |
],
|
| 364 |
[ "./examples/3.jpg",
|
|
|
|
| 365 |
prompt.value,
|
| 366 |
negative_prompt.value,
|
| 367 |
5120,
|
|
@@ -376,6 +359,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 376 |
"Gaussian"
|
| 377 |
],
|
| 378 |
[ "./examples/3.jpg",
|
|
|
|
| 379 |
prompt.value,
|
| 380 |
negative_prompt.value,
|
| 381 |
5120,
|
|
@@ -390,6 +374,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 390 |
"Gaussian"
|
| 391 |
],
|
| 392 |
[ "./examples/4.jpg",
|
|
|
|
| 393 |
prompt.value,
|
| 394 |
negative_prompt.value,
|
| 395 |
8192,
|
|
@@ -404,6 +389,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 404 |
"Gaussian"
|
| 405 |
],
|
| 406 |
[ "./examples/4.jpg",
|
|
|
|
| 407 |
prompt.value,
|
| 408 |
negative_prompt.value,
|
| 409 |
8192,
|
|
@@ -418,6 +404,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 418 |
"Gaussian"
|
| 419 |
],
|
| 420 |
[ "./examples/5.jpg",
|
|
|
|
| 421 |
prompt.value,
|
| 422 |
negative_prompt.value,
|
| 423 |
8192,
|
|
@@ -432,6 +419,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 432 |
"Cosine"
|
| 433 |
],
|
| 434 |
[ "./examples/5.jpg",
|
|
|
|
| 435 |
prompt.value,
|
| 436 |
negative_prompt.value,
|
| 437 |
8192,
|
|
@@ -448,6 +436,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 448 |
],
|
| 449 |
inputs=[
|
| 450 |
input_image,
|
|
|
|
| 451 |
prompt,
|
| 452 |
negative_prompt,
|
| 453 |
resolution,
|
|
@@ -461,13 +450,13 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 461 |
max_tile_size,
|
| 462 |
tile_weighting_method,
|
| 463 |
],
|
| 464 |
-
fn=
|
| 465 |
outputs=result,
|
| 466 |
cache_examples=False,
|
| 467 |
)
|
| 468 |
|
| 469 |
max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
|
| 470 |
-
tile_weighting_method.
|
| 471 |
generate_button.click(
|
| 472 |
fn=clear_result,
|
| 473 |
inputs=None,
|
|
@@ -475,8 +464,8 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(), title="MoD ControlNet Tile Upsc
|
|
| 475 |
).then(
|
| 476 |
fn=predict,
|
| 477 |
inputs=[
|
| 478 |
-
model,
|
| 479 |
input_image,
|
|
|
|
| 480 |
prompt,
|
| 481 |
negative_prompt,
|
| 482 |
resolution,
|
|
|
|
| 1 |
import torch
|
| 2 |
import spaces
|
| 3 |
+
from diffusers import ControlNetUnionModel, AutoencoderKL, UNet2DConditionModel
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
from pipeline.mod_controlnet_tile_sr_sdxl import StableDiffusionXLControlNetTileSRPipeline, calculate_overlap
|
| 7 |
from pipeline.util import (
|
| 8 |
SAMPLERS,
|
| 9 |
create_hdr_effect,
|
| 10 |
+
optionally_disable_offloading,
|
| 11 |
progressive_upscale,
|
| 12 |
+
quantize_8bit,
|
| 13 |
select_scheduler,
|
| 14 |
+
torch_gc,
|
| 15 |
)
|
| 16 |
|
| 17 |
device = "cuda"
|
| 18 |
pipe = None
|
| 19 |
last_loaded_model = None
|
| 20 |
+
MODELS = {"RealVisXL 5 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
| 21 |
+
"RealVisXL 5": "SG161222/RealVisXL_V5.0"
|
| 22 |
+
}
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def load_model(model_id):
|
| 25 |
global pipe, last_loaded_model
|
| 26 |
+
|
| 27 |
if model_id != last_loaded_model:
|
| 28 |
+
|
| 29 |
+
# Initialize the models and pipeline
|
| 30 |
+
controlnet = ControlNetUnionModel.from_pretrained(
|
| 31 |
+
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
| 32 |
+
)
|
| 33 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
| 34 |
+
if pipe is not None:
|
| 35 |
+
optionally_disable_offloading(pipe)
|
| 36 |
+
torch_gc()
|
| 37 |
pipe = StableDiffusionXLControlNetTileSRPipeline.from_pretrained(
|
| 38 |
+
MODELS[model_id], controlnet=controlnet, vae=vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
| 39 |
+
)
|
| 40 |
+
pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
|
|
|
| 41 |
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
| 42 |
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
| 43 |
+
|
| 44 |
+
unet = UNet2DConditionModel.from_pretrained(MODELS[model_id], subfolder="unet", variant="fp16", use_safetensors=True)
|
| 45 |
+
quantize_8bit(unet) # << Enable this if you have limited VRAM
|
| 46 |
+
pipe.unet = unet
|
| 47 |
+
|
| 48 |
last_loaded_model = model_id
|
| 49 |
|
| 50 |
+
load_model("RealVisXL 5 Lightning")
|
| 51 |
|
| 52 |
# region functions
|
| 53 |
@spaces.GPU(duration=120)
|
| 54 |
def predict(
|
|
|
|
| 55 |
image,
|
| 56 |
+
model_id,
|
| 57 |
prompt,
|
| 58 |
negative_prompt,
|
| 59 |
resolution,
|
|
|
|
| 134 |
def select_tile_weighting_method(tile_weighting_method):
|
| 135 |
return gr.update(visible=True if tile_weighting_method=="Gaussian" else False)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# endregion
|
| 138 |
|
| 139 |
css = """
|
|
|
|
| 152 |
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
|
| 153 |
}
|
| 154 |
.fillable {
|
| 155 |
+
width: 100% !important;
|
| 156 |
max-width: unset !important;
|
| 157 |
}
|
| 158 |
#examples_container {
|
|
|
|
| 257 |
with gr.Row(elem_id="parameters_row"):
|
| 258 |
gr.Markdown("### General parameters")
|
| 259 |
model = gr.Dropdown(
|
| 260 |
+
label="Model", choices=MODELS.keys(), value=list(MODELS.keys())[0]
|
| 261 |
)
|
| 262 |
tile_weighting_method = gr.Dropdown(
|
| 263 |
label="Tile Weighting Method", choices=["Cosine", "Gaussian"], value="Cosine"
|
|
|
|
| 281 |
with gr.Accordion(label="Example Images", open=True):
|
| 282 |
with gr.Row(elem_id="examples_row"):
|
| 283 |
with gr.Column(scale=12, elem_id="examples_container"):
|
| 284 |
+
eg = gr.Examples(
|
| 285 |
examples=[
|
| 286 |
[ "./examples/1.jpg",
|
| 287 |
+
"RealVisXL 5 Lightning",
|
| 288 |
prompt.value,
|
| 289 |
negative_prompt.value,
|
| 290 |
4096,
|
|
|
|
| 299 |
"Cosine"
|
| 300 |
],
|
| 301 |
[ "./examples/1.jpg",
|
| 302 |
+
"RealVisXL 5",
|
| 303 |
prompt.value,
|
| 304 |
negative_prompt.value,
|
| 305 |
4096,
|
|
|
|
| 314 |
"Cosine"
|
| 315 |
],
|
| 316 |
[ "./examples/2.jpg",
|
| 317 |
+
"RealVisXL 5 Lightning",
|
| 318 |
prompt.value,
|
| 319 |
negative_prompt.value,
|
| 320 |
4096,
|
|
|
|
| 329 |
"Cosine"
|
| 330 |
],
|
| 331 |
[ "./examples/2.jpg",
|
| 332 |
+
"RealVisXL 5",
|
| 333 |
prompt.value,
|
| 334 |
negative_prompt.value,
|
| 335 |
4096,
|
|
|
|
| 344 |
"Cosine"
|
| 345 |
],
|
| 346 |
[ "./examples/3.jpg",
|
| 347 |
+
"RealVisXL 5 Lightning",
|
| 348 |
prompt.value,
|
| 349 |
negative_prompt.value,
|
| 350 |
5120,
|
|
|
|
| 359 |
"Gaussian"
|
| 360 |
],
|
| 361 |
[ "./examples/3.jpg",
|
| 362 |
+
"RealVisXL 5",
|
| 363 |
prompt.value,
|
| 364 |
negative_prompt.value,
|
| 365 |
5120,
|
|
|
|
| 374 |
"Gaussian"
|
| 375 |
],
|
| 376 |
[ "./examples/4.jpg",
|
| 377 |
+
"RealVisXL 5 Lightning",
|
| 378 |
prompt.value,
|
| 379 |
negative_prompt.value,
|
| 380 |
8192,
|
|
|
|
| 389 |
"Gaussian"
|
| 390 |
],
|
| 391 |
[ "./examples/4.jpg",
|
| 392 |
+
"RealVisXL 5",
|
| 393 |
prompt.value,
|
| 394 |
negative_prompt.value,
|
| 395 |
8192,
|
|
|
|
| 404 |
"Gaussian"
|
| 405 |
],
|
| 406 |
[ "./examples/5.jpg",
|
| 407 |
+
"RealVisXL 5 Lightning",
|
| 408 |
prompt.value,
|
| 409 |
negative_prompt.value,
|
| 410 |
8192,
|
|
|
|
| 419 |
"Cosine"
|
| 420 |
],
|
| 421 |
[ "./examples/5.jpg",
|
| 422 |
+
"RealVisXL 5",
|
| 423 |
prompt.value,
|
| 424 |
negative_prompt.value,
|
| 425 |
8192,
|
|
|
|
| 436 |
],
|
| 437 |
inputs=[
|
| 438 |
input_image,
|
| 439 |
+
model,
|
| 440 |
prompt,
|
| 441 |
negative_prompt,
|
| 442 |
resolution,
|
|
|
|
| 450 |
max_tile_size,
|
| 451 |
tile_weighting_method,
|
| 452 |
],
|
| 453 |
+
fn=predict,
|
| 454 |
outputs=result,
|
| 455 |
cache_examples=False,
|
| 456 |
)
|
| 457 |
|
| 458 |
max_tile_size.select(fn=set_maximum_resolution, inputs=[max_tile_size, resolution], outputs=resolution)
|
| 459 |
+
tile_weighting_method.change(fn=select_tile_weighting_method, inputs=tile_weighting_method, outputs=tile_gaussian_sigma)
|
| 460 |
generate_button.click(
|
| 461 |
fn=clear_result,
|
| 462 |
inputs=None,
|
|
|
|
| 464 |
).then(
|
| 465 |
fn=predict,
|
| 466 |
inputs=[
|
|
|
|
| 467 |
input_image,
|
| 468 |
+
model,
|
| 469 |
prompt,
|
| 470 |
negative_prompt,
|
| 471 |
resolution,
|
pipeline/util.py
CHANGED
|
@@ -16,6 +16,8 @@
|
|
| 16 |
import gc
|
| 17 |
import cv2
|
| 18 |
import numpy as np
|
|
|
|
|
|
|
| 19 |
import torch
|
| 20 |
from PIL import Image
|
| 21 |
|
|
@@ -96,6 +98,32 @@ def select_scheduler(pipe, selected_sampler):
|
|
| 96 |
|
| 97 |
return scheduler.from_config(config, **add_kwargs)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
|
| 101 |
def progressive_upscale(input_image, target_resolution, steps=3):
|
|
@@ -185,7 +213,7 @@ def torch_gc():
|
|
| 185 |
if torch.cuda.is_available():
|
| 186 |
with torch.cuda.device("cuda"):
|
| 187 |
torch.cuda.empty_cache()
|
| 188 |
-
torch.cuda.ipc_collect()
|
| 189 |
|
| 190 |
gc.collect()
|
| 191 |
|
|
|
|
| 16 |
import gc
|
| 17 |
import cv2
|
| 18 |
import numpy as np
|
| 19 |
+
from torch import nn
|
| 20 |
+
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
| 21 |
import torch
|
| 22 |
from PIL import Image
|
| 23 |
|
|
|
|
| 98 |
|
| 99 |
return scheduler.from_config(config, **add_kwargs)
|
| 100 |
|
| 101 |
+
def optionally_disable_offloading(_pipeline):
|
| 102 |
+
"""
|
| 103 |
+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
_pipeline (`DiffusionPipeline`):
|
| 107 |
+
The pipeline to disable offloading for.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
tuple:
|
| 111 |
+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
| 112 |
+
"""
|
| 113 |
+
is_model_cpu_offload = False
|
| 114 |
+
is_sequential_cpu_offload = False
|
| 115 |
+
if _pipeline is not None:
|
| 116 |
+
for _, component in _pipeline.components.items():
|
| 117 |
+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
| 118 |
+
if not is_model_cpu_offload:
|
| 119 |
+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
| 120 |
+
if not is_sequential_cpu_offload:
|
| 121 |
+
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
remove_hook_from_module(component, recurse=True)
|
| 125 |
+
|
| 126 |
+
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
| 127 |
|
| 128 |
# This function was copied and adapted from https://huggingface.co/spaces/gokaygokay/TileUpscalerV2, licensed under Apache 2.0.
|
| 129 |
def progressive_upscale(input_image, target_resolution, steps=3):
|
|
|
|
| 213 |
if torch.cuda.is_available():
|
| 214 |
with torch.cuda.device("cuda"):
|
| 215 |
torch.cuda.empty_cache()
|
| 216 |
+
#torch.cuda.ipc_collect()
|
| 217 |
|
| 218 |
gc.collect()
|
| 219 |
|