Spaces:
Runtime error
Runtime error
| import argparse | |
| import time | |
| import gradio as gr | |
| import torch | |
| from diffusers import DiffusionPipeline, UNet2DConditionModel | |
| from scheduling_dmd import DMDScheduler | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--unet-path", type='Lykon/dreamshaper-8') | |
| parser.add_argument("--model-path", type='aaronb/dreamshaper-8-dmd-kl-only-6kstep') | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| unet = UNet2DConditionModel.from_pretrained(args.unet_path) | |
| pipe = DiffusionPipeline.from_pretrained(args.model_path, unet=unet) | |
| pipe.scheduler = DMDScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(device=device, dtype=torch.float16) | |
| def predict(prompt, seed=1231231): | |
| generator = torch.manual_seed(seed) | |
| last_time = time.time() | |
| image = pipe( | |
| prompt, | |
| num_inference_steps=1, | |
| guidance_scale=0.0, | |
| generator=generator, | |
| ).images[0] | |
| print(f"Pipe took {time.time() - last_time} seconds") | |
| return image | |
| css = """ | |
| #container{ | |
| margin: 0 auto; | |
| max-width: 40rem; | |
| } | |
| #intro{ | |
| max-width: 100%; | |
| text-align: center; | |
| margin: 0 auto; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="container"): | |
| gr.Markdown( | |
| """# Distribution Matching Distillation | |
| """, | |
| elem_id="intro", | |
| ) | |
| with gr.Row(): | |
| with gr.Row(): | |
| prompt = gr.Textbox(placeholder="Insert your prompt here:", scale=5, container=False) | |
| generate_bt = gr.Button("Generate", scale=1) | |
| image = gr.Image(type="filepath") | |
| with gr.Accordion("Advanced options", open=False): | |
| seed = gr.Slider(randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1) | |
| inputs = [prompt, seed] | |
| generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False) | |
| prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False) | |
| seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False) | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False) | |