HIRO12121212 commited on
Commit
781d14b
·
verified ·
1 Parent(s): 697fc7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -53
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import gradio as gr
3
- from diffusers import DiffusionPipeline, AutoPipelineForImage2Image, AutoencoderTiny, AutoencoderKL
4
  from diffusers.utils import load_image
5
  import os
 
6
  from PIL import Image
7
  import time
8
 
@@ -21,23 +22,20 @@ lora_usage = {lora["title"]: 0 for lora in loras}
21
  device = "cpu"
22
  dtype = torch.float32 # Use float32 for CPU compatibility
23
 
24
- # Initialize pipelines
25
  base_model = "black-forest-labs/FLUX.1-dev"
26
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
27
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype)
28
 
29
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1)
30
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
31
  base_model,
32
- vae=good_vae,
33
- transformer=pipe.transformer,
34
- text_encoder=pipe.text_encoder,
35
- tokenizer=pipe.tokenizer,
36
- text_encoder_2=pipe.text_encoder_2,
37
- tokenizer_2=pipe.tokenizer_2,
38
- torch_dtype=dtype
39
  )
40
 
 
 
 
41
  # Custom CSS
42
  css = """
43
  #title {
@@ -84,39 +82,50 @@ def update_lora_info(selected_index, custom_lora):
84
  def remove_custom_lora(selected_index):
85
  return None, gr.HTML(visible=False), gr.Button(visible=False), gr.Markdown(value=update_lora_info(selected_index, None)[0])
86
 
87
- # Image generation functions
88
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
89
- generator = torch.Generator(device=device).manual_seed(seed)
90
- with calculateDuration("Generating image"):
91
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
92
- prompt=prompt_mash,
93
- num_inference_steps=steps,
94
- guidance_scale=cfg_scale,
95
- width=width,
96
- height=height,
97
- generator=generator,
98
- joint_attention_kwargs={"scale": lora_scale},
99
- output_type="pil",
100
- good_vae=good_vae,
101
- ):
102
- yield img
103
-
104
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
105
  generator = torch.Generator(device=device).manual_seed(seed)
106
- image_input = load_image(image_input_path)
107
- final_image = pipe_i2i(
108
- prompt=prompt_mash,
109
- image=image_input,
110
- strength=image_strength,
111
- num_inference_steps=steps,
112
- guidance_scale=cfg_scale,
113
- width=width,
114
- height=height,
115
- generator=generator,
116
- joint_attention_kwargs={"scale": lora_scale},
117
- output_type="pil",
118
- ).images[0]
119
- return final_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale):
122
  global lora_usage
@@ -129,19 +138,27 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
129
  lora_usage[selected_lora["title"]] += 1
130
  pipe.unload_lora_weights()
131
  pipe.load_lora_weights(lora_path)
132
- pipe_i2i.load_lora_weights(lora_path)
133
  if prompt == "":
134
  prompt = trigger_word
135
  else:
136
  prompt_mash = f"{prompt}, {trigger_word}"
137
  if randomize_seed:
138
  seed = int(time.time())
139
- if image_input:
140
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
141
- return final_image, seed, gr.Markdown(value=f"**Seed**: {seed}", visible=True)
142
- else:
143
- for img in generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale):
144
- yield img, seed, gr.Markdown(value=f"**Seed**: {seed}", visible=True)
 
 
 
 
 
 
 
 
 
145
 
146
  def generate_usage_chart():
147
  sorted_usage = sorted(lora_usage.items(), key=lambda x: x[1], reverse=True)[:5]
@@ -239,11 +256,11 @@ with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
239
  refresh_chart_button = gr.Button("Refresh Usage Chart")
240
  with gr.Accordion("Advanced Settings", open=False):
241
  with gr.Row():
242
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=25, step=1)
243
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.1)
244
  with gr.Row():
245
- width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
246
- height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
247
  with gr.Row():
248
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, value=0.8, step=0.1)
249
  image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, value=0.5, step=0.1, visible=False)
 
1
  import torch
2
  import gradio as gr
3
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
4
  from diffusers.utils import load_image
5
  import os
6
+ import gc
7
  from PIL import Image
8
  import time
9
 
 
22
  device = "cpu"
23
  dtype = torch.float32 # Use float32 for CPU compatibility
24
 
25
+ # Initialize a single pipeline with CPU offloading
26
  base_model = "black-forest-labs/FLUX.1-dev"
27
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
28
  good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype)
29
 
30
+ pipe = DiffusionPipeline.from_pretrained(
 
31
  base_model,
32
+ torch_dtype=dtype,
33
+ vae=taef1,
 
 
 
 
 
34
  )
35
 
36
+ # Enable CPU offloading to reduce memory usage
37
+ pipe.enable_model_cpu_offload()
38
+
39
  # Custom CSS
40
  css = """
