Spaces:
Runtime error
Runtime error
| from typing import List, Optional, Union | |
| import torch | |
| from transformers import T5EncoderModel, T5Tokenizer | |
| def _get_t5_prompt_embeds( | |
| tokenizer: T5Tokenizer, | |
| text_encoder: T5EncoderModel, | |
| prompt: Union[str, List[str]], | |
| num_videos_per_prompt: int = 1, | |
| max_sequence_length: int = 226, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| text_input_ids=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") | |
| prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| _, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) | |
| return prompt_embeds | |
| def encode_prompt( | |
| tokenizer: T5Tokenizer, | |
| text_encoder: T5EncoderModel, | |
| prompt: Union[str, List[str]], | |
| num_videos_per_prompt: int = 1, | |
| max_sequence_length: int = 226, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| text_input_ids=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| prompt_embeds = _get_t5_prompt_embeds( | |
| tokenizer, | |
| text_encoder, | |
| prompt=prompt, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| dtype=dtype, | |
| text_input_ids=text_input_ids, | |
| ) | |
| return prompt_embeds | |
| def compute_prompt_embeddings( | |
| tokenizer: T5Tokenizer, | |
| text_encoder: T5EncoderModel, | |
| prompt: str, | |
| max_sequence_length: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| requires_grad: bool = False, | |
| ): | |
| if requires_grad: | |
| prompt_embeds = encode_prompt( | |
| tokenizer, | |
| text_encoder, | |
| prompt, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| else: | |
| with torch.no_grad(): | |
| prompt_embeds = encode_prompt( | |
| tokenizer, | |
| text_encoder, | |
| prompt, | |
| num_videos_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| return prompt_embeds | |