Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import warnings | |
| import logging | |
| from torch import Tensor | |
| from diffusers import ModelMixin | |
| from transformers.models.t5.modeling_t5 import T5LayerSelfAttention, T5LayerFF, T5LayerNorm | |
| logger = logging.getLogger(__name__) | |
| class T5EncoderBlock(nn.Module): | |
| def __init__(self, config, has_relative_attention_bias=False): | |
| super().__init__() | |
| self.layer = nn.ModuleList() | |
| self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) | |
| self.layer.append(T5LayerFF(config)) | |
| def forward( | |
| self, | |
| hidden_states, | |
| attention_mask=None, | |
| position_bias=None, | |
| layer_head_mask=None, | |
| output_attentions=False, | |
| ): | |
| self_attn_past_key_value, cross_attn_past_key_value = None, None | |
| self_attention_outputs = self.layer[0]( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| layer_head_mask=layer_head_mask, | |
| past_key_value=self_attn_past_key_value, | |
| use_cache=False, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states, present_key_value_state = self_attention_outputs[:2] | |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights | |
| # clamp inf values to enable fp16 training | |
| if hidden_states.dtype == torch.float16: | |
| clamp_value = torch.where( | |
| torch.isinf(hidden_states).any(), | |
| torch.finfo(hidden_states.dtype).max - 1000, | |
| torch.finfo(hidden_states.dtype).max, | |
| ) | |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
| # Apply Feed Forward layer | |
| hidden_states = self.layer[-1](hidden_states) | |
| # clamp inf values to enable fp16 training | |
| if hidden_states.dtype == torch.float16: | |
| clamp_value = torch.where( | |
| torch.isinf(hidden_states).any(), | |
| torch.finfo(hidden_states.dtype).max - 1000, | |
| torch.finfo(hidden_states.dtype).max, | |
| ) | |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) | |
| outputs = (hidden_states,) + attention_outputs | |
| return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) | |
| class T5EncoderBlockByT5Mapper(ModelMixin): | |
| def __init__(self, byt5_config, num_layers, sdxl_channels=None): | |
| super().__init__() | |
| if num_layers > 0: | |
| self.blocks = nn.ModuleList( | |
| [ | |
| T5EncoderBlock( | |
| byt5_config, | |
| has_relative_attention_bias=bool(i == 0)) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| else: | |
| self.blocks = None | |
| self.layer_norm = T5LayerNorm(byt5_config.d_model, eps=byt5_config.layer_norm_epsilon) | |
| if sdxl_channels is not None: | |
| self.channel_mapper = nn.Linear(byt5_config.d_model, sdxl_channels) | |
| self.final_layer_norm = T5LayerNorm(sdxl_channels, eps=byt5_config.layer_norm_epsilon) | |
| else: | |
| self.channel_mapper = None | |
| self.final_layer_norm = None | |
| def get_extended_attention_mask( | |
| self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None | |
| ) -> Tensor: | |
| """ | |
| Makes broadcastable attention and causal masks so that future and masked tokens are ignored. | |
| Arguments: | |
| attention_mask (`torch.Tensor`): | |
| Mask with ones indicating tokens to attend to, zeros for tokens to ignore. | |
| input_shape (`Tuple[int]`): | |
| The shape of the input to the model. | |
| Returns: | |
| `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. | |
| """ | |
| if dtype is None: | |
| dtype = self.dtype | |
| # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | |
| # ourselves in which case we just need to make it broadcastable to all heads. | |
| if attention_mask.dim() == 3: | |
| extended_attention_mask = attention_mask[:, None, :, :] | |
| elif attention_mask.dim() == 2: | |
| # Provided a padding mask of dimensions [batch_size, seq_length] | |
| # - if the model is a decoder, apply a causal mask in addition to the padding mask | |
| # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] | |
| extended_attention_mask = attention_mask[:, None, None, :] | |
| else: | |
| raise ValueError( | |
| f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" | |
| ) | |
| # Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |
| # masked positions, this operation will create a tensor which is 0.0 for | |
| # positions we want to attend and the dtype's smallest value for masked positions. | |
| # Since we are adding it to the raw scores before the softmax, this is | |
| # effectively the same as removing these entirely. | |
| extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility | |
| extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min | |
| return extended_attention_mask | |
| def forward(self, inputs_embeds, attention_mask): | |
| input_shape = inputs_embeds.size()[:-1] | |
| extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) | |
| hidden_states = inputs_embeds | |
| position_bias = None | |
| if self.blocks is not None: | |
| for layer_module in self.blocks: | |
| layer_outputs = layer_module( | |
| hidden_states, | |
| attention_mask=extended_attention_mask, | |
| position_bias=position_bias, | |
| ) | |
| hidden_states, position_bias = layer_outputs | |
| hidden_states = self.layer_norm(hidden_states) | |
| if self.channel_mapper is not None: | |
| hidden_states = self.channel_mapper(hidden_states) | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| return hidden_states | |