File size: 2,020 Bytes
3ed0796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import torch
from transformers import T5EncoderModel, Gemma3ForCausalLM, AutoTokenizer
# load text-encoder
def load_text_encoder(text_encoder_dir, device, weight_dtype):
os.environ["TOKENIZERS_PARALLELISM"] = "true"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)
if "gemma" in text_encoder_dir:
tokenizer.padding_side = "right"
text_encoder = Gemma3ForCausalLM.from_pretrained(
text_encoder_dir,
attn_implementation="sdpa",
device_map="cpu",
dtype=weight_dtype,
)
elif "t5" in text_encoder_dir:
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_dir,
attn_implementation="sdpa",
device_map="cpu",
dtype=weight_dtype,
)
else:
raise NotImplementedError
# Set requires_grad to False for all parameters to avoid functorch issues
# for param in text_encoder.parameters():
# param.requires_grad = False
text_encoder.model = text_encoder.model.eval().to(device=device, dtype=weight_dtype)
return text_encoder, tokenizer
def encode_prompt(
tokenizer,
text_encoder,
device,
weight_dtype,
captions,
use_last_hidden_state,
max_seq_length=256,
):
text_inputs = tokenizer(
captions,
padding="max_length",
max_length=max_seq_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device)
with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype):
results = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
)
if use_last_hidden_state:
prompt_embeds = results.last_hidden_state
else: # from Imagen paper
prompt_embeds = results.hidden_states[-2]
return prompt_embeds, prompt_masks
|