Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| def tokenize_prompt(tokenizer, prompt, max_sequence_length): | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| return text_input_ids | |
| def _encode_prompt_with_t5( | |
| text_encoder, | |
| tokenizer, | |
| max_sequence_length=512, | |
| prompt=None, | |
| num_images_per_prompt=1, | |
| device=None, | |
| text_input_ids=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
| prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
| if hasattr(text_encoder, "module"): | |
| dtype = text_encoder.module.dtype | |
| else: | |
| dtype = text_encoder.dtype | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds | |
| def _encode_prompt_with_clip( | |
| text_encoder, | |
| tokenizer, | |
| prompt: str, | |
| device=None, | |
| text_input_ids=None, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
| if hasattr(text_encoder, "module"): | |
| dtype = text_encoder.module.dtype | |
| else: | |
| dtype = text_encoder.dtype | |
| # Use pooled output of CLIPTextModel | |
| prompt_embeds = prompt_embeds.pooler_output | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) | |
| return prompt_embeds | |
| def encode_prompt( | |
| text_encoders, | |
| tokenizers, | |
| prompt: str, | |
| max_sequence_length, | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| text_input_ids_list=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if hasattr(text_encoders[0], "module"): | |
| dtype = text_encoders[0].module.dtype | |
| else: | |
| dtype = text_encoders[0].dtype | |
| pooled_prompt_embeds = _encode_prompt_with_clip( | |
| text_encoder=text_encoders[0], | |
| tokenizer=tokenizers[0], | |
| prompt=prompt, | |
| device=device if device is not None else text_encoders[0].device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, | |
| ) | |
| prompt_embeds = _encode_prompt_with_t5( | |
| text_encoder=text_encoders[1], | |
| tokenizer=tokenizers[1], | |
| max_sequence_length=max_sequence_length, | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device if device is not None else text_encoders[1].device, | |
| text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, | |
| ) | |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
| return prompt_embeds, pooled_prompt_embeds, text_ids | |
| def compute_text_embeddings( args, accelerator, prompt, text_encoders, tokenizers): | |
| with torch.no_grad(): | |
| prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( | |
| text_encoders, tokenizers, prompt, args.max_sequence_length | |
| ) | |
| prompt_embeds = prompt_embeds.to(accelerator.device) | |
| pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) | |
| text_ids = text_ids.to(accelerator.device) | |
| return prompt_embeds, pooled_prompt_embeds, text_ids | |
| def get_sigmas(noise_scheduler_copy,accelerator, timesteps, n_dim=4, dtype=torch.float32): | |
| sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) | |
| schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) | |
| timesteps = timesteps.to(accelerator.device) | |
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
| sigma = sigmas[step_indices].flatten() | |
| while len(sigma.shape) < n_dim: | |
| sigma = sigma.unsqueeze(-1) | |
| return sigma |