TiM / tim /models /utils /text_encoders.py
blanchon's picture
Update
3ed0796
import os
import torch
from transformers import T5EncoderModel, Gemma3ForCausalLM, AutoTokenizer
# load text-encoder
def load_text_encoder(text_encoder_dir, device, weight_dtype):
os.environ["TOKENIZERS_PARALLELISM"] = "true"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)
if "gemma" in text_encoder_dir:
tokenizer.padding_side = "right"
text_encoder = Gemma3ForCausalLM.from_pretrained(
text_encoder_dir,
attn_implementation="sdpa",
device_map="cpu",
dtype=weight_dtype,
)
elif "t5" in text_encoder_dir:
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_dir,
attn_implementation="sdpa",
device_map="cpu",
dtype=weight_dtype,
)
else:
raise NotImplementedError
# Set requires_grad to False for all parameters to avoid functorch issues
# for param in text_encoder.parameters():
# param.requires_grad = False
text_encoder.model = text_encoder.model.eval().to(device=device, dtype=weight_dtype)
return text_encoder, tokenizer
def encode_prompt(
tokenizer,
text_encoder,
device,
weight_dtype,
captions,
use_last_hidden_state,
max_seq_length=256,
):
text_inputs = tokenizer(
captions,
padding="max_length",
max_length=max_seq_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device)
with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype):
results = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
)
if use_last_hidden_state:
prompt_embeds = results.last_hidden_state
else: # from Imagen paper
prompt_embeds = results.hidden_states[-2]
return prompt_embeds, prompt_masks