Spaces:
Build error
Build error
| import math | |
| from einops import rearrange, repeat | |
| import torch | |
| from torch import nn, einsum | |
| import torch.nn.functional as F | |
| from axial_positional_embedding import AxialPositionalEmbedding | |
| from text2punks.transformer import Transformer | |
| # helpers fns | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def set_requires_grad(model, value): | |
| for param in model.parameters(): | |
| param.requires_grad = value | |
| def eval_decorator(fn): | |
| def inner(model, *args, **kwargs): | |
| was_training = model.training | |
| model.eval() | |
| out = fn(model, *args, **kwargs) | |
| model.train(was_training) | |
| return out | |
| return inner | |
| # sampling helpers fn | |
| def top_k(logits, thres = 0.5): | |
| num_logits = logits.shape[-1] | |
| k = max(int((1 - thres) * num_logits), 1) | |
| val, ind = torch.topk(logits, k) | |
| probs = torch.full_like(logits, float('-inf')) | |
| probs.scatter_(1, ind, val) | |
| return probs | |
| # main CLIP class | |
| class CLIP(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim_text = 512, | |
| dim_image = 512, | |
| dim_latent = 512, | |
| num_text_tokens = 10000, | |
| text_enc_depth = 6, | |
| text_seq_len = 256, | |
| text_heads = 8, | |
| num_visual_tokens = 256, | |
| visual_enc_depth = 6, | |
| visual_image_seq_len = 256, | |
| visual_image_size = 24, | |
| visual_heads = 8, | |
| attn_pdrop = 0.1, | |
| resid_pdrop = 0.1, | |
| embd_pdrop = 0.1, | |
| ff_dropout = 0.1, | |
| attn_types = None | |
| ): | |
| super().__init__() | |
| # Texts | |
| self.text_emb = nn.Embedding(num_text_tokens, dim_text) | |
| self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) | |
| self.text_transformer = Transformer( | |
| dim = dim_text, | |
| causal = False, | |
| seq_len = text_seq_len, | |
| depth = text_enc_depth, | |
| heads = text_heads, | |
| dim_head = dim_text // text_heads, | |
| attn_dropout = attn_pdrop, | |
| resid_dropout = resid_pdrop, | |
| embd_dropout = embd_pdrop, | |
| ff_dropout = ff_dropout, | |
| attn_types = attn_types | |
| ) | |
| self.text_ln = nn.LayerNorm(dim_text) | |
| self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False) | |
| # Images | |
| self.image_emb = nn.Embedding(num_visual_tokens, dim_image) | |
| self.image_pos_emb = nn.Embedding(visual_image_seq_len, dim_image) | |
| self.visual_transformer = Transformer( | |
| dim = dim_image, | |
| causal = False, | |
| seq_len = visual_image_seq_len, | |
| depth = visual_enc_depth, | |
| heads = visual_heads, | |
| dim_head = dim_image // visual_heads, | |
| attn_dropout = attn_pdrop, | |
| resid_dropout = resid_pdrop, | |
| embd_dropout = embd_pdrop, | |
| ff_dropout = ff_dropout, | |
| attn_types = attn_types, | |
| image_size = visual_image_size, | |
| ) | |
| self.image_ln = nn.LayerNorm(dim_image) | |
| self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False) | |
| self.temperature = nn.Parameter(torch.ones([]) * math.log(1 / 0.07)) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def forward( | |
| self, | |
| text, | |
| image, | |
| return_loss = False | |
| ): | |
| b, device= text.shape[0], text.device | |
| text_emb = self.text_emb(text) | |
| text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device)) | |
| image_emb = self.image_emb(image) | |
| image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device)) | |
| enc_text = self.text_transformer(text_emb) | |
| enc_image = self.visual_transformer(image_emb) | |
| text_latents = enc_text.mean(dim = 1) | |
| image_latents = enc_image.mean(dim = 1) | |
| text_latents = self.text_ln(text_latents) | |
| image_latents = self.image_ln(image_latents) | |
| text_latents = self.to_text_latent(text_latents) | |
| image_latents = self.to_visual_latent(image_latents) | |
| text_latents, image_latents = map(lambda t: F.normalize(t, p = 2, dim = -1), (text_latents, image_latents)) | |
| temp = self.temperature.exp() | |
| if not return_loss: | |
| sim = einsum('n d, n d -> n', text_latents, image_latents) * temp | |
| return sim | |
| sim = einsum('i d, j d -> i j', text_latents, image_latents) * temp | |
| labels = torch.arange(b, device = device) | |
| loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 | |
| return loss | |
| # main Text2Punks class | |
| class Text2Punks(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| n_embd, | |
| n_layer = 12, | |
| n_head = 12, | |
| d_head = 64, | |
| num_text_tokens = 10000, | |
| text_seq_len = 256, | |
| num_image_tokens = 222, | |
| image_seq_len = 576, | |
| image_size = 24, | |
| attn_pdrop = 0.1, | |
| resid_pdrop = 0.1, | |
| embd_pdrop = 0.1, | |
| ff_dropout = 0.1, | |
| attn_types = None, | |
| loss_img_weight = 7, | |
| loss_txt_weight = 7, | |
| ): | |
| super().__init__() | |
| num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) | |
| self.text_emb = nn.Embedding(num_text_tokens, n_embd) | |
| self.image_emb = nn.Embedding(num_image_tokens, n_embd) | |
| self.text_pos_emb = nn.Embedding(text_seq_len + 1, n_embd) # +1 for <bos> a.k.a <sos> | |
| # self.image_pos_emb = nn.Embedding(image_seq_len, n_embd) | |
| self.image_pos_emb = nn.Parameter(torch.zeros(1, image_seq_len, n_embd)) | |
| # self.image_pos_emb = AxialPositionalEmbedding(n_embd, axial_shape=(image_size, image_size)) | |
| self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss | |
| self.num_image_tokens = num_image_tokens | |
| self.text_seq_len = text_seq_len | |
| self.image_seq_len = image_seq_len | |
| seq_len = text_seq_len + image_seq_len | |
| total_tokens = num_text_tokens + num_image_tokens | |
| self.total_seq_len = seq_len | |
| self.total_tokens = total_tokens | |
| self.transformer = Transformer( | |
| dim = n_embd, | |
| causal = True, | |
| seq_len = seq_len, | |
| depth = n_layer, | |
| heads = n_head, | |
| dim_head = d_head, | |
| attn_dropout = attn_pdrop, | |
| resid_dropout = resid_pdrop, | |
| embd_dropout = embd_pdrop, | |
| ff_dropout = ff_dropout, | |
| attn_types = attn_types, | |
| image_size = image_size, | |
| ) | |
| self.to_logits = nn.Sequential( | |
| nn.LayerNorm(n_embd), | |
| nn.Linear(n_embd, self.total_tokens), | |
| ) | |
| seq_range = torch.arange(seq_len) | |
| logits_range = torch.arange(total_tokens) | |
| seq_range = rearrange(seq_range, 'n -> () n ()') | |
| logits_range = rearrange(logits_range, 'd -> () () d') | |
| logits_mask = ( | |
| ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) | | |
| ((seq_range < text_seq_len) & (logits_range >= num_text_tokens)) | |
| ) | |
| self.register_buffer('logits_mask', logits_mask, persistent=False) | |
| self.loss_img_weight = loss_img_weight | |
| self.loss_txt_weight = loss_txt_weight | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, (nn.Linear, nn.Embedding)): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if isinstance(module, nn.Linear) and module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| def generate_images( | |
| self, | |
| text, | |
| decoder, | |
| *, | |
| clip = None, | |
| filter_thres = 0.5, | |
| temperature = 1., | |
| img = None, | |
| num_init_img_tokens = None | |
| ): | |
| text_seq_len, image_seq_len, num_text_tokens = self.text_seq_len, self.image_seq_len, self.num_text_tokens | |
| total_len = text_seq_len + image_seq_len | |
| batch = text.shape[0] | |
| text = text[:, :text_seq_len] # make sure text is within bounds | |
| out = text | |
| if exists(img): | |
| assert img.shape[1] == image_seq_len, f'input image must have the correct image size {image_seq_len}' | |
| num_img_tokens = default(num_init_img_tokens, int(0.4375 * image_seq_len)) # OpenAI used 14 * 32 initial tokens to prime | |
| assert num_img_tokens < image_seq_len, 'number of initial image tokens for priming must be less than the total image token sequence length' | |
| trunc_img = img[:, :num_img_tokens] | |
| out = torch.cat((out, trunc_img), dim = -1) | |
| for cur_len in range(out.shape[1], total_len): | |
| is_image = cur_len >= text_seq_len | |
| text, image = out[:, :text_seq_len], out[:, text_seq_len:] | |
| logits = self(text, image)[:, -1, :] | |
| filtered_logits = top_k(logits, thres = filter_thres) | |
| probs = F.softmax(filtered_logits / temperature, dim = -1) | |
| sample = torch.multinomial(probs, 1) | |
| sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens | |
| out = torch.cat((out, sample), dim=-1) | |
| text_seq = out[:, :text_seq_len] | |
| img_seq = out[:, -image_seq_len:] | |
| scores = None | |
| if exists(clip): | |
| scores = clip(text_seq, img_seq, return_loss = False) | |
| img_seq = repeat(img_seq, 'b p -> b p c', c=3) | |
| decoder = repeat(decoder, 'p c -> b p c', b=batch) | |
| images = torch.gather(decoder, 1, img_seq) | |
| images = rearrange(images, 'b (h w) c-> b c h w', h=24, w =24) | |
| images = images.float() | |
| return images, scores | |
| def forward( | |
| self, | |
| text, | |
| image = None, | |
| return_loss = False | |
| ): | |
| assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})' | |
| device, total_seq_len = text.device, self.total_seq_len | |
| text_range = torch.arange(self.text_seq_len, device = device) + (self.num_text_tokens - self.text_seq_len) | |
| text = torch.where(text == 0, text_range, text) | |
| text = F.pad(text, (1, 0), value = 0) # add <bos> | |
| tokens = self.text_emb(text) | |
| tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device)) | |
| seq_len = tokens.shape[1] | |
| image_len = image.shape[1] | |
| image_emb = self.image_emb(image) | |
| # image_emb += self.image_pos_emb(torch.arange(image_len, device = device)) | |
| image_emb += self.image_pos_emb[:, :image_len, :] | |
| # image_emb += self.image_pos_emb(image_emb) | |
| tokens = torch.cat((tokens, image_emb), dim = 1) | |
| seq_len += image_len | |
| # when training, if the length exceeds the total text + image length | |
| # remove the last token, since it needs not to be trained | |
| if tokens.shape[1] > total_seq_len: | |
| seq_len -= 1 | |
| tokens = tokens[:, :-1] | |
| out = self.transformer(tokens) | |
| logits = self.to_logits(out) | |
| # mask logits to make sure text predicts text (except last token), and image predicts image | |
| logits_mask = self.logits_mask[:, :seq_len] | |
| max_neg_value = -torch.finfo(logits.dtype).max | |
| logits.masked_fill_(logits_mask, max_neg_value) | |
| if not return_loss: | |
| return logits | |
| assert exists(image), 'when training, image must be supplied' | |
| offsetted_image = image + self.num_text_tokens | |
| labels = torch.cat((text[:, 1:], offsetted_image), dim = 1) | |
| logits = rearrange(logits, 'b n c -> b c n') | |
| loss_text = F.cross_entropy(logits[:, :, :self.text_seq_len], labels[:, :self.text_seq_len]) | |
| loss_img = F.cross_entropy(logits[:, :, self.text_seq_len:], labels[:, self.text_seq_len:]) | |
| loss = (self.loss_txt_weight * loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + self.loss_txt_weight) | |
| return loss, loss_text, loss_img | |