FLUX-LoRA-DLC / app.py
HIRO12121212's picture
Update app.py
781d14b verified
import torch
import gradio as gr
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL
from diffusers.utils import load_image
import os
import gc
from PIL import Image
import time
# Initialize a dictionary to track LoRA usage
loras = [
{"title": "Anime", "repo": "prithivMLmods/Canopus-LoRA-Flux-Anime", "trigger_word": "Anime style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Anime/resolve/main/1.jpg"},
{"title": "PixelArt", "repo": "prithivMLmods/Canopus-LoRA-Flux-PixelArt", "trigger_word": "PixelArt style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-PixelArt/resolve/main/1.jpg"},
{"title": "Ghibli", "repo": "prithivMLmods/Canopus-LoRA-Flux-Ghibli", "trigger_word": "Ghibli style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Ghibli/resolve/main/1.jpg"},
{"title": "Realistic", "repo": "prithivMLmods/Canopus-LoRA-Flux-Realistic", "trigger_word": "Realistic style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Realistic/resolve/main/1.jpg"},
{"title": "Claymation", "repo": "prithivMLmods/Canopus-LoRA-Flux-Claymation", "trigger_word": "Claymation style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Claymation/resolve/main/1.jpg"}
]
lora_usage = {lora["title"]: 0 for lora in loras}
# Device and dtype setup for CPU
device = "cpu"
dtype = torch.float32 # Use float32 for CPU compatibility
# Initialize a single pipeline with CPU offloading
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype)
pipe = DiffusionPipeline.from_pretrained(
base_model,
torch_dtype=dtype,
vae=taef1,
)
# Enable CPU offloading to reduce memory usage
pipe.enable_model_cpu_offload()
# Custom CSS
css = """
#title {
text-align: center;
}
#gen_column {
display: flex;
align-items: flex-end;
}
#gen_btn {
height: 100%;
}
#gallery img {
border-radius: 10px !important;
border: 2px solid white !important;
}
#gallery .svelte-mg0r0q.selected img {
border: 2px solid #00ff00 !important;
}
#progress {
width: 100%;
}
#lora_list {
font-size: 12px;
}
"""
# Utility functions
def calculateDuration(message):
start_time = time.time()
yield None
end_time = time.time()
duration = end_time - start_time
print(f"{message}: {duration:.2f} seconds")
def update_lora_info(selected_index, custom_lora):
if selected_index is None and not custom_lora:
return "Select a LoRA to get started!🧨", None, gr.Button(visible=False)
if custom_lora:
return f"**Custom LoRA**: {custom_lora}", custom_lora, gr.Button(visible=True)
selected_lora = loras[selected_index]
return f"**Selected LoRA**: {selected_lora['title']}\n**Trigger Word**: {selected_lora['trigger_word']}", None, gr.Button(visible=False)
def remove_custom_lora(selected_index):
return None, gr.HTML(visible=False), gr.Button(visible=False), gr.Markdown(value=update_lora_info(selected_index, None)[0])
# Image generation function (combined for both text-to-image and image-to-image)
def generate_image(
prompt_mash,
image_input_path,
image_strength,
steps,
seed,
cfg_scale,
width,
height,
lora_scale
):
generator = torch.Generator(device=device).manual_seed(seed)
# Configure pipeline for text-to-image or image-to-image
kwargs = {
"prompt": prompt_mash,
"num_inference_steps": steps,
"guidance_scale": cfg_scale,
"width": width,
"height": height,
"generator": generator,
"joint_attention_kwargs": {"scale": lora_scale},
"output_type": "pil",
"good_vae": good_vae,
}
if image_input_path:
image_input = load_image(image_input_path)
kwargs.update({
"image": image_input,
"strength": image_strength,
})
with calculateDuration("Generating image-to-image"):
result = pipe(**kwargs).images[0]
else:
with calculateDuration("Generating text-to-image"):
result = pipe(**kwargs).images[0]
# Clear memory after generation
torch.cuda.empty_cache() # No effect on CPU, but harmless
gc.collect()
return result
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale):
global lora_usage
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.🧨")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Increment the usage counter for the selected LoRA
lora_usage[selected_lora["title"]] += 1
pipe.unload_lora_weights()
pipe.load_lora_weights(lora_path)
if prompt == "":
prompt = trigger_word
else:
prompt_mash = f"{prompt}, {trigger_word}"
if randomize_seed:
seed = int(time.time())
# Generate the image
final_image = generate_image(
prompt_mash,
image_input,
image_strength,
steps,
seed,
cfg_scale,
width,
height,
lora_scale
)
return final_image, seed, gr.Markdown(value=f"**Seed**: {seed}", visible=True)
def generate_usage_chart():
sorted_usage = sorted(lora_usage.items(), key=lambda x: x[1], reverse=True)[:5]
labels = [item[0] for item in sorted_usage]
data = [item[1] for item in sorted_usage]
chart_config = {
"type": "bar",
"data": {
"labels": labels,
"datasets": [{
"label": "LoRA Usage Count",
"data": data,
"backgroundColor": [
"#4f46e5", # Indigo
"#10b981", # Emerald
"#f97316", # Orange
"#ef4444", # Red
"#3b82f6" # Blue
],
"borderColor": [
"#4f46e5",
"#10b981",
"#f97316",
"#ef4444",
"#3b82f6"
],
"borderWidth": 1
}]
},
"options": {
"scales": {
"y": {
"beginAtZero": True,
"title": {
"display": True,
"text": "Usage Count"
}
},
"x": {
"title": {
"display": True,
"text": "LoRA Title"
}
}
},
"plugins": {
"legend": {
"display": False
},
"title": {
"display": True,
"text": "Top 5 Most Used LoRAs"
}
}
}
}
return chart_config
# Gradio interface
with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app:
title = gr.HTML(
"""<h1>FLUX LoRA DLC🥳</h1>""",
elem_id="title",
)
selected_index = gr.State(None)
lora_usage_state = gr.State(lora_usage)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA DLC's",
allow_preview=False,
columns=3,
elem_id="gallery",
show_share_button=False
)
with gr.Group():
custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime")
gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
custom_lora_info = gr.HTML(visible=False)
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
with gr.Column():
progress_bar = gr.Markdown(elem_id="progress", visible=False)
result = gr.Image(label="Generated Image")
with gr.Accordion("LoRA Usage Statistics", open=False):
usage_chart = gr.HTML(label="LoRA Usage Chart")
refresh_chart_button = gr.Button("Refresh Usage Chart")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=10, step=1) # Reduced default steps
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.1)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution
height = gr.Slider(label="Height", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution
with gr.Row():
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, value=0.8, step=0.1)
image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, value=0.5, step=0.1, visible=False)
with gr.Row():
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
seed = gr.Number(label="Seed", value=42, precision=0, visible=False)
input_image = gr.Image(label="Input Image", type="filepath")
gallery.select(
fn=lambda idx: (idx, update_lora_info(idx, None)[0]),
inputs=None,
outputs=[selected_index, selected_info],
_js="""
(idx, gallery) => {
const items = document.querySelectorAll('#gallery .svelte-mg0r0q');
items.forEach((item, i) => {
item.classList.toggle('selected', i === idx);
});
return [idx, gallery];
}
"""
)
custom_lora.submit(
fn=lambda custom_lora: (None, *update_lora_info(None, custom_lora)),
inputs=custom_lora,
outputs=[selected_index, selected_info, custom_lora_info, custom_lora_button]
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=custom_lora
)
custom_lora_button.click(
fn=remove_custom_lora,
inputs=selected_index,
outputs=[custom_lora, custom_lora_info, custom_lora_button, selected_info]
)
input_image.upload(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=image_strength
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=input_image
)
input_image.clear(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=image_strength
)
randomize_seed.change(
fn=lambda randomize: gr.update(visible=not randomize),
inputs=randomize_seed,
outputs=seed
)
refresh_chart_button.click(
fn=generate_usage_chart,
inputs=[],
outputs=[usage_chart],
_js="return (chart) => chart"
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
outputs=[result, seed, progress_bar]
).then(
fn=generate_usage_chart,
inputs=[],
outputs=[usage_chart],
_js="return (chart) => chart"
)
# Launch the app
app.launch(server_port=7860)