Glyph-ByT5
English
Glyph-ByT5 / inference_multilingual.py
bghira's picture
Upload folder using huggingface_hub
cd05235 verified
import argparse
import os
import json
import copy
import os.path as osp
import torch
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers.models.attention import BasicTransformerBlock
from peft import LoraConfig
from peft.utils import set_peft_model_state_dict
from transformers import PretrainedConfig
from diffusers import DPMSolverMultistepScheduler
from glyph_sdxl.utils import (
parse_config,
UNET_CKPT_NAME,
huggingface_cache_dir,
load_byt5_and_byt5_tokenizer,
BYT5_MAPPER_CKPT_NAME,
INSERTED_ATTN_CKPT_NAME,
BYT5_CKPT_NAME,
MultilingualPromptFormat,
)
from glyph_sdxl.custom_diffusers import (
StableDiffusionGlyphXLPipeline,
CrossAttnInsertBasicTransformerBlock,
)
from glyph_sdxl.modules import T5EncoderBlockByT5Mapper
byt5_mapper_dict = [T5EncoderBlockByT5Mapper]
byt5_mapper_dict = {mapper.__name__: mapper for mapper in byt5_mapper_dict}
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder",
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder=subfolder,
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "CLIPTextModelWithProjection":
from transformers import CLIPTextModelWithProjection
return CLIPTextModelWithProjection
else:
raise ValueError(f"{model_class} is not supported.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_dir", type=str)
parser.add_argument("ckpt_dir", type=str)
parser.add_argument("ann_path", type=str, default='examples/shower.json')
parser.add_argument("--out_folder", type=str, default='None')
parser.add_argument("--device", type=str, default='cuda')
parser.add_argument("--sampler", type=str, choices=['euler', 'dpm'])
parser.add_argument("--cfg", type=float, default=5.0)
args = parser.parse_args()
config = parse_config(args.config_dir)
text_encoder_cls_one = import_model_class_from_model_name_or_path(
config.pretrained_model_name_or_path, config.revision,
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
config.pretrained_model_name_or_path, config.revision, subfolder="text_encoder_2",
)
text_encoder_one = text_encoder_cls_one.from_pretrained(
config.pretrained_model_name_or_path, subfolder="text_encoder", revision=config.revision,
cache_dir=huggingface_cache_dir,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
config.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=config.revision,
cache_dir=huggingface_cache_dir,
)
unet = UNet2DConditionModel.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="unet",
revision=config.revision,
cache_dir=huggingface_cache_dir,
)
vae_path = (
config.pretrained_model_name_or_path
if config.pretrained_vae_model_name_or_path is None
else config.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if config.pretrained_vae_model_name_or_path is None else None,
revision=config.revision,
cache_dir=huggingface_cache_dir,
)
byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer(
**config.byt5_config,
huggingface_cache_dir=huggingface_cache_dir,
)
inference_dtype = torch.float32
if config.inference_dtype == "fp16":
inference_dtype = torch.float16
elif config.inference_dtype == "bf16":
inference_dtype = torch.bfloat16
if config.pretrained_vae_model_name_or_path is None:
vae.to(args.device, dtype=torch.float32)
else:
vae.to(args.device, dtype=inference_dtype)
text_encoder_one.to(args.device, dtype=inference_dtype)
text_encoder_two.to(args.device, dtype=inference_dtype)
byt5_model.to(args.device)
unet.to(args.device, dtype=inference_dtype)
inserted_new_modules_para_set = set()
for name, module in unet.named_modules():
if isinstance(module, BasicTransformerBlock) and name in config.attn_block_to_modify:
parent_module = unet
for n in name.split(".")[:-1]:
parent_module = getattr(parent_module, n)
new_block = CrossAttnInsertBasicTransformerBlock.from_transformer_block(
module,
byt5_model.config.d_model if config.byt5_mapper_config.sdxl_channels is None else config.byt5_mapper_config.sdxl_channels,
)
new_block.requires_grad_(False)
for inserted_module_name, inserted_module in zip(
new_block.get_inserted_modules_names(),
new_block.get_inserted_modules()
):
inserted_module.requires_grad_(True)
for para_name, para in inserted_module.named_parameters():
para_key = name + '.' + inserted_module_name + '.' + para_name
assert para_key not in inserted_new_modules_para_set
inserted_new_modules_para_set.add(para_key)
for origin_module in new_block.get_origin_modules():
origin_module.to(args.device, dtype=inference_dtype)
parent_module.register_module(name.split(".")[-1], new_block)
print(f"inserted cross attn block to {name}")
byt5_mapper = byt5_mapper_dict[config.byt5_mapper_type](
byt5_model.config,
**config.byt5_mapper_config,
)
unet_lora_target_modules = [
"attn1.to_k", "attn1.to_q", "attn1.to_v", "attn1.to_out.0",
"attn2.to_k", "attn2.to_q", "attn2.to_v", "attn2.to_out.0",
]
unet_lora_config = LoraConfig(
r=config.unet_lora_rank,
lora_alpha=config.unet_lora_rank,
init_lora_weights="gaussian",
target_modules=unet_lora_target_modules,
)
unet.add_adapter(unet_lora_config)
unet_lora_layers_para = torch.load(osp.join(args.ckpt_dir, UNET_CKPT_NAME), map_location='cpu')
incompatible_keys = set_peft_model_state_dict(unet, unet_lora_layers_para, adapter_name="default")
if getattr(incompatible_keys, 'unexpected_keys', []) == []:
print(f"loaded unet_lora_layers_para")
else:
print(f"unet_lora_layers has unexpected_keys: {getattr(incompatible_keys, 'unexpected_keys', None)}")
inserted_attn_module_paras = torch.load(osp.join(args.ckpt_dir, INSERTED_ATTN_CKPT_NAME), map_location='cpu')
missing_keys, unexpected_keys = unet.load_state_dict(inserted_attn_module_paras, strict=False)
assert len(unexpected_keys) == 0, unexpected_keys
byt5_mapper_para = torch.load(osp.join(args.ckpt_dir, BYT5_MAPPER_CKPT_NAME), map_location='cpu')
byt5_mapper.load_state_dict(byt5_mapper_para)
byt5_model_para = torch.load(osp.join(args.ckpt_dir, BYT5_CKPT_NAME), map_location='cpu')
byt5_model.load_state_dict(byt5_model_para)
pipeline = StableDiffusionGlyphXLPipeline.from_pretrained(
config.pretrained_model_name_or_path,
vae=vae,
text_encoder=text_encoder_one,
text_encoder_2=text_encoder_two,
byt5_text_encoder=byt5_model,
byt5_tokenizer=byt5_tokenizer,
byt5_mapper=byt5_mapper,
unet=unet,
byt5_max_length=config.byt5_max_length,
revision=config.revision,
torch_dtype=inference_dtype,
safety_checker=None,
cache_dir=huggingface_cache_dir,
)
if args.sampler == 'dpm':
pipeline.scheduler = DPMSolverMultistepScheduler.from_pretrained(
config.pretrained_model_name_or_path,
subfolder="scheduler",
use_karras_sigmas=True,
)
pipeline = pipeline.to(args.device)
with open(args.ann_path, 'r') as f:
ann = json.load(f)
os.makedirs(args.out_folder, exist_ok=True)
prompt_format = MultilingualPromptFormat()
texts = copy.deepcopy(ann['texts'])
bboxes = copy.deepcopy(ann['bbox'])
styles = copy.deepcopy(ann['styles'])
text_prompt = prompt_format.format_prompt(texts, styles)
if 'seed' not in ann:
generator = torch.Generator(device=args.device)
else:
generator = torch.Generator(device=args.device).manual_seed(ann['seed'])
with torch.cuda.amp.autocast():
image = pipeline(
prompt=ann['bg_prompt'],
text_prompt=text_prompt,
texts=texts,
bboxes=bboxes,
num_inference_steps=50,
generator=generator,
text_attn_mask=None,
guidance_scale=args.cfg,
).images[0]
image.save(f'{args.out_folder}/result.png')