Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import gc | |
| import pathlib | |
| import sys | |
| import gradio as gr | |
| import PIL.Image | |
| import numpy as np | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| sys.path.insert(0, './ReVersion') | |
| # below are original | |
| import os | |
| # import argparse | |
| # import torch | |
| from PIL import Image | |
| # from diffusers import StableDiffusionPipeline | |
| # sys.path.insert(0, './ReVersion') | |
| # from templates.templates import inference_templates | |
| import math | |
| """ | |
| Inference script for generating batch results | |
| """ | |
| def make_image_grid(imgs, rows, cols): | |
| assert len(imgs) == rows*cols | |
| w, h = imgs[0].size | |
| grid = Image.new('RGB', size=(cols*w, rows*h)) | |
| grid_w, grid_h = grid.size | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| return grid | |
| def inference_fn( | |
| model_id: str, | |
| prompt: str, | |
| num_samples: int, | |
| guidance_scale: float, | |
| ) -> PIL.Image.Image: | |
| # create inference pipeline | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to(device) | |
| # make directory to save images | |
| image_root_folder = os.path.join('experiments', model_id, 'inference') | |
| os.makedirs(image_root_folder, exist_ok = True) | |
| # if prompt is None and args.template_name is None: | |
| # raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.") | |
| # single text prompt | |
| if prompt is not None: | |
| prompt_list = [prompt] | |
| else: | |
| prompt_list = [] | |
| # if args.template_name is not None: | |
| # # read the selected text prompts for generation | |
| # prompt_list.extend(inference_templates[args.template_name]) | |
| for prompt in prompt_list: | |
| # insert relation prompt <R> | |
| # prompt = prompt.lower().replace("<r>", "<R>").format(placeholder_string) | |
| prompt = prompt.lower().replace("<r>", "<R>").format("<R>") | |
| # make sub-folder | |
| image_folder = os.path.join(image_root_folder, prompt, 'samples') | |
| os.makedirs(image_folder, exist_ok = True) | |
| # batch generation | |
| images = pipe(prompt, num_inference_steps=50, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images | |
| # save generated images | |
| for idx, image in enumerate(images): | |
| image_name = f"{str(idx).zfill(4)}.png" | |
| image_path = os.path.join(image_folder, image_name) | |
| image.save(image_path) | |
| # save a grid of images | |
| image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2)) | |
| image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png') | |
| return image_grid | |
| if __name__ == "__main__": | |
| inference_fn() | |