Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
import gradio as gr
|
| 3 |
-
from diffusers import DiffusionPipeline,
|
| 4 |
from diffusers.utils import load_image
|
| 5 |
import os
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
import time
|
| 8 |
|
|
@@ -21,23 +22,20 @@ lora_usage = {lora["title"]: 0 for lora in loras}
|
|
| 21 |
device = "cpu"
|
| 22 |
dtype = torch.float32 # Use float32 for CPU compatibility
|
| 23 |
|
| 24 |
-
# Initialize
|
| 25 |
base_model = "black-forest-labs/FLUX.1-dev"
|
| 26 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
| 27 |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype)
|
| 28 |
|
| 29 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 30 |
-
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
| 31 |
base_model,
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
text_encoder=pipe.text_encoder,
|
| 35 |
-
tokenizer=pipe.tokenizer,
|
| 36 |
-
text_encoder_2=pipe.text_encoder_2,
|
| 37 |
-
tokenizer_2=pipe.tokenizer_2,
|
| 38 |
-
torch_dtype=dtype
|
| 39 |
)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
# Custom CSS
|
| 42 |
css = """
|
| 43 |
#title {
|
|
@@ -84,39 +82,50 @@ def update_lora_info(selected_index, custom_lora):
|
|
| 84 |
def remove_custom_lora(selected_index):
|
| 85 |
return None, gr.HTML(visible=False), gr.Button(visible=False), gr.Markdown(value=update_lora_info(selected_index, None)[0])
|
| 86 |
|
| 87 |
-
# Image generation
|
| 88 |
-
def generate_image(
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
output_type="pil",
|
| 100 |
-
good_vae=good_vae,
|
| 101 |
-
):
|
| 102 |
-
yield img
|
| 103 |
-
|
| 104 |
-
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
|
| 105 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale):
|
| 122 |
global lora_usage
|
|
@@ -129,19 +138,27 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
| 129 |
lora_usage[selected_lora["title"]] += 1
|
| 130 |
pipe.unload_lora_weights()
|
| 131 |
pipe.load_lora_weights(lora_path)
|
| 132 |
-
pipe_i2i.load_lora_weights(lora_path)
|
| 133 |
if prompt == "":
|
| 134 |
prompt = trigger_word
|
| 135 |
else:
|
| 136 |
prompt_mash = f"{prompt}, {trigger_word}"
|
| 137 |
if randomize_seed:
|
| 138 |
seed = int(time.time())
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
def generate_usage_chart():
|
| 147 |
sorted_usage = sorted(lora_usage.items(), key=lambda x: x[1], reverse=True)[:5]
|
|
@@ -239,11 +256,11 @@ with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
|
|
| 239 |
refresh_chart_button = gr.Button("Refresh Usage Chart")
|
| 240 |
with gr.Accordion("Advanced Settings", open=False):
|
| 241 |
with gr.Row():
|
| 242 |
-
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=
|
| 243 |
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.1)
|
| 244 |
with gr.Row():
|
| 245 |
-
width = gr.Slider(label="Width", minimum=256, maximum=1024, value=
|
| 246 |
-
height = gr.Slider(label="Height", minimum=256, maximum=1024, value=
|
| 247 |
with gr.Row():
|
| 248 |
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, value=0.8, step=0.1)
|
| 249 |
image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, value=0.5, step=0.1, visible=False)
|
|
|
|
| 1 |
import torch
|
| 2 |
import gradio as gr
|
| 3 |
+
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
|
| 4 |
from diffusers.utils import load_image
|
| 5 |
import os
|
| 6 |
+
import gc
|
| 7 |
from PIL import Image
|
| 8 |
import time
|
| 9 |
|
|
|
|
| 22 |
device = "cpu"
|
| 23 |
dtype = torch.float32 # Use float32 for CPU compatibility
|
| 24 |
|
| 25 |
+
# Initialize a single pipeline with CPU offloading
|
| 26 |
base_model = "black-forest-labs/FLUX.1-dev"
|
| 27 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
| 28 |
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype)
|
| 29 |
|
| 30 |
+
pipe = DiffusionPipeline.from_pretrained(
|
|
|
|
| 31 |
base_model,
|
| 32 |
+
torch_dtype=dtype,
|
| 33 |
+
vae=taef1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
)
|
| 35 |
|
| 36 |
+
# Enable CPU offloading to reduce memory usage
|
| 37 |
+
pipe.enable_model_cpu_offload()
|
| 38 |
+
|
| 39 |
# Custom CSS
|
| 40 |
css = """
|
| 41 |
#title {
|
|
|
|
| 82 |
def remove_custom_lora(selected_index):
|
| 83 |
return None, gr.HTML(visible=False), gr.Button(visible=False), gr.Markdown(value=update_lora_info(selected_index, None)[0])
|
| 84 |
|
| 85 |
+
# Image generation function (combined for both text-to-image and image-to-image)
|
| 86 |
+
def generate_image(
|
| 87 |
+
prompt_mash,
|
| 88 |
+
image_input_path,
|
| 89 |
+
image_strength,
|
| 90 |
+
steps,
|
| 91 |
+
seed,
|
| 92 |
+
cfg_scale,
|
| 93 |
+
width,
|
| 94 |
+
height,
|
| 95 |
+
lora_scale
|
| 96 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 98 |
+
|
| 99 |
+
# Configure pipeline for text-to-image or image-to-image
|
| 100 |
+
kwargs = {
|
| 101 |
+
"prompt": prompt_mash,
|
| 102 |
+
"num_inference_steps": steps,
|
| 103 |
+
"guidance_scale": cfg_scale,
|
| 104 |
+
"width": width,
|
| 105 |
+
"height": height,
|
| 106 |
+
"generator": generator,
|
| 107 |
+
"joint_attention_kwargs": {"scale": lora_scale},
|
| 108 |
+
"output_type": "pil",
|
| 109 |
+
"good_vae": good_vae,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
if image_input_path:
|
| 113 |
+
image_input = load_image(image_input_path)
|
| 114 |
+
kwargs.update({
|
| 115 |
+
"image": image_input,
|
| 116 |
+
"strength": image_strength,
|
| 117 |
+
})
|
| 118 |
+
with calculateDuration("Generating image-to-image"):
|
| 119 |
+
result = pipe(**kwargs).images[0]
|
| 120 |
+
else:
|
| 121 |
+
with calculateDuration("Generating text-to-image"):
|
| 122 |
+
result = pipe(**kwargs).images[0]
|
| 123 |
+
|
| 124 |
+
# Clear memory after generation
|
| 125 |
+
torch.cuda.empty_cache() # No effect on CPU, but harmless
|
| 126 |
+
gc.collect()
|
| 127 |
+
|
| 128 |
+
return result
|
| 129 |
|
| 130 |
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale):
|
| 131 |
global lora_usage
|
|
|
|
| 138 |
lora_usage[selected_lora["title"]] += 1
|
| 139 |
pipe.unload_lora_weights()
|
| 140 |
pipe.load_lora_weights(lora_path)
|
|
|
|
| 141 |
if prompt == "":
|
| 142 |
prompt = trigger_word
|
| 143 |
else:
|
| 144 |
prompt_mash = f"{prompt}, {trigger_word}"
|
| 145 |
if randomize_seed:
|
| 146 |
seed = int(time.time())
|
| 147 |
+
|
| 148 |
+
# Generate the image
|
| 149 |
+
final_image = generate_image(
|
| 150 |
+
prompt_mash,
|
| 151 |
+
image_input,
|
| 152 |
+
image_strength,
|
| 153 |
+
steps,
|
| 154 |
+
seed,
|
| 155 |
+
cfg_scale,
|
| 156 |
+
width,
|
| 157 |
+
height,
|
| 158 |
+
lora_scale
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return final_image, seed, gr.Markdown(value=f"**Seed**: {seed}", visible=True)
|
| 162 |
|
| 163 |
def generate_usage_chart():
|
| 164 |
sorted_usage = sorted(lora_usage.items(), key=lambda x: x[1], reverse=True)[:5]
|
|
|
|
| 256 |
refresh_chart_button = gr.Button("Refresh Usage Chart")
|
| 257 |
with gr.Accordion("Advanced Settings", open=False):
|
| 258 |
with gr.Row():
|
| 259 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=10, step=1) # Reduced default steps
|
| 260 |
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.1)
|
| 261 |
with gr.Row():
|
| 262 |
+
width = gr.Slider(label="Width", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution
|
| 263 |
+
height = gr.Slider(label="Height", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution
|
| 264 |
with gr.Row():
|
| 265 |
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, value=0.8, step=0.1)
|
| 266 |
image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, value=0.5, step=0.1, visible=False)
|