Spaces:
Runtime error
Runtime error
| import argparse | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline, AutoencoderKL | |
| from blora_utils import BLOCKS, filter_lora, scale_lora | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--prompt", type=str, required=True, help="B-LoRA prompt" | |
| ) | |
| parser.add_argument( | |
| "--output_path", type=str, required=True, help="path to save the images" | |
| ) | |
| parser.add_argument( | |
| "--content_B_LoRA", type=str, default=None, help="path for the content B-LoRA" | |
| ) | |
| parser.add_argument( | |
| "--style_B_LoRA", type=str, default=None, help="path for the style B-LoRA" | |
| ) | |
| parser.add_argument( | |
| "--content_alpha", type=float, default=1., help="alpha parameter to scale the content B-LoRA weights" | |
| ) | |
| parser.add_argument( | |
| "--style_alpha", type=float, default=1., help="alpha parameter to scale the style B-LoRA weights" | |
| ) | |
| parser.add_argument( | |
| "--num_images_per_prompt", type=int, default=4, help="number of images per prompt" | |
| ) | |
| return parser.parse_args() | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
| pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", | |
| vae=vae, | |
| torch_dtype=torch.float16).to("cuda") | |
| # Get Content B-LoRA SD | |
| if args.content_B_LoRA is not None: | |
| content_B_LoRA_sd, _ = pipeline.lora_state_dict(args.content_B_LoRA) | |
| content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content']) | |
| content_B_LoRA = scale_lora(content_B_LoRA, args.content_alpha) | |
| else: | |
| content_B_LoRA = {} | |
| # Get Style B-LoRA SD | |
| if args.style_B_LoRA is not None: | |
| style_B_LoRA_sd, _ = pipeline.lora_state_dict(args.style_B_LoRA) | |
| style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style']) | |
| style_B_LoRA = scale_lora(style_B_LoRA, args.style_alpha) | |
| else: | |
| style_B_LoRA = {} | |
| # Merge B-LoRAs SD | |
| res_lora = {**content_B_LoRA, **style_B_LoRA} | |
| # Load | |
| pipeline.load_lora_into_unet(res_lora, None, pipeline.unet) | |
| # Generate | |
| images = pipeline(args.prompt, num_images_per_prompt=args.num_images_per_prompt).images | |
| # Save | |
| for i, img in enumerate(images): | |
| img.save(f'{args.output_path}/{args.prompt}_{i}.jpg') | |