Spaces:
Runtime error
Runtime error
| ### This file contains impls for underlying related models (CLIP, T5, etc) | |
| import logging | |
| import math | |
| import os | |
| import torch | |
| from torch import nn | |
| from transformers import CLIPTokenizer, T5TokenizerFast | |
| from einops import rearrange | |
| ################################################################################################# | |
| ### Core/Utility | |
| ################################################################################################# | |
| def attention(q, k, v, heads, mask=None): | |
| """Convenience wrapper around a basic attention operation""" | |
| b, _, dim_head = q.shape | |
| dim_head //= heads | |
| q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) | |
| out = torch.nn.functional.scaled_dot_product_attention( | |
| q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False | |
| ) | |
| return out.transpose(1, 2).reshape(b, -1, heads * dim_head) | |
| class Mlp(nn.Module): | |
| """MLP as used in Vision Transformer, MLP-Mixer and related networks""" | |
| def __init__( | |
| self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| act_layer=nn.GELU, | |
| bias=True, | |
| dtype=None, | |
| device=None, | |
| ): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear( | |
| in_features, hidden_features, bias=bias, dtype=dtype, device=device | |
| ) | |
| self.act = act_layer | |
| self.fc2 = nn.Linear( | |
| hidden_features, out_features, bias=bias, dtype=dtype, device=device | |
| ) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.fc2(x) | |
| return x | |
| ################################################################################################# | |
| ### CLIP | |
| ################################################################################################# | |
| class CLIPAttention(torch.nn.Module): | |
| def __init__(self, embed_dim, heads, dtype, device): | |
| super().__init__() | |
| self.heads = heads | |
| self.q_proj = nn.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| self.k_proj = nn.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| self.v_proj = nn.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| self.out_proj = nn.Linear( | |
| embed_dim, embed_dim, bias=True, dtype=dtype, device=device | |
| ) | |
| def forward(self, x, mask=None): | |
| q = self.q_proj(x) | |
| k = self.k_proj(x) | |
| v = self.v_proj(x) | |
| out = attention(q, k, v, self.heads, mask) | |
| return self.out_proj(out) | |
| ACTIVATIONS = { | |
| "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), | |
| "gelu": torch.nn.functional.gelu, | |
| } | |
| class CLIPLayer(torch.nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| heads, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| ): | |
| super().__init__() | |
| self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) | |
| self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) | |
| self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) | |
| # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) | |
| self.mlp = Mlp( | |
| embed_dim, | |
| intermediate_size, | |
| embed_dim, | |
| act_layer=ACTIVATIONS[intermediate_activation], | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def forward(self, x, mask=None): | |
| x += self.self_attn(self.layer_norm1(x), mask) | |
| x += self.mlp(self.layer_norm2(x)) | |
| return x | |
| class CLIPEncoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_layers, | |
| embed_dim, | |
| heads, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| ): | |
| super().__init__() | |
| self.layers = torch.nn.ModuleList( | |
| [ | |
| CLIPLayer( | |
| embed_dim, | |
| heads, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| def forward(self, x, mask=None, intermediate_output=None): | |
| if intermediate_output is not None: | |
| if intermediate_output < 0: | |
| intermediate_output = len(self.layers) + intermediate_output | |
| intermediate = None | |
| for i, l in enumerate(self.layers): | |
| x = l(x, mask) | |
| if i == intermediate_output: | |
| intermediate = x.clone() | |
| return x, intermediate | |
| class CLIPEmbeddings(torch.nn.Module): | |
| def __init__( | |
| self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None | |
| ): | |
| super().__init__() | |
| self.token_embedding = torch.nn.Embedding( | |
| vocab_size, embed_dim, dtype=dtype, device=device | |
| ) | |
| self.position_embedding = torch.nn.Embedding( | |
| num_positions, embed_dim, dtype=dtype, device=device | |
| ) | |
| def forward(self, input_tokens): | |
| return self.token_embedding(input_tokens) + self.position_embedding.weight | |
| class CLIPTextModel_(torch.nn.Module): | |
| def __init__(self, config_dict, dtype, device): | |
| num_layers = config_dict["num_hidden_layers"] | |
| embed_dim = config_dict["hidden_size"] | |
| heads = config_dict["num_attention_heads"] | |
| intermediate_size = config_dict["intermediate_size"] | |
| intermediate_activation = config_dict["hidden_act"] | |
| super().__init__() | |
| self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) | |
| self.encoder = CLIPEncoder( | |
| num_layers, | |
| embed_dim, | |
| heads, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| ) | |
| self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) | |
| def forward( | |
| self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True | |
| ): | |
| x = self.embeddings(input_tokens) | |
| causal_mask = ( | |
| torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) | |
| .fill_(float("-inf")) | |
| .triu_(1) | |
| ) | |
| x, i = self.encoder( | |
| x, mask=causal_mask, intermediate_output=intermediate_output | |
| ) | |
| x = self.final_layer_norm(x) | |
| if i is not None and final_layer_norm_intermediate: | |
| i = self.final_layer_norm(i) | |
| pooled_output = x[ | |
| torch.arange(x.shape[0], device=x.device), | |
| input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), | |
| ] | |
| return x, i, pooled_output | |
| class CLIPTextModel(torch.nn.Module): | |
| def __init__(self, config_dict, dtype, device): | |
| super().__init__() | |
| self.num_layers = config_dict["num_hidden_layers"] | |
| self.text_model = CLIPTextModel_(config_dict, dtype, device) | |
| embed_dim = config_dict["hidden_size"] | |
| self.text_projection = nn.Linear( | |
| embed_dim, embed_dim, bias=False, dtype=dtype, device=device | |
| ) | |
| self.text_projection.weight.copy_(torch.eye(embed_dim)) | |
| self.dtype = dtype | |
| def get_input_embeddings(self): | |
| return self.text_model.embeddings.token_embedding | |
| def set_input_embeddings(self, embeddings): | |
| self.text_model.embeddings.token_embedding = embeddings | |
| def forward(self, *args, **kwargs): | |
| x = self.text_model(*args, **kwargs) | |
| out = self.text_projection(x[2]) | |
| return (x[0], x[1], out, x[2]) | |
| def parse_parentheses(string): | |
| result = [] | |
| current_item = "" | |
| nesting_level = 0 | |
| for char in string: | |
| if char == "(": | |
| if nesting_level == 0: | |
| if current_item: | |
| result.append(current_item) | |
| current_item = "(" | |
| else: | |
| current_item = "(" | |
| else: | |
| current_item += char | |
| nesting_level += 1 | |
| elif char == ")": | |
| nesting_level -= 1 | |
| if nesting_level == 0: | |
| result.append(current_item + ")") | |
| current_item = "" | |
| else: | |
| current_item += char | |
| else: | |
| current_item += char | |
| if current_item: | |
| result.append(current_item) | |
| return result | |
| def token_weights(string, current_weight): | |
| a = parse_parentheses(string) | |
| out = [] | |
| for x in a: | |
| weight = current_weight | |
| if len(x) >= 2 and x[-1] == ")" and x[0] == "(": | |
| x = x[1:-1] | |
| xx = x.rfind(":") | |
| weight *= 1.1 | |
| if xx > 0: | |
| try: | |
| weight = float(x[xx + 1 :]) | |
| x = x[:xx] | |
| except: | |
| pass | |
| out += token_weights(x, weight) | |
| else: | |
| out += [(x, current_weight)] | |
| return out | |
| def escape_important(text): | |
| text = text.replace("\\)", "\0\1") | |
| text = text.replace("\\(", "\0\2") | |
| return text | |
| def unescape_important(text): | |
| text = text.replace("\0\1", ")") | |
| text = text.replace("\0\2", "(") | |
| return text | |
| class SDTokenizer: | |
| def __init__( | |
| self, | |
| max_length=77, | |
| pad_with_end=True, | |
| tokenizer=None, | |
| has_start_token=True, | |
| pad_to_max_length=True, | |
| min_length=None, | |
| extra_padding_token=None, | |
| ): | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| self.min_length = min_length | |
| empty = self.tokenizer("")["input_ids"] | |
| if has_start_token: | |
| self.tokens_start = 1 | |
| self.start_token = empty[0] | |
| self.end_token = empty[1] | |
| else: | |
| self.tokens_start = 0 | |
| self.start_token = None | |
| self.end_token = empty[0] | |
| self.pad_with_end = pad_with_end | |
| self.pad_to_max_length = pad_to_max_length | |
| self.extra_padding_token = extra_padding_token | |
| vocab = self.tokenizer.get_vocab() | |
| self.inv_vocab = {v: k for k, v in vocab.items()} | |
| self.max_word_length = 8 | |
| def tokenize_with_weights(self, text: str, return_word_ids=False): | |
| """ | |
| Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. | |
| The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3. | |
| """ | |
| if self.pad_with_end: | |
| pad_token = self.end_token | |
| else: | |
| pad_token = 0 | |
| text = escape_important(text) | |
| parsed_weights = token_weights(text, 1.0) | |
| # tokenize words | |
| tokens = [] | |
| for weighted_segment, weight in parsed_weights: | |
| to_tokenize = ( | |
| unescape_important(weighted_segment).replace("\n", " ").split(" ") | |
| ) | |
| to_tokenize = [x for x in to_tokenize if x != ""] | |
| for word in to_tokenize: | |
| # parse word | |
| tokens.append( | |
| [ | |
| (t, weight) | |
| for t in self.tokenizer(word)["input_ids"][ | |
| self.tokens_start : -1 | |
| ] | |
| ] | |
| ) | |
| # reshape token array to CLIP input size | |
| batched_tokens = [] | |
| batch = [] | |
| if self.start_token is not None: | |
| batch.append((self.start_token, 1.0, 0)) | |
| batched_tokens.append(batch) | |
| for i, t_group in enumerate(tokens): | |
| # determine if we're going to try and keep the tokens in a single batch | |
| is_large = len(t_group) >= self.max_word_length | |
| while len(t_group) > 0: | |
| if len(t_group) + len(batch) > self.max_length - 1: | |
| remaining_length = self.max_length - len(batch) - 1 | |
| # break word in two and add end token | |
| if is_large: | |
| batch.extend( | |
| [(t, w, i + 1) for t, w in t_group[:remaining_length]] | |
| ) | |
| batch.append((self.end_token, 1.0, 0)) | |
| t_group = t_group[remaining_length:] | |
| # add end token and pad | |
| else: | |
| batch.append((self.end_token, 1.0, 0)) | |
| if self.pad_to_max_length: | |
| batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) | |
| # start new batch | |
| batch = [] | |
| if self.start_token is not None: | |
| batch.append((self.start_token, 1.0, 0)) | |
| batched_tokens.append(batch) | |
| else: | |
| batch.extend([(t, w, i + 1) for t, w in t_group]) | |
| t_group = [] | |
| # pad extra padding token first befor getting to the end token | |
| if self.extra_padding_token is not None: | |
| batch.extend( | |
| [(self.extra_padding_token, 1.0, 0)] | |
| * (self.min_length - len(batch) - 1) | |
| ) | |
| # fill last batch | |
| batch.append((self.end_token, 1.0, 0)) | |
| if self.pad_to_max_length: | |
| batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch))) | |
| if self.min_length is not None and len(batch) < self.min_length: | |
| batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch))) | |
| if not return_word_ids: | |
| batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens] | |
| return batched_tokens | |
| def untokenize(self, token_weight_pair): | |
| return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) | |
| class SDXLClipGTokenizer(SDTokenizer): | |
| def __init__(self, tokenizer): | |
| super().__init__(pad_with_end=False, tokenizer=tokenizer) | |
| class SD3Tokenizer: | |
| def __init__(self): | |
| clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) | |
| self.clip_g = SDXLClipGTokenizer(clip_tokenizer) | |
| self.t5xxl = T5XXLTokenizer() | |
| def tokenize_with_weights(self, text: str): | |
| out = {} | |
| out["l"] = self.clip_l.tokenize_with_weights(text) | |
| out["g"] = self.clip_g.tokenize_with_weights(text) | |
| out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226]) | |
| return out | |
| class ClipTokenWeightEncoder: | |
| def encode_token_weights(self, token_weight_pairs): | |
| tokens = list(map(lambda a: a[0], token_weight_pairs[0])) | |
| out, pooled = self([tokens]) | |
| if pooled is not None: | |
| first_pooled = pooled[0:1].cpu() | |
| else: | |
| first_pooled = pooled | |
| output = [out[0:1]] | |
| return torch.cat(output, dim=-2).cpu(), first_pooled | |
| class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): | |
| """Uses the CLIP transformer encoder for text (from huggingface)""" | |
| LAYERS = ["last", "pooled", "hidden"] | |
| def __init__( | |
| self, | |
| device="cpu", | |
| max_length=77, | |
| layer="last", | |
| layer_idx=None, | |
| textmodel_json_config=None, | |
| dtype=None, | |
| model_class=CLIPTextModel, | |
| special_tokens={"start": 49406, "end": 49407, "pad": 49407}, | |
| layer_norm_hidden_state=True, | |
| return_projected_pooled=True, | |
| ): | |
| super().__init__() | |
| assert layer in self.LAYERS | |
| self.transformer = model_class(textmodel_json_config, dtype, device) | |
| self.num_layers = self.transformer.num_layers | |
| self.max_length = max_length | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| self.layer = layer | |
| self.layer_idx = None | |
| self.special_tokens = special_tokens | |
| self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) | |
| self.layer_norm_hidden_state = layer_norm_hidden_state | |
| self.return_projected_pooled = return_projected_pooled | |
| if layer == "hidden": | |
| assert layer_idx is not None | |
| assert abs(layer_idx) < self.num_layers | |
| self.set_clip_options({"layer": layer_idx}) | |
| self.options_default = ( | |
| self.layer, | |
| self.layer_idx, | |
| self.return_projected_pooled, | |
| ) | |
| def set_clip_options(self, options): | |
| layer_idx = options.get("layer", self.layer_idx) | |
| self.return_projected_pooled = options.get( | |
| "projected_pooled", self.return_projected_pooled | |
| ) | |
| if layer_idx is None or abs(layer_idx) > self.num_layers: | |
| self.layer = "last" | |
| else: | |
| self.layer = "hidden" | |
| self.layer_idx = layer_idx | |
| def forward(self, tokens): | |
| backup_embeds = self.transformer.get_input_embeddings() | |
| device = backup_embeds.weight.device | |
| tokens = torch.LongTensor(tokens).to(device) | |
| outputs = self.transformer( | |
| tokens, | |
| intermediate_output=self.layer_idx, | |
| final_layer_norm_intermediate=self.layer_norm_hidden_state, | |
| ) | |
| self.transformer.set_input_embeddings(backup_embeds) | |
| if self.layer == "last": | |
| z = outputs[0] | |
| else: | |
| z = outputs[1] | |
| pooled_output = None | |
| if len(outputs) >= 3: | |
| if ( | |
| not self.return_projected_pooled | |
| and len(outputs) >= 4 | |
| and outputs[3] is not None | |
| ): | |
| pooled_output = outputs[3].float() | |
| elif outputs[2] is not None: | |
| pooled_output = outputs[2].float() | |
| return z.float(), pooled_output | |
| class SDXLClipG(SDClipModel): | |
| """Wraps the CLIP-G model into the SD-CLIP-Model interface""" | |
| def __init__( | |
| self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None | |
| ): | |
| if layer == "penultimate": | |
| layer = "hidden" | |
| layer_idx = -2 | |
| super().__init__( | |
| device=device, | |
| layer=layer, | |
| layer_idx=layer_idx, | |
| textmodel_json_config=config, | |
| dtype=dtype, | |
| special_tokens={"start": 49406, "end": 49407, "pad": 0}, | |
| layer_norm_hidden_state=False, | |
| ) | |
| class T5XXLModel(SDClipModel): | |
| """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" | |
| def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None): | |
| super().__init__( | |
| device=device, | |
| layer=layer, | |
| layer_idx=layer_idx, | |
| textmodel_json_config=config, | |
| dtype=dtype, | |
| special_tokens={"end": 1, "pad": 0}, | |
| model_class=T5, | |
| ) | |
| ################################################################################################# | |
| ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl | |
| ################################################################################################# | |
| class T5XXLTokenizer(SDTokenizer): | |
| """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" | |
| def __init__(self): | |
| super().__init__( | |
| pad_with_end=False, | |
| tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), | |
| has_start_token=False, | |
| pad_to_max_length=False, | |
| max_length=99999999, | |
| min_length=77, | |
| ) | |
| class T5LayerNorm(torch.nn.Module): | |
| def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): | |
| super().__init__() | |
| self.weight = torch.nn.Parameter( | |
| torch.ones(hidden_size, dtype=dtype, device=device) | |
| ) | |
| self.variance_epsilon = eps | |
| def forward(self, x): | |
| variance = x.pow(2).mean(-1, keepdim=True) | |
| x = x * torch.rsqrt(variance + self.variance_epsilon) | |
| return self.weight.to(device=x.device, dtype=x.dtype) * x | |
| class T5DenseGatedActDense(torch.nn.Module): | |
| def __init__(self, model_dim, ff_dim, dtype, device): | |
| super().__init__() | |
| self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) | |
| self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) | |
| self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) | |
| def forward(self, x): | |
| hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") | |
| hidden_linear = self.wi_1(x) | |
| x = hidden_gelu * hidden_linear | |
| x = self.wo(x) | |
| return x | |
| class T5LayerFF(torch.nn.Module): | |
| def __init__(self, model_dim, ff_dim, dtype, device): | |
| super().__init__() | |
| self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) | |
| self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) | |
| def forward(self, x): | |
| forwarded_states = self.layer_norm(x) | |
| forwarded_states = self.DenseReluDense(forwarded_states) | |
| x += forwarded_states | |
| return x | |
| class T5Attention(torch.nn.Module): | |
| def __init__( | |
| self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device | |
| ): | |
| super().__init__() | |
| # Mesh TensorFlow initialization to avoid scaling before softmax | |
| self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) | |
| self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) | |
| self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) | |
| self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) | |
| self.num_heads = num_heads | |
| self.relative_attention_bias = None | |
| if relative_attention_bias: | |
| self.relative_attention_num_buckets = 32 | |
| self.relative_attention_max_distance = 128 | |
| self.relative_attention_bias = torch.nn.Embedding( | |
| self.relative_attention_num_buckets, self.num_heads, device=device | |
| ) | |
| def _relative_position_bucket( | |
| relative_position, bidirectional=True, num_buckets=32, max_distance=128 | |
| ): | |
| """ | |
| Adapted from Mesh Tensorflow: | |
| https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 | |
| Translate relative position to a bucket number for relative attention. The relative position is defined as | |
| memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to | |
| position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for | |
| small absolute relative_position and larger buckets for larger absolute relative_positions. All relative | |
| positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. | |
| This should allow for more graceful generalization to longer sequences than the model has been trained on | |
| Args: | |
| relative_position: an int32 Tensor | |
| bidirectional: a boolean - whether the attention is bidirectional | |
| num_buckets: an integer | |
| max_distance: an integer | |
| Returns: | |
| a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) | |
| """ | |
| relative_buckets = 0 | |
| if bidirectional: | |
| num_buckets //= 2 | |
| relative_buckets += (relative_position > 0).to(torch.long) * num_buckets | |
| relative_position = torch.abs(relative_position) | |
| else: | |
| relative_position = -torch.min( | |
| relative_position, torch.zeros_like(relative_position) | |
| ) | |
| # now relative_position is in the range [0, inf) | |
| # half of the buckets are for exact increments in positions | |
| max_exact = num_buckets // 2 | |
| is_small = relative_position < max_exact | |
| # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance | |
| relative_position_if_large = max_exact + ( | |
| torch.log(relative_position.float() / max_exact) | |
| / math.log(max_distance / max_exact) | |
| * (num_buckets - max_exact) | |
| ).to(torch.long) | |
| relative_position_if_large = torch.min( | |
| relative_position_if_large, | |
| torch.full_like(relative_position_if_large, num_buckets - 1), | |
| ) | |
| relative_buckets += torch.where( | |
| is_small, relative_position, relative_position_if_large | |
| ) | |
| return relative_buckets | |
| def compute_bias(self, query_length, key_length, device): | |
| """Compute binned relative position bias""" | |
| context_position = torch.arange(query_length, dtype=torch.long, device=device)[ | |
| :, None | |
| ] | |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ | |
| None, : | |
| ] | |
| relative_position = ( | |
| memory_position - context_position | |
| ) # shape (query_length, key_length) | |
| relative_position_bucket = self._relative_position_bucket( | |
| relative_position, # shape (query_length, key_length) | |
| bidirectional=True, | |
| num_buckets=self.relative_attention_num_buckets, | |
| max_distance=self.relative_attention_max_distance, | |
| ) | |
| values = self.relative_attention_bias( | |
| relative_position_bucket | |
| ) # shape (query_length, key_length, num_heads) | |
| values = values.permute([2, 0, 1]).unsqueeze( | |
| 0 | |
| ) # shape (1, num_heads, query_length, key_length) | |
| return values | |
| def forward(self, x, past_bias=None): | |
| q = self.q(x) | |
| k = self.k(x) | |
| v = self.v(x) | |
| if self.relative_attention_bias is not None: | |
| past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) | |
| if past_bias is not None: | |
| mask = past_bias | |
| out = attention( | |
| q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask | |
| ) | |
| return self.o(out), past_bias | |
| class T5LayerSelfAttention(torch.nn.Module): | |
| def __init__( | |
| self, | |
| model_dim, | |
| inner_dim, | |
| ff_dim, | |
| num_heads, | |
| relative_attention_bias, | |
| dtype, | |
| device, | |
| ): | |
| super().__init__() | |
| self.SelfAttention = T5Attention( | |
| model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device | |
| ) | |
| self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) | |
| def forward(self, x, past_bias=None): | |
| output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) | |
| x += output | |
| return x, past_bias | |
| class T5Block(torch.nn.Module): | |
| def __init__( | |
| self, | |
| model_dim, | |
| inner_dim, | |
| ff_dim, | |
| num_heads, | |
| relative_attention_bias, | |
| dtype, | |
| device, | |
| ): | |
| super().__init__() | |
| self.layer = torch.nn.ModuleList() | |
| self.layer.append( | |
| T5LayerSelfAttention( | |
| model_dim, | |
| inner_dim, | |
| ff_dim, | |
| num_heads, | |
| relative_attention_bias, | |
| dtype, | |
| device, | |
| ) | |
| ) | |
| self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) | |
| def forward(self, x, past_bias=None): | |
| x, past_bias = self.layer[0](x, past_bias) | |
| x = self.layer[-1](x) | |
| return x, past_bias | |
| class T5Stack(torch.nn.Module): | |
| def __init__( | |
| self, | |
| num_layers, | |
| model_dim, | |
| inner_dim, | |
| ff_dim, | |
| num_heads, | |
| vocab_size, | |
| dtype, | |
| device, | |
| ): | |
| super().__init__() | |
| self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device) | |
| self.block = torch.nn.ModuleList( | |
| [ | |
| T5Block( | |
| model_dim, | |
| inner_dim, | |
| ff_dim, | |
| num_heads, | |
| relative_attention_bias=(i == 0), | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) | |
| def forward( | |
| self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True | |
| ): | |
| intermediate = None | |
| x = self.embed_tokens(input_ids) | |
| past_bias = None | |
| for i, l in enumerate(self.block): | |
| x, past_bias = l(x, past_bias) | |
| if i == intermediate_output: | |
| intermediate = x.clone() | |
| x = self.final_layer_norm(x) | |
| if intermediate is not None and final_layer_norm_intermediate: | |
| intermediate = self.final_layer_norm(intermediate) | |
| return x, intermediate | |
| class T5(torch.nn.Module): | |
| def __init__(self, config_dict, dtype, device): | |
| super().__init__() | |
| self.num_layers = config_dict["num_layers"] | |
| self.encoder = T5Stack( | |
| self.num_layers, | |
| config_dict["d_model"], | |
| config_dict["d_model"], | |
| config_dict["d_ff"], | |
| config_dict["num_heads"], | |
| config_dict["vocab_size"], | |
| dtype, | |
| device, | |
| ) | |
| self.dtype = dtype | |
| def get_input_embeddings(self): | |
| return self.encoder.embed_tokens | |
| def set_input_embeddings(self, embeddings): | |
| self.encoder.embed_tokens = embeddings | |
| def forward(self, *args, **kwargs): | |
| return self.encoder(*args, **kwargs) | |