Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes
07dc8e6
| import gc | |
| from typing import List, Optional, Dict, Any | |
| import torch | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| from diffusers.schedulers import * | |
| from controlnet_aux.processor import Processor | |
| class ControlNetReq(BaseModel): | |
| controlnets: List[str] # ["canny", "tile", "depth"] | |
| control_images: List[Image.Image] | |
| controlnet_conditioning_scale: List[float] | |
| class Config: | |
| arbitrary_types_allowed=True | |
| class BaseReq(BaseModel): | |
| model: str = "" | |
| prompt: str = "" | |
| fast_generation: Optional[bool] = True | |
| loras: Optional[list] = [] | |
| resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill | |
| scheduler: Optional[str] = "euler_fl" | |
| height: int = 1024 | |
| width: int = 1024 | |
| num_images_per_prompt: int = 1 | |
| num_inference_steps: int = 8 | |
| guidance_scale: float = 3.5 | |
| seed: Optional[int] = 0 | |
| refiner: bool = False | |
| vae: bool = True | |
| controlnet_config: Optional[ControlNetReq] = None | |
| custom_addons: Optional[Dict[Any, Any]] = None | |
| class Config: | |
| arbitrary_types_allowed=True | |
| class BaseImg2ImgReq(BaseReq): | |
| image: Image.Image | |
| strength: float = 1.0 | |
| class Config: | |
| arbitrary_types_allowed=True | |
| class BaseInpaintReq(BaseImg2ImgReq): | |
| mask_image: Image.Image | |
| class Config: | |
| arbitrary_types_allowed=True | |
| def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str): | |
| for image in images: | |
| if resize_mode == "resize_only": | |
| image = image.resize((width, height)) | |
| elif resize_mode == "crop_and_resize": | |
| image = image.crop((0, 0, width, height)) | |
| elif resize_mode == "resize_and_fill": | |
| image = image.resize((width, height), Image.Resampling.LANCZOS) | |
| return images | |
| def get_controlnet_images(controlnet_config: ControlNetReq, height: int, width: int, resize_mode: str): | |
| response_images = [] | |
| control_images = resize_images(controlnet_config.control_images, height, width, resize_mode) | |
| for controlnet, image in zip(controlnet_config.controlnets, control_images): | |
| if controlnet == "canny": | |
| processor = Processor('canny') | |
| elif controlnet == "depth": | |
| processor = Processor('depth_midas') | |
| elif controlnet == "pose": | |
| processor = Processor('openpose_full') | |
| else: | |
| raise ValueError(f"Invalid Controlnet: {controlnet}") | |
| response_images.append(processor(image, to_pil=True)) | |
| return response_images | |
| def cleanup(pipeline, loras = None): | |
| if loras: | |
| pipeline.unload_lora_weights() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |