Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import logging | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection | |
| import torch.nn as nn | |
| from inference.manganinjia_pipeline import MangaNinjiaPipeline | |
| from diffusers import ( | |
| ControlNetModel, | |
| DiffusionPipeline, | |
| DDIMScheduler, | |
| AutoencoderKL, | |
| ) | |
| from src.models.mutual_self_attention_multi_scale import ReferenceAttentionControl | |
| from src.models.unet_2d_condition import UNet2DConditionModel | |
| from src.models.refunet_2d_condition import RefUNet2DConditionModel | |
| from src.point_network import PointNet | |
| from src.annotator.lineart import BatchLineartDetector | |
| if "__main__" == __name__: | |
| logging.basicConfig(level=logging.INFO) | |
| # -------------------- Arguments -------------------- | |
| parser = argparse.ArgumentParser( | |
| description="Run single-image MangaNinjia" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", type=str, required=True, help="Output directory." | |
| ) | |
| # inference setting | |
| parser.add_argument( | |
| "--denoise_steps", | |
| type=int, | |
| default=50, # quantitative evaluation uses 50 steps | |
| help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.", | |
| ) | |
| # resolution setting | |
| parser.add_argument("--seed", type=int, default=None, help="Random seed.") | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--image_encoder_path", | |
| type=str, | |
| default=None, | |
| required=True, | |
| help="Path to pretrained model or model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--controlnet_model_name_or_path", type=str, required=True, help="Path to original controlnet." | |
| ) | |
| parser.add_argument( | |
| "--annotator_ckpts_path", type=str, required=True, help="Path to depth inpainting model." | |
| ) | |
| parser.add_argument( | |
| "--manga_reference_unet_path", type=str, required=True, help="Path to depth inpainting model." | |
| ) | |
| parser.add_argument( | |
| "--manga_main_model_path", type=str, required=True, help="Path to depth inpainting model." | |
| ) | |
| parser.add_argument( | |
| "--manga_controlnet_model_path", type=str, required=True, help="Path to depth inpainting model." | |
| ) | |
| parser.add_argument( | |
| "--point_net_path", type=str, required=True, help="Path to depth inpainting model." | |
| ) | |
| parser.add_argument( | |
| "--input_reference_paths", | |
| nargs='+', | |
| default=None, | |
| help="input_image_paths", | |
| ) | |
| parser.add_argument( | |
| "--input_lineart_paths", | |
| nargs='+', | |
| default=None, | |
| help="lineart_paths", | |
| ) | |
| parser.add_argument( | |
| "--point_ref_paths", | |
| type=str, | |
| default=None, | |
| nargs="+", | |
| ) | |
| parser.add_argument( | |
| "--point_lineart_paths", | |
| type=str, | |
| default=None, | |
| nargs="+", | |
| ) | |
| parser.add_argument( | |
| "--is_lineart", | |
| action="store_true", | |
| default=False | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale_ref", | |
| type=float, | |
| default=1e-4, | |
| help="guidance scale for reference image", | |
| ) | |
| parser.add_argument( | |
| "--guidance_scale_point", | |
| type=float, | |
| default=1e-4, | |
| help="guidance scale for points", | |
| ) | |
| args = parser.parse_args() | |
| output_dir = args.output_dir | |
| denoise_steps = args.denoise_steps | |
| seed = args.seed | |
| is_lineart = args.is_lineart | |
| os.makedirs(output_dir, exist_ok=True) | |
| logging.info(f"output dir = {output_dir}") | |
| if args.input_reference_paths is not None: | |
| assert len(args.input_reference_paths) == len(args.input_lineart_paths) | |
| input_reference_paths = args.input_reference_paths | |
| input_lineart_paths = args.input_lineart_paths | |
| if args.point_ref_paths is not None: | |
| point_ref_paths = args.point_ref_paths | |
| point_lineart_paths = args.point_lineart_paths | |
| assert len(point_ref_paths) == len(point_lineart_paths) | |
| print(f"arguments: {args}") | |
| if seed is None: | |
| import time | |
| seed = int(time.time()) | |
| generator = torch.cuda.manual_seed(seed) | |
| # -------------------- Device -------------------- | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| logging.warning("CUDA is not available. Running on CPU will be slow.") | |
| logging.info(f"device = {device}") | |
| # -------------------- Model -------------------- | |
| preprocessor = BatchLineartDetector(args.annotator_ckpts_path) | |
| preprocessor.to(device,dtype=torch.float32) | |
| in_channels_reference_unet = 4 | |
| in_channels_denoising_unet = 4 | |
| in_channels_controlnet = 4 | |
| noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path,subfolder='scheduler') | |
| vae = AutoencoderKL.from_pretrained( | |
| args.pretrained_model_name_or_path, | |
| subfolder='vae' | |
| ) | |
| denoising_unet = UNet2DConditionModel.from_pretrained( | |
| args.pretrained_model_name_or_path,subfolder="unet", | |
| in_channels=in_channels_denoising_unet, | |
| low_cpu_mem_usage=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| reference_unet = RefUNet2DConditionModel.from_pretrained( | |
| args.pretrained_model_name_or_path,subfolder="unet", | |
| in_channels=in_channels_reference_unet, | |
| low_cpu_mem_usage=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| refnet_tokenizer = CLIPTokenizer.from_pretrained(args.image_encoder_path) | |
| refnet_text_encoder = CLIPTextModel.from_pretrained(args.image_encoder_path) | |
| refnet_image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) | |
| controlnet = ControlNetModel.from_pretrained( | |
| args.controlnet_model_name_or_path, | |
| in_channels=in_channels_controlnet, | |
| low_cpu_mem_usage=False, | |
| ignore_mismatched_sizes=True | |
| ) | |
| controlnet_tokenizer = CLIPTokenizer.from_pretrained(args.image_encoder_path) | |
| controlnet_text_encoder = CLIPTextModel.from_pretrained(args.image_encoder_path) | |
| controlnet_image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) | |
| point_net=PointNet() | |
| controlnet.load_state_dict( | |
| torch.load(args.manga_controlnet_model_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| point_net.load_state_dict( | |
| torch.load(args.point_net_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| reference_unet.load_state_dict( | |
| torch.load(args.manga_reference_unet_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| denoising_unet.load_state_dict( | |
| torch.load(args.manga_main_model_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| pipe = MangaNinjiaPipeline( | |
| reference_unet=reference_unet, | |
| controlnet=controlnet, | |
| denoising_unet=denoising_unet, | |
| vae=vae, | |
| refnet_tokenizer=refnet_tokenizer, | |
| refnet_text_encoder=refnet_text_encoder, | |
| refnet_image_encoder=refnet_image_encoder, | |
| controlnet_tokenizer=controlnet_tokenizer, | |
| controlnet_text_encoder=controlnet_text_encoder, | |
| controlnet_image_encoder=controlnet_image_encoder, | |
| scheduler=noise_scheduler, | |
| point_net=point_net | |
| ) | |
| pipe = pipe.to(torch.device(device)) | |
| # -------------------- Inference and saving -------------------- | |
| with torch.no_grad(): | |
| for i in range(len(input_reference_paths)): | |
| input_reference_path = input_reference_paths[i] | |
| input_lineart_path = input_lineart_paths[i] | |
| # save path | |
| rgb_name_base = os.path.splitext(os.path.basename(input_reference_path))[0] | |
| pred_name_base = rgb_name_base + "_colorized" | |
| lineart_name_base = rgb_name_base + "_lineart" | |
| colored_save_path = os.path.join( | |
| output_dir, f"{pred_name_base}.png" | |
| ) | |
| lineart_save_path = os.path.join( | |
| output_dir, f"{lineart_name_base}.png" | |
| ) | |
| if point_ref_paths is not None: | |
| point_ref_path = point_ref_paths[i] | |
| point_lineart_path = point_lineart_paths[i] | |
| point_ref = torch.from_numpy(np.load(point_ref_path)).unsqueeze(0).unsqueeze(0) | |
| point_main = torch.from_numpy(np.load(point_lineart_path)).unsqueeze(0).unsqueeze(0) | |
| else: | |
| matrix1 = np.zeros((512, 512), dtype=np.uint8) | |
| matrix2 = np.zeros((512, 512), dtype=np.uint8) | |
| point_ref = torch.from_numpy(matrix1).unsqueeze(0).unsqueeze(0) | |
| point_main = torch.from_numpy(matrix2).unsqueeze(0).unsqueeze(0) | |
| ref_image = Image.open(input_reference_path) | |
| ref_image = ref_image.resize((512, 512)) | |
| target_image = Image.open(input_lineart_path) | |
| target_image = target_image.resize((512, 512)) | |
| pipe_out = pipe( | |
| is_lineart, | |
| ref_image, | |
| target_image, | |
| target_image, | |
| denosing_steps=denoise_steps, | |
| processing_res=512, | |
| match_input_res=True, | |
| batch_size=1, | |
| show_progress_bar=True, | |
| guidance_scale_ref=args.guidance_scale_ref, | |
| guidance_scale_point=args.guidance_scale_point, | |
| preprocessor=preprocessor, | |
| generator=generator, | |
| point_ref=point_ref, | |
| point_main=point_main, | |
| ) | |
| if os.path.exists(colored_save_path): | |
| logging.warning(f"Existing file: '{colored_save_path}' will be overwritten") | |
| image = pipe_out.img_pil | |
| lineart = pipe_out.to_save_dict['edge2_black'] | |
| image.save(colored_save_path) | |
| lineart.save(lineart_save_path) |