41
  #title {
 
82
  def remove_custom_lora(selected_index):
83
  return None, gr.HTML(visible=False), gr.Button(visible=False), gr.Markdown(value=update_lora_info(selected_index, None)[0])
84
 
85
+ # Image generation function (combined for both text-to-image and image-to-image)
86
+ def generate_image(
87
+ prompt_mash,
88
+ image_input_path,
89
+ image_strength,
90
+ steps,
91
+ seed,
92
+ cfg_scale,
93
+ width,
94
+ height,
95
+ lora_scale
96
+ ):
 
 
 
 
 
 
97
  generator = torch.Generator(device=device).manual_seed(seed)
98
+
99
+ # Configure pipeline for text-to-image or image-to-image
100
+ kwargs = {
101
+ "prompt": prompt_mash,
102
+ "num_inference_steps": steps,
103
+ "guidance_scale": cfg_scale,
104
+ "width": width,
105
+ "height": height,
106
+ "generator": generator,
107
+ "joint_attention_kwargs": {"scale": lora_scale},
108
+ "output_type": "pil",
109
+ "good_vae": good_vae,
110
+ }
111
+
112
+ if image_input_path:
113
+ image_input = load_image(image_input_path)
114
+ kwargs.update({
115
+ "image": image_input,
116
+ "strength": image_strength,
117
+ })
118
+ with calculateDuration("Generating image-to-image"):
119
+ result = pipe(**kwargs).images[0]
120
+ else:
121
+ with calculateDuration("Generating text-to-image"):
122
+ result = pipe(**kwargs).images[0]
123
+
124
+ # Clear memory after generation
125
+ torch.cuda.empty_cache() # No effect on CPU, but harmless
126
+ gc.collect()
127
+
128
+ return result
129
 
130
  def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale):
131
  global lora_usage
 
138
  lora_usage[selected_lora["title"]] += 1
139
  pipe.unload_lora_weights()
140
  pipe.load_lora_weights(lora_path)
 
141
  if prompt == "":
142
  prompt = trigger_word
143
  else:
144
  prompt_mash = f"{prompt}, {trigger_word}"
145
  if randomize_seed:
146
  seed = int(time.time())
147
+
148
+ # Generate the image
149
+ final_image = generate_image(
150
+ prompt_mash,
151
+ image_input,
152
+ image_strength,
153
+ steps,
154
+ seed,
155
+ cfg_scale,
156
+ width,
157
+ height,
158
+ lora_scale
159
+ )
160
+
161
+ return final_image, seed, gr.Markdown(value=f"**Seed**: {seed}", visible=True)
162
 
163
  def generate_usage_chart():
164
  sorted_usage = sorted(lora_usage.items(), key=lambda x: x[1], reverse=True)[:5]
 
256
  refresh_chart_button = gr.Button("Refresh Usage Chart")
257
  with gr.Accordion("Advanced Settings", open=False):
258
  with gr.Row():
259
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=10, step=1) # Reduced default steps
260
  cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.1)
261
  with gr.Row():
262
+ width = gr.Slider(label="Width", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution
263
+ height = gr.Slider(label="Height", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution
264
  with gr.Row():
265
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, value=0.8, step=0.1)
266
  image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, value=0.5, step=0.1, visible=False)