| import argparse | |
| import os | |
| import json | |
| import copy | |
| import os.path as osp | |
| import torch | |
| from diffusers import UNet2DConditionModel, AutoencoderKL | |
| from diffusers.models.attention import BasicTransformerBlock | |
| from peft import LoraConfig | |
| from peft.utils import set_peft_model_state_dict | |
| from transformers import PretrainedConfig | |
| from diffusers import DPMSolverMultistepScheduler | |
| from glyph_sdxl.utils import ( | |
| parse_config, | |
| UNET_CKPT_NAME, | |
| huggingface_cache_dir, | |
| load_byt5_and_byt5_tokenizer, | |
| BYT5_MAPPER_CKPT_NAME, | |
| INSERTED_ATTN_CKPT_NAME, | |
| BYT5_CKPT_NAME, | |
| MultilingualPromptFormat, | |
| ) | |
| from glyph_sdxl.custom_diffusers import ( | |
| StableDiffusionGlyphXLPipeline, | |
| CrossAttnInsertBasicTransformerBlock, | |
| ) | |
| from glyph_sdxl.modules import T5EncoderBlockByT5Mapper | |
| byt5_mapper_dict = [T5EncoderBlockByT5Mapper] | |
| byt5_mapper_dict = {mapper.__name__: mapper for mapper in byt5_mapper_dict} | |
| def import_model_class_from_model_name_or_path( | |
| pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder", | |
| ): | |
| text_encoder_config = PretrainedConfig.from_pretrained( | |
| pretrained_model_name_or_path, | |
| subfolder=subfolder, | |
| revision=revision, | |
| ) | |
| model_class = text_encoder_config.architectures[0] | |
| if model_class == "CLIPTextModel": | |
| from transformers import CLIPTextModel | |
| return CLIPTextModel | |
| elif model_class == "CLIPTextModelWithProjection": | |
| from transformers import CLIPTextModelWithProjection | |
| return CLIPTextModelWithProjection | |
| else: | |
| raise ValueError(f"{model_class} is not supported.") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("config_dir", type=str) | |
| parser.add_argument("ckpt_dir", type=str) | |
| parser.add_argument("ann_path", type=str, default='examples/shower.json') | |
| parser.add_argument("--out_folder", type=str, default='None') | |
| parser.add_argument("--device", type=str, default='cuda') | |
| parser.add_argument("--sampler", type=str, choices=['euler', 'dpm']) | |
| parser.add_argument("--cfg", type=float, default=5.0) | |
| args = parser.parse_args() | |
| config = parse_config(args.config_dir) | |
| text_encoder_cls_one = import_model_class_from_model_name_or_path( | |
| config.pretrained_model_name_or_path, config.revision, | |
| ) | |
| text_encoder_cls_two = import_model_class_from_model_name_or_path( | |
| config.pretrained_model_name_or_path, config.revision, subfolder="text_encoder_2", | |
| ) | |
| text_encoder_one = text_encoder_cls_one.from_pretrained( | |
| config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision, | |
| cache_dir=huggingface_cache_dir, | |
| ) | |
| text_encoder_two = text_encoder_cls_two.from_pretrained( | |
| config.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=config.revision, | |
| cache_dir=huggingface_cache_dir, | |
| ) | |
| unet = UNet2DConditionModel.from_pretrained( | |
| config.pretrained_model_name_or_path, | |
| subfolder="unet", | |
| revision=config.revision, | |
| cache_dir=huggingface_cache_dir, | |
| ) | |
| vae_path = ( | |
| config.pretrained_model_name_or_path | |
| if config.pretrained_vae_model_name_or_path is None | |
| else config.pretrained_vae_model_name_or_path | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| vae_path, subfolder="vae" if config.pretrained_vae_model_name_or_path is None else None, | |
| revision=config.revision, | |
| cache_dir=huggingface_cache_dir, | |
| ) | |
| byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer( | |
| **config.byt5_config, | |
| huggingface_cache_dir=huggingface_cache_dir, | |
| ) | |
| inference_dtype = torch.float32 | |
| if config.inference_dtype == "fp16": | |
| inference_dtype = torch.float16 | |
| elif config.inference_dtype == "bf16": | |
| inference_dtype = torch.bfloat16 | |
| if config.pretrained_vae_model_name_or_path is None: | |
| vae.to(args.device, dtype=torch.float32) | |
| else: | |
| vae.to(args.device, dtype=inference_dtype) | |
| text_encoder_one.to(args.device, dtype=inference_dtype) | |
| text_encoder_two.to(args.device, dtype=inference_dtype) | |
| byt5_model.to(args.device) | |
| unet.to(args.device, dtype=inference_dtype) | |
| inserted_new_modules_para_set = set() | |
| for name, module in unet.named_modules(): | |
| if isinstance(module, BasicTransformerBlock) and name in config.attn_block_to_modify: | |
| parent_module = unet | |
| for n in name.split(".")[:-1]: | |
| parent_module = getattr(parent_module, n) | |
| new_block = CrossAttnInsertBasicTransformerBlock.from_transformer_block( | |
| module, | |
| byt5_model.config.d_model if config.byt5_mapper_config.sdxl_channels is None else config.byt5_mapper_config.sdxl_channels, | |
| ) | |
| new_block.requires_grad_(False) | |
| for inserted_module_name, inserted_module in zip( | |
| new_block.get_inserted_modules_names(), | |
| new_block.get_inserted_modules() | |
| ): | |
| inserted_module.requires_grad_(True) | |
| for para_name, para in inserted_module.named_parameters(): | |
| para_key = name + '.' + inserted_module_name + '.' + para_name | |
| assert para_key not in inserted_new_modules_para_set | |
| inserted_new_modules_para_set.add(para_key) | |
| for origin_module in new_block.get_origin_modules(): | |
| origin_module.to(args.device, dtype=inference_dtype) | |
| parent_module.register_module(name.split(".")[-1], new_block) | |
| print(f"inserted cross attn block to {name}") | |
| byt5_mapper = byt5_mapper_dict[config.byt5_mapper_type]( | |
| byt5_model.config, | |
| **config.byt5_mapper_config, | |
| ) | |
| unet_lora_target_modules = [ | |
| "attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0", | |
| "attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0", | |
| ] | |
| unet_lora_config = LoraConfig( | |
| r=config.unet_lora_rank, | |
| lora_alpha=config.unet_lora_rank, | |
| init_lora_weights="gaussian", | |
| target_modules=unet_lora_target_modules, | |
| ) | |
| unet.add_adapter(unet_lora_config) | |
| unet_lora_layers_para = torch.load(osp.join(args.ckpt_dir, UNET_CKPT_NAME), map_location='cpu') | |
| incompatible_keys = set_peft_model_state_dict(unet, unet_lora_layers_para, adapter_name="default") | |
| if getattr(incompatible_keys, 'unexpected_keys', []) == []: | |
| print(f"loaded unet_lora_layers_para") | |
| else: | |
| print(f"unet_lora_layers has unexpected_keys: {getattr(incompatible_keys, 'unexpected_keys', None)}") | |
| inserted_attn_module_paras = torch.load(osp.join(args.ckpt_dir, INSERTED_ATTN_CKPT_NAME), map_location='cpu') | |
| missing_keys, unexpected_keys = unet.load_state_dict(inserted_attn_module_paras, strict=False) | |
| assert len(unexpected_keys) == 0, unexpected_keys | |
| byt5_mapper_para = torch.load(osp.join(args.ckpt_dir, BYT5_MAPPER_CKPT_NAME), map_location='cpu') | |
| byt5_mapper.load_state_dict(byt5_mapper_para) | |
| byt5_model_para = torch.load(osp.join(args.ckpt_dir, BYT5_CKPT_NAME), map_location='cpu') | |
| byt5_model.load_state_dict(byt5_model_para) | |
| pipeline = StableDiffusionGlyphXLPipeline.from_pretrained( | |
| config.pretrained_model_name_or_path, | |
| vae=vae, | |
| text_encoder=text_encoder_one, | |
| text_encoder_2=text_encoder_two, | |
| byt5_text_encoder=byt5_model, | |
| byt5_tokenizer=byt5_tokenizer, | |
| byt5_mapper=byt5_mapper, | |
| unet=unet, | |
| byt5_max_length=config.byt5_max_length, | |
| revision=config.revision, | |
| torch_dtype=inference_dtype, | |
| safety_checker=None, | |
| cache_dir=huggingface_cache_dir, | |
| ) | |
| if args.sampler == 'dpm': | |
| pipeline.scheduler = DPMSolverMultistepScheduler.from_pretrained( | |
| config.pretrained_model_name_or_path, | |
| subfolder="scheduler", | |
| use_karras_sigmas=True, | |
| ) | |
| pipeline = pipeline.to(args.device) | |
| with open(args.ann_path, 'r') as f: | |
| ann = json.load(f) | |
| os.makedirs(args.out_folder, exist_ok=True) | |
| prompt_format = MultilingualPromptFormat() | |
| texts = copy.deepcopy(ann['texts']) | |
| bboxes = copy.deepcopy(ann['bbox']) | |
| styles = copy.deepcopy(ann['styles']) | |
| text_prompt = prompt_format.format_prompt(texts, styles) | |
| if 'seed' not in ann: | |
| generator = torch.Generator(device=args.device) | |
| else: | |
| generator = torch.Generator(device=args.device).manual_seed(ann['seed']) | |
| with torch.cuda.amp.autocast(): | |
| image = pipeline( | |
| prompt=ann['bg_prompt'], | |
| text_prompt=text_prompt, | |
| texts=texts, | |
| bboxes=bboxes, | |
| num_inference_steps=50, | |
| generator=generator, | |
| text_attn_mask=None, | |
| guidance_scale=args.cfg, | |
| ).images[0] | |
| image.save(f'{args.out_folder}/result.png') | |