Spaces:
Runtime error
Runtime error
| from other_impls import SD3Tokenizer, SDClipModel, SDXLClipG, T5XXLModel | |
| from safetensors import safe_open | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| def load_into(ckpt, model, prefix, device, dtype=None, remap=None): | |
| """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" | |
| for key in ckpt.keys(): | |
| model_key = key | |
| if remap is not None and key in remap: | |
| model_key = remap[key] | |
| if model_key.startswith(prefix) and not model_key.startswith("loss."): | |
| path = model_key[len(prefix) :].split(".") | |
| obj = model | |
| for p in path: | |
| if obj is list: | |
| obj = obj[int(p)] | |
| else: | |
| obj = getattr(obj, p, None) | |
| if obj is None: | |
| print( | |
| f"Skipping key '{model_key}' in safetensors file as '{p}' does not exist in python model" | |
| ) | |
| break | |
| if obj is None: | |
| continue | |
| try: | |
| tensor = ckpt.get_tensor(key).to(device=device) | |
| if dtype is not None and tensor.dtype != torch.int32: | |
| tensor = tensor.to(dtype=dtype) | |
| obj.requires_grad_(False) | |
| # print(f"K: {model_key}, O: {obj.shape} T: {tensor.shape}") | |
| if obj.shape != tensor.shape: | |
| print( | |
| f"W: shape mismatch for key {model_key}, {obj.shape} != {tensor.shape}" | |
| ) | |
| obj.set_(tensor) | |
| except Exception as e: | |
| print(f"Failed to load key '{key}' in safetensors file: {e}") | |
| raise e | |
| CLIPG_CONFIG = { | |
| "hidden_act": "gelu", | |
| "hidden_size": 1280, | |
| "intermediate_size": 5120, | |
| "num_attention_heads": 20, | |
| "num_hidden_layers": 32, | |
| } | |
| class ClipG: | |
| def __init__(self, model_folder: str, device: str = "cpu"): | |
| safetensors_path = hf_hub_download( | |
| repo_id=model_folder, | |
| filename="clip_g.safetensors", | |
| cache_dir=None | |
| ) | |
| with safe_open( | |
| # f"{model_folder}/clip_g.safetensors", framework="pt", device="cpu" | |
| safetensors_path, framework="pt", device="cpu" | |
| ) as f: | |
| self.model = SDXLClipG(CLIPG_CONFIG, device=device, dtype=torch.float32) | |
| load_into(f, self.model.transformer, "", device, torch.float32) | |
| CLIPL_CONFIG = { | |
| "hidden_act": "quick_gelu", | |
| "hidden_size": 768, | |
| "intermediate_size": 3072, | |
| "num_attention_heads": 12, | |
| "num_hidden_layers": 12, | |
| } | |
| class ClipL: | |
| def __init__(self, model_folder: str): | |
| safetensors_path = hf_hub_download( | |
| repo_id=model_folder, | |
| filename="clip_l.safetensors", | |
| cache_dir=None | |
| ) | |
| with safe_open( | |
| # f"{model_folder}/clip_l.safetensors", framework="pt", device="cpu" | |
| safetensors_path, framework="pt", device="cpu" | |
| ) as f: | |
| self.model = SDClipModel( | |
| layer="hidden", | |
| layer_idx=-2, | |
| device="cpu", | |
| dtype=torch.float32, | |
| layer_norm_hidden_state=False, | |
| return_projected_pooled=False, | |
| textmodel_json_config=CLIPL_CONFIG, | |
| ) | |
| load_into(f, self.model.transformer, "", "cpu", torch.float32) | |
| T5_CONFIG = { | |
| "d_ff": 10240, | |
| "d_model": 4096, | |
| "num_heads": 64, | |
| "num_layers": 24, | |
| "vocab_size": 32128, | |
| } | |
| class T5XXL: | |
| def __init__(self, model_folder: str, device: str = "cpu", dtype=torch.float32): | |
| safetensors_path = hf_hub_download( | |
| repo_id=model_folder, | |
| filename="t5xxl.safetensors", | |
| cache_dir=None | |
| ) | |
| with safe_open( | |
| # f"{model_folder}/t5xxl.safetensors", framework="pt", device="cpu" | |
| safetensors_path, framework="pt", device="cpu" | |
| ) as f: | |
| self.model = T5XXLModel(T5_CONFIG, device=device, dtype=dtype) | |
| load_into(f, self.model.transformer, "", device, dtype) | |
| tokenizer = SD3Tokenizer() | |
| text_encoder_device = "cpu" | |
| model_folder = "stabilityai/stable-diffusion-3.5-medium" | |
| print("Loading Google T5-v1-XXL...") | |
| t5xxl = T5XXL(model_folder, text_encoder_device, torch.float32) | |
| print("Loading OpenAI CLIP L...") | |
| clip_l = ClipL(model_folder) | |
| print("Loading OpenCLIP bigG...") | |
| clip_g = ClipG(model_folder, text_encoder_device) | |
| def get_cond(self, prompt): | |
| print("Encode prompt...") | |
| tokens = tokenizer.tokenize_with_weights(prompt) | |
| l_out, l_pooled = clip_l.model.encode_token_weights(tokens["l"]) | |
| g_out, g_pooled = clip_g.model.encode_token_weights(tokens["g"]) | |
| t5_out, t5_pooled = t5xxl.model.encode_token_weights(tokens["t5xxl"]) | |
| lg_out = torch.cat([l_out, g_out], dim=-1) | |
| lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) | |
| return torch.cat([lg_out, t5_out], dim=-2), torch.cat( | |
| (l_pooled, g_pooled), dim=-1 | |
| ) |