Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL | |
| from diffusers.utils import load_image | |
| import os | |
| import gc | |
| from PIL import Image | |
| import time | |
| # Initialize a dictionary to track LoRA usage | |
| loras = [ | |
| {"title": "Anime", "repo": "prithivMLmods/Canopus-LoRA-Flux-Anime", "trigger_word": "Anime style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Anime/resolve/main/1.jpg"}, | |
| {"title": "PixelArt", "repo": "prithivMLmods/Canopus-LoRA-Flux-PixelArt", "trigger_word": "PixelArt style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-PixelArt/resolve/main/1.jpg"}, | |
| {"title": "Ghibli", "repo": "prithivMLmods/Canopus-LoRA-Flux-Ghibli", "trigger_word": "Ghibli style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Ghibli/resolve/main/1.jpg"}, | |
| {"title": "Realistic", "repo": "prithivMLmods/Canopus-LoRA-Flux-Realistic", "trigger_word": "Realistic style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Realistic/resolve/main/1.jpg"}, | |
| {"title": "Claymation", "repo": "prithivMLmods/Canopus-LoRA-Flux-Claymation", "trigger_word": "Claymation style", "image": "https://huggingface.co/prithivMLmods/Canopus-LoRA-Flux-Claymation/resolve/main/1.jpg"} | |
| ] | |
| lora_usage = {lora["title"]: 0 for lora in loras} | |
| # Device and dtype setup for CPU | |
| device = "cpu" | |
| dtype = torch.float32 # Use float32 for CPU compatibility | |
| # Initialize a single pipeline with CPU offloading | |
| base_model = "black-forest-labs/FLUX.1-dev" | |
| taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype) | |
| good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype) | |
| pipe = DiffusionPipeline.from_pretrained( | |
| base_model, | |
| torch_dtype=dtype, | |
| vae=taef1, | |
| ) | |
| # Enable CPU offloading to reduce memory usage | |
| pipe.enable_model_cpu_offload() | |
| # Custom CSS | |
| css = """ | |
| #title { | |
| text-align: center; | |
| } | |
| #gen_column { | |
| display: flex; | |
| align-items: flex-end; | |
| } | |
| #gen_btn { | |
| height: 100%; | |
| } | |
| #gallery img { | |
| border-radius: 10px !important; | |
| border: 2px solid white !important; | |
| } | |
| #gallery .svelte-mg0r0q.selected img { | |
| border: 2px solid #00ff00 !important; | |
| } | |
| #progress { | |
| width: 100%; | |
| } | |
| #lora_list { | |
| font-size: 12px; | |
| } | |
| """ | |
| # Utility functions | |
| def calculateDuration(message): | |
| start_time = time.time() | |
| yield None | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| print(f"{message}: {duration:.2f} seconds") | |
| def update_lora_info(selected_index, custom_lora): | |
| if selected_index is None and not custom_lora: | |
| return "Select a LoRA to get started!🧨", None, gr.Button(visible=False) | |
| if custom_lora: | |
| return f"**Custom LoRA**: {custom_lora}", custom_lora, gr.Button(visible=True) | |
| selected_lora = loras[selected_index] | |
| return f"**Selected LoRA**: {selected_lora['title']}\n**Trigger Word**: {selected_lora['trigger_word']}", None, gr.Button(visible=False) | |
| def remove_custom_lora(selected_index): | |
| return None, gr.HTML(visible=False), gr.Button(visible=False), gr.Markdown(value=update_lora_info(selected_index, None)[0]) | |
| # Image generation function (combined for both text-to-image and image-to-image) | |
| def generate_image( | |
| prompt_mash, | |
| image_input_path, | |
| image_strength, | |
| steps, | |
| seed, | |
| cfg_scale, | |
| width, | |
| height, | |
| lora_scale | |
| ): | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| # Configure pipeline for text-to-image or image-to-image | |
| kwargs = { | |
| "prompt": prompt_mash, | |
| "num_inference_steps": steps, | |
| "guidance_scale": cfg_scale, | |
| "width": width, | |
| "height": height, | |
| "generator": generator, | |
| "joint_attention_kwargs": {"scale": lora_scale}, | |
| "output_type": "pil", | |
| "good_vae": good_vae, | |
| } | |
| if image_input_path: | |
| image_input = load_image(image_input_path) | |
| kwargs.update({ | |
| "image": image_input, | |
| "strength": image_strength, | |
| }) | |
| with calculateDuration("Generating image-to-image"): | |
| result = pipe(**kwargs).images[0] | |
| else: | |
| with calculateDuration("Generating text-to-image"): | |
| result = pipe(**kwargs).images[0] | |
| # Clear memory after generation | |
| torch.cuda.empty_cache() # No effect on CPU, but harmless | |
| gc.collect() | |
| return result | |
| def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale): | |
| global lora_usage | |
| if selected_index is None: | |
| raise gr.Error("You must select a LoRA before proceeding.🧨") | |
| selected_lora = loras[selected_index] | |
| lora_path = selected_lora["repo"] | |
| trigger_word = selected_lora["trigger_word"] | |
| # Increment the usage counter for the selected LoRA | |
| lora_usage[selected_lora["title"]] += 1 | |
| pipe.unload_lora_weights() | |
| pipe.load_lora_weights(lora_path) | |
| if prompt == "": | |
| prompt = trigger_word | |
| else: | |
| prompt_mash = f"{prompt}, {trigger_word}" | |
| if randomize_seed: | |
| seed = int(time.time()) | |
| # Generate the image | |
| final_image = generate_image( | |
| prompt_mash, | |
| image_input, | |
| image_strength, | |
| steps, | |
| seed, | |
| cfg_scale, | |
| width, | |
| height, | |
| lora_scale | |
| ) | |
| return final_image, seed, gr.Markdown(value=f"**Seed**: {seed}", visible=True) | |
| def generate_usage_chart(): | |
| sorted_usage = sorted(lora_usage.items(), key=lambda x: x[1], reverse=True)[:5] | |
| labels = [item[0] for item in sorted_usage] | |
| data = [item[1] for item in sorted_usage] | |
| chart_config = { | |
| "type": "bar", | |
| "data": { | |
| "labels": labels, | |
| "datasets": [{ | |
| "label": "LoRA Usage Count", | |
| "data": data, | |
| "backgroundColor": [ | |
| "#4f46e5", # Indigo | |
| "#10b981", # Emerald | |
| "#f97316", # Orange | |
| "#ef4444", # Red | |
| "#3b82f6" # Blue | |
| ], | |
| "borderColor": [ | |
| "#4f46e5", | |
| "#10b981", | |
| "#f97316", | |
| "#ef4444", | |
| "#3b82f6" | |
| ], | |
| "borderWidth": 1 | |
| }] | |
| }, | |
| "options": { | |
| "scales": { | |
| "y": { | |
| "beginAtZero": True, | |
| "title": { | |
| "display": True, | |
| "text": "Usage Count" | |
| } | |
| }, | |
| "x": { | |
| "title": { | |
| "display": True, | |
| "text": "LoRA Title" | |
| } | |
| } | |
| }, | |
| "plugins": { | |
| "legend": { | |
| "display": False | |
| }, | |
| "title": { | |
| "display": True, | |
| "text": "Top 5 Most Used LoRAs" | |
| } | |
| } | |
| } | |
| } | |
| return chart_config | |
| # Gradio interface | |
| with gr.Blocks(theme="YTheme/Minecraft", css=css, delete_cache=(60, 60)) as app: | |
| title = gr.HTML( | |
| """<h1>FLUX LoRA DLC🥳</h1>""", | |
| elem_id="title", | |
| ) | |
| selected_index = gr.State(None) | |
| lora_usage_state = gr.State(lora_usage) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt = gr.Textbox(label="Prompt", lines=1, placeholder=":/ choose the LoRA and type the prompt ") | |
| with gr.Column(scale=1, elem_id="gen_column"): | |
| generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn") | |
| with gr.Row(): | |
| with gr.Column(): | |
| selected_info = gr.Markdown("") | |
| gallery = gr.Gallery( | |
| [(item["image"], item["title"]) for item in loras], | |
| label="LoRA DLC's", | |
| allow_preview=False, | |
| columns=3, | |
| elem_id="gallery", | |
| show_share_button=False | |
| ) | |
| with gr.Group(): | |
| custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Canopus-LoRA-Flux-Anime") | |
| gr.Markdown("[Check the list of FLUX LoRA's](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list") | |
| custom_lora_info = gr.HTML(visible=False) | |
| custom_lora_button = gr.Button("Remove custom LoRA", visible=False) | |
| with gr.Column(): | |
| progress_bar = gr.Markdown(elem_id="progress", visible=False) | |
| result = gr.Image(label="Generated Image") | |
| with gr.Accordion("LoRA Usage Statistics", open=False): | |
| usage_chart = gr.HTML(label="LoRA Usage Chart") | |
| refresh_chart_button = gr.Button("Refresh Usage Chart") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=10, step=1) # Reduced default steps | |
| cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=0.1) | |
| with gr.Row(): | |
| width = gr.Slider(label="Width", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution | |
| height = gr.Slider(label="Height", minimum=256, maximum=1024, value=256, step=64) # Reduced default resolution | |
| with gr.Row(): | |
| lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, value=0.8, step=0.1) | |
| image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, value=0.5, step=0.1, visible=False) | |
| with gr.Row(): | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| seed = gr.Number(label="Seed", value=42, precision=0, visible=False) | |
| input_image = gr.Image(label="Input Image", type="filepath") | |
| gallery.select( | |
| fn=lambda idx: (idx, update_lora_info(idx, None)[0]), | |
| inputs=None, | |
| outputs=[selected_index, selected_info], | |
| _js=""" | |
| (idx, gallery) => { | |
| const items = document.querySelectorAll('#gallery .svelte-mg0r0q'); | |
| items.forEach((item, i) => { | |
| item.classList.toggle('selected', i === idx); | |
| }); | |
| return [idx, gallery]; | |
| } | |
| """ | |
| ) | |
| custom_lora.submit( | |
| fn=lambda custom_lora: (None, *update_lora_info(None, custom_lora)), | |
| inputs=custom_lora, | |
| outputs=[selected_index, selected_info, custom_lora_info, custom_lora_button] | |
| ).then( | |
| fn=lambda: gr.update(value=""), | |
| inputs=None, | |
| outputs=custom_lora | |
| ) | |
| custom_lora_button.click( | |
| fn=remove_custom_lora, | |
| inputs=selected_index, | |
| outputs=[custom_lora, custom_lora_info, custom_lora_button, selected_info] | |
| ) | |
| input_image.upload( | |
| fn=lambda: gr.update(visible=True), | |
| inputs=None, | |
| outputs=image_strength | |
| ).then( | |
| fn=lambda: gr.update(visible=False), | |
| inputs=None, | |
| outputs=input_image | |
| ) | |
| input_image.clear( | |
| fn=lambda: gr.update(visible=False), | |
| inputs=None, | |
| outputs=image_strength | |
| ) | |
| randomize_seed.change( | |
| fn=lambda randomize: gr.update(visible=not randomize), | |
| inputs=randomize_seed, | |
| outputs=seed | |
| ) | |
| refresh_chart_button.click( | |
| fn=generate_usage_chart, | |
| inputs=[], | |
| outputs=[usage_chart], | |
| _js="return (chart) => chart" | |
| ) | |
| gr.on( | |
| triggers=[generate_button.click, prompt.submit], | |
| fn=run_lora, | |
| inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale], | |
| outputs=[result, seed, progress_bar] | |
| ).then( | |
| fn=generate_usage_chart, | |
| inputs=[], | |
| outputs=[usage_chart], | |
| _js="return (chart) => chart" | |
| ) | |
| # Launch the app | |
| app.launch(server_port=7860) |