wangpangintsig commited on
Commit
839e5a0
·
verified ·
1 Parent(s): 34c1280

Create other_impls.py

Browse files
Files changed (1) hide show
  1. other_impls.py +868 -0
other_impls.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### This file contains impls for underlying related models (CLIP, T5, etc)
2
+
3
+ import logging
4
+ import math
5
+ import os
6
+
7
+ import torch
8
+ from torch import nn
9
+ from transformers import CLIPTokenizer, T5TokenizerFast
10
+ from einops import rearrange
11
+
12
+ #################################################################################################
13
+ ### Core/Utility
14
+ #################################################################################################
15
+
16
+
17
+ def attention(q, k, v, heads, mask=None):
18
+ """Convenience wrapper around a basic attention operation"""
19
+ b, _, dim_head = q.shape
20
+ dim_head //= heads
21
+ q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
22
+ out = torch.nn.functional.scaled_dot_product_attention(
23
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
24
+ )
25
+ return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
26
+
27
+ class Mlp(nn.Module):
28
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
29
+
30
+ def __init__(
31
+ self,
32
+ in_features,
33
+ hidden_features=None,
34
+ out_features=None,
35
+ act_layer=nn.GELU,
36
+ bias=True,
37
+ dtype=None,
38
+ device=None,
39
+ ):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+
44
+ self.fc1 = nn.Linear(
45
+ in_features, hidden_features, bias=bias, dtype=dtype, device=device
46
+ )
47
+ self.act = act_layer
48
+ self.fc2 = nn.Linear(
49
+ hidden_features, out_features, bias=bias, dtype=dtype, device=device
50
+ )
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.fc2(x)
56
+ return x
57
+
58
+
59
+ #################################################################################################
60
+ ### CLIP
61
+ #################################################################################################
62
+
63
+
64
+ class CLIPAttention(torch.nn.Module):
65
+ def __init__(self, embed_dim, heads, dtype, device):
66
+ super().__init__()
67
+ self.heads = heads
68
+ self.q_proj = nn.Linear(
69
+ embed_dim, embed_dim, bias=True, dtype=dtype, device=device
70
+ )
71
+ self.k_proj = nn.Linear(
72
+ embed_dim, embed_dim, bias=True, dtype=dtype, device=device
73
+ )
74
+ self.v_proj = nn.Linear(
75
+ embed_dim, embed_dim, bias=True, dtype=dtype, device=device
76
+ )
77
+ self.out_proj = nn.Linear(
78
+ embed_dim, embed_dim, bias=True, dtype=dtype, device=device
79
+ )
80
+
81
+ def forward(self, x, mask=None):
82
+ q = self.q_proj(x)
83
+ k = self.k_proj(x)
84
+ v = self.v_proj(x)
85
+ out = attention(q, k, v, self.heads, mask)
86
+ return self.out_proj(out)
87
+
88
+
89
+ ACTIVATIONS = {
90
+ "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
91
+ "gelu": torch.nn.functional.gelu,
92
+ }
93
+
94
+
95
+ class CLIPLayer(torch.nn.Module):
96
+ def __init__(
97
+ self,
98
+ embed_dim,
99
+ heads,
100
+ intermediate_size,
101
+ intermediate_activation,
102
+ dtype,
103
+ device,
104
+ ):
105
+ super().__init__()
106
+ self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
107
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
108
+ self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
109
+ # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
110
+ self.mlp = Mlp(
111
+ embed_dim,
112
+ intermediate_size,
113
+ embed_dim,
114
+ act_layer=ACTIVATIONS[intermediate_activation],
115
+ dtype=dtype,
116
+ device=device,
117
+ )
118
+
119
+ def forward(self, x, mask=None):
120
+ x += self.self_attn(self.layer_norm1(x), mask)
121
+ x += self.mlp(self.layer_norm2(x))
122
+ return x
123
+
124
+
125
+ class CLIPEncoder(torch.nn.Module):
126
+ def __init__(
127
+ self,
128
+ num_layers,
129
+ embed_dim,
130
+ heads,
131
+ intermediate_size,
132
+ intermediate_activation,
133
+ dtype,
134
+ device,
135
+ ):
136
+ super().__init__()
137
+ self.layers = torch.nn.ModuleList(
138
+ [
139
+ CLIPLayer(
140
+ embed_dim,
141
+ heads,
142
+ intermediate_size,
143
+ intermediate_activation,
144
+ dtype,
145
+ device,
146
+ )
147
+ for i in range(num_layers)
148
+ ]
149
+ )
150
+
151
+ def forward(self, x, mask=None, intermediate_output=None):
152
+ if intermediate_output is not None:
153
+ if intermediate_output < 0:
154
+ intermediate_output = len(self.layers) + intermediate_output
155
+ intermediate = None
156
+ for i, l in enumerate(self.layers):
157
+ x = l(x, mask)
158
+ if i == intermediate_output:
159
+ intermediate = x.clone()
160
+ return x, intermediate
161
+
162
+
163
+ class CLIPEmbeddings(torch.nn.Module):
164
+ def __init__(
165
+ self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None
166
+ ):
167
+ super().__init__()
168
+ self.token_embedding = torch.nn.Embedding(
169
+ vocab_size, embed_dim, dtype=dtype, device=device
170
+ )
171
+ self.position_embedding = torch.nn.Embedding(
172
+ num_positions, embed_dim, dtype=dtype, device=device
173
+ )
174
+
175
+ def forward(self, input_tokens):
176
+ return self.token_embedding(input_tokens) + self.position_embedding.weight
177
+
178
+
179
+ class CLIPTextModel_(torch.nn.Module):
180
+ def __init__(self, config_dict, dtype, device):
181
+ num_layers = config_dict["num_hidden_layers"]
182
+ embed_dim = config_dict["hidden_size"]
183
+ heads = config_dict["num_attention_heads"]
184
+ intermediate_size = config_dict["intermediate_size"]
185
+ intermediate_activation = config_dict["hidden_act"]
186
+ super().__init__()
187
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
188
+ self.encoder = CLIPEncoder(
189
+ num_layers,
190
+ embed_dim,
191
+ heads,
192
+ intermediate_size,
193
+ intermediate_activation,
194
+ dtype,
195
+ device,
196
+ )
197
+ self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
198
+
199
+ def forward(
200
+ self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True
201
+ ):
202
+ x = self.embeddings(input_tokens)
203
+ causal_mask = (
204
+ torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
205
+ .fill_(float("-inf"))
206
+ .triu_(1)
207
+ )
208
+ x, i = self.encoder(
209
+ x, mask=causal_mask, intermediate_output=intermediate_output
210
+ )
211
+ x = self.final_layer_norm(x)
212
+ if i is not None and final_layer_norm_intermediate:
213
+ i = self.final_layer_norm(i)
214
+ pooled_output = x[
215
+ torch.arange(x.shape[0], device=x.device),
216
+ input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),
217
+ ]
218
+ return x, i, pooled_output
219
+
220
+
221
+ class CLIPTextModel(torch.nn.Module):
222
+ def __init__(self, config_dict, dtype, device):
223
+ super().__init__()
224
+ self.num_layers = config_dict["num_hidden_layers"]
225
+ self.text_model = CLIPTextModel_(config_dict, dtype, device)
226
+ embed_dim = config_dict["hidden_size"]
227
+ self.text_projection = nn.Linear(
228
+ embed_dim, embed_dim, bias=False, dtype=dtype, device=device
229
+ )
230
+ self.text_projection.weight.copy_(torch.eye(embed_dim))
231
+ self.dtype = dtype
232
+
233
+ def get_input_embeddings(self):
234
+ return self.text_model.embeddings.token_embedding
235
+
236
+ def set_input_embeddings(self, embeddings):
237
+ self.text_model.embeddings.token_embedding = embeddings
238
+
239
+ def forward(self, *args, **kwargs):
240
+ x = self.text_model(*args, **kwargs)
241
+ out = self.text_projection(x[2])
242
+ return (x[0], x[1], out, x[2])
243
+
244
+
245
+ def parse_parentheses(string):
246
+ result = []
247
+ current_item = ""
248
+ nesting_level = 0
249
+ for char in string:
250
+ if char == "(":
251
+ if nesting_level == 0:
252
+ if current_item:
253
+ result.append(current_item)
254
+ current_item = "("
255
+ else:
256
+ current_item = "("
257
+ else:
258
+ current_item += char
259
+ nesting_level += 1
260
+ elif char == ")":
261
+ nesting_level -= 1
262
+ if nesting_level == 0:
263
+ result.append(current_item + ")")
264
+ current_item = ""
265
+ else:
266
+ current_item += char
267
+ else:
268
+ current_item += char
269
+ if current_item:
270
+ result.append(current_item)
271
+ return result
272
+
273
+
274
+ def token_weights(string, current_weight):
275
+ a = parse_parentheses(string)
276
+ out = []
277
+ for x in a:
278
+ weight = current_weight
279
+ if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
280
+ x = x[1:-1]
281
+ xx = x.rfind(":")
282
+ weight *= 1.1
283
+ if xx > 0:
284
+ try:
285
+ weight = float(x[xx + 1 :])
286
+ x = x[:xx]
287
+ except:
288
+ pass
289
+ out += token_weights(x, weight)
290
+ else:
291
+ out += [(x, current_weight)]
292
+ return out
293
+
294
+
295
+ def escape_important(text):
296
+ text = text.replace("\\)", "\0\1")
297
+ text = text.replace("\\(", "\0\2")
298
+ return text
299
+
300
+
301
+ def unescape_important(text):
302
+ text = text.replace("\0\1", ")")
303
+ text = text.replace("\0\2", "(")
304
+ return text
305
+
306
+
307
+ class SDTokenizer:
308
+ def __init__(
309
+ self,
310
+ max_length=77,
311
+ pad_with_end=True,
312
+ tokenizer=None,
313
+ has_start_token=True,
314
+ pad_to_max_length=True,
315
+ min_length=None,
316
+ extra_padding_token=None,
317
+ ):
318
+ self.tokenizer = tokenizer
319
+ self.max_length = max_length
320
+ self.min_length = min_length
321
+
322
+ empty = self.tokenizer("")["input_ids"]
323
+ if has_start_token:
324
+ self.tokens_start = 1
325
+ self.start_token = empty[0]
326
+ self.end_token = empty[1]
327
+ else:
328
+ self.tokens_start = 0
329
+ self.start_token = None
330
+ self.end_token = empty[0]
331
+ self.pad_with_end = pad_with_end
332
+ self.pad_to_max_length = pad_to_max_length
333
+ self.extra_padding_token = extra_padding_token
334
+
335
+ vocab = self.tokenizer.get_vocab()
336
+ self.inv_vocab = {v: k for k, v in vocab.items()}
337
+ self.max_word_length = 8
338
+
339
+ def tokenize_with_weights(self, text: str, return_word_ids=False):
340
+ """
341
+ Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
342
+ The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.
343
+ """
344
+ if self.pad_with_end:
345
+ pad_token = self.end_token
346
+ else:
347
+ pad_token = 0
348
+
349
+ text = escape_important(text)
350
+ parsed_weights = token_weights(text, 1.0)
351
+
352
+ # tokenize words
353
+ tokens = []
354
+ for weighted_segment, weight in parsed_weights:
355
+ to_tokenize = (
356
+ unescape_important(weighted_segment).replace("\n", " ").split(" ")
357
+ )
358
+ to_tokenize = [x for x in to_tokenize if x != ""]
359
+ for word in to_tokenize:
360
+ # parse word
361
+ tokens.append(
362
+ [
363
+ (t, weight)
364
+ for t in self.tokenizer(word)["input_ids"][
365
+ self.tokens_start : -1
366
+ ]
367
+ ]
368
+ )
369
+
370
+ # reshape token array to CLIP input size
371
+ batched_tokens = []
372
+ batch = []
373
+ if self.start_token is not None:
374
+ batch.append((self.start_token, 1.0, 0))
375
+ batched_tokens.append(batch)
376
+ for i, t_group in enumerate(tokens):
377
+ # determine if we're going to try and keep the tokens in a single batch
378
+ is_large = len(t_group) >= self.max_word_length
379
+
380
+ while len(t_group) > 0:
381
+ if len(t_group) + len(batch) > self.max_length - 1:
382
+ remaining_length = self.max_length - len(batch) - 1
383
+ # break word in two and add end token
384
+ if is_large:
385
+ batch.extend(
386
+ [(t, w, i + 1) for t, w in t_group[:remaining_length]]
387
+ )
388
+ batch.append((self.end_token, 1.0, 0))
389
+ t_group = t_group[remaining_length:]
390
+ # add end token and pad
391
+ else:
392
+ batch.append((self.end_token, 1.0, 0))
393
+ if self.pad_to_max_length:
394
+ batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
395
+ # start new batch
396
+ batch = []
397
+ if self.start_token is not None:
398
+ batch.append((self.start_token, 1.0, 0))
399
+ batched_tokens.append(batch)
400
+ else:
401
+ batch.extend([(t, w, i + 1) for t, w in t_group])
402
+ t_group = []
403
+
404
+ # pad extra padding token first befor getting to the end token
405
+ if self.extra_padding_token is not None:
406
+ batch.extend(
407
+ [(self.extra_padding_token, 1.0, 0)]
408
+ * (self.min_length - len(batch) - 1)
409
+ )
410
+ # fill last batch
411
+ batch.append((self.end_token, 1.0, 0))
412
+ if self.pad_to_max_length:
413
+ batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
414
+ if self.min_length is not None and len(batch) < self.min_length:
415
+ batch.extend([(pad_token, 1.0, 0)] * (self.min_length - len(batch)))
416
+
417
+ if not return_word_ids:
418
+ batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
419
+
420
+ return batched_tokens
421
+
422
+ def untokenize(self, token_weight_pair):
423
+ return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
424
+
425
+
426
+ class SDXLClipGTokenizer(SDTokenizer):
427
+ def __init__(self, tokenizer):
428
+ super().__init__(pad_with_end=False, tokenizer=tokenizer)
429
+
430
+
431
+ class SD3Tokenizer:
432
+ def __init__(self):
433
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
434
+ self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
435
+ self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
436
+ self.t5xxl = T5XXLTokenizer()
437
+
438
+ def tokenize_with_weights(self, text: str):
439
+ out = {}
440
+ out["l"] = self.clip_l.tokenize_with_weights(text)
441
+ out["g"] = self.clip_g.tokenize_with_weights(text)
442
+ out["t5xxl"] = self.t5xxl.tokenize_with_weights(text[:226])
443
+ return out
444
+
445
+
446
+ class ClipTokenWeightEncoder:
447
+ def encode_token_weights(self, token_weight_pairs):
448
+ tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
449
+ out, pooled = self([tokens])
450
+ if pooled is not None:
451
+ first_pooled = pooled[0:1].cpu()
452
+ else:
453
+ first_pooled = pooled
454
+ output = [out[0:1]]
455
+ return torch.cat(output, dim=-2).cpu(), first_pooled
456
+
457
+
458
+ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
459
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
460
+
461
+ LAYERS = ["last", "pooled", "hidden"]
462
+
463
+ def __init__(
464
+ self,
465
+ device="cpu",
466
+ max_length=77,
467
+ layer="last",
468
+ layer_idx=None,
469
+ textmodel_json_config=None,
470
+ dtype=None,
471
+ model_class=CLIPTextModel,
472
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407},
473
+ layer_norm_hidden_state=True,
474
+ return_projected_pooled=True,
475
+ ):
476
+ super().__init__()
477
+ assert layer in self.LAYERS
478
+ self.transformer = model_class(textmodel_json_config, dtype, device)
479
+ self.num_layers = self.transformer.num_layers
480
+ self.max_length = max_length
481
+ self.transformer = self.transformer.eval()
482
+ for param in self.parameters():
483
+ param.requires_grad = False
484
+ self.layer = layer
485
+ self.layer_idx = None
486
+ self.special_tokens = special_tokens
487
+ self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
488
+ self.layer_norm_hidden_state = layer_norm_hidden_state
489
+ self.return_projected_pooled = return_projected_pooled
490
+ if layer == "hidden":
491
+ assert layer_idx is not None
492
+ assert abs(layer_idx) < self.num_layers
493
+ self.set_clip_options({"layer": layer_idx})
494
+ self.options_default = (
495
+ self.layer,
496
+ self.layer_idx,
497
+ self.return_projected_pooled,
498
+ )
499
+
500
+ def set_clip_options(self, options):
501
+ layer_idx = options.get("layer", self.layer_idx)
502
+ self.return_projected_pooled = options.get(
503
+ "projected_pooled", self.return_projected_pooled
504
+ )
505
+ if layer_idx is None or abs(layer_idx) > self.num_layers:
506
+ self.layer = "last"
507
+ else:
508
+ self.layer = "hidden"
509
+ self.layer_idx = layer_idx
510
+
511
+ def forward(self, tokens):
512
+ backup_embeds = self.transformer.get_input_embeddings()
513
+ device = backup_embeds.weight.device
514
+ tokens = torch.LongTensor(tokens).to(device)
515
+ outputs = self.transformer(
516
+ tokens,
517
+ intermediate_output=self.layer_idx,
518
+ final_layer_norm_intermediate=self.layer_norm_hidden_state,
519
+ )
520
+ self.transformer.set_input_embeddings(backup_embeds)
521
+ if self.layer == "last":
522
+ z = outputs[0]
523
+ else:
524
+ z = outputs[1]
525
+ pooled_output = None
526
+ if len(outputs) >= 3:
527
+ if (
528
+ not self.return_projected_pooled
529
+ and len(outputs) >= 4
530
+ and outputs[3] is not None
531
+ ):
532
+ pooled_output = outputs[3].float()
533
+ elif outputs[2] is not None:
534
+ pooled_output = outputs[2].float()
535
+ return z.float(), pooled_output
536
+
537
+
538
+ class SDXLClipG(SDClipModel):
539
+ """Wraps the CLIP-G model into the SD-CLIP-Model interface"""
540
+
541
+ def __init__(
542
+ self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None
543
+ ):
544
+ if layer == "penultimate":
545
+ layer = "hidden"
546
+ layer_idx = -2
547
+ super().__init__(
548
+ device=device,
549
+ layer=layer,
550
+ layer_idx=layer_idx,
551
+ textmodel_json_config=config,
552
+ dtype=dtype,
553
+ special_tokens={"start": 49406, "end": 49407, "pad": 0},
554
+ layer_norm_hidden_state=False,
555
+ )
556
+
557
+
558
+ class T5XXLModel(SDClipModel):
559
+ """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience"""
560
+
561
+ def __init__(self, config, device="cpu", layer="last", layer_idx=None, dtype=None):
562
+ super().__init__(
563
+ device=device,
564
+ layer=layer,
565
+ layer_idx=layer_idx,
566
+ textmodel_json_config=config,
567
+ dtype=dtype,
568
+ special_tokens={"end": 1, "pad": 0},
569
+ model_class=T5,
570
+ )
571
+
572
+
573
+ #################################################################################################
574
+ ### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
575
+ #################################################################################################
576
+
577
+
578
+ class T5XXLTokenizer(SDTokenizer):
579
+ """Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
580
+
581
+ def __init__(self):
582
+ super().__init__(
583
+ pad_with_end=False,
584
+ tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"),
585
+ has_start_token=False,
586
+ pad_to_max_length=False,
587
+ max_length=99999999,
588
+ min_length=77,
589
+ )
590
+
591
+
592
+ class T5LayerNorm(torch.nn.Module):
593
+ def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):
594
+ super().__init__()
595
+ self.weight = torch.nn.Parameter(
596
+ torch.ones(hidden_size, dtype=dtype, device=device)
597
+ )
598
+ self.variance_epsilon = eps
599
+
600
+ def forward(self, x):
601
+ variance = x.pow(2).mean(-1, keepdim=True)
602
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
603
+ return self.weight.to(device=x.device, dtype=x.dtype) * x
604
+
605
+
606
+ class T5DenseGatedActDense(torch.nn.Module):
607
+ def __init__(self, model_dim, ff_dim, dtype, device):
608
+ super().__init__()
609
+ self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
610
+ self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
611
+ self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
612
+
613
+ def forward(self, x):
614
+ hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
615
+ hidden_linear = self.wi_1(x)
616
+ x = hidden_gelu * hidden_linear
617
+ x = self.wo(x)
618
+ return x
619
+
620
+
621
+ class T5LayerFF(torch.nn.Module):
622
+ def __init__(self, model_dim, ff_dim, dtype, device):
623
+ super().__init__()
624
+ self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)
625
+ self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
626
+
627
+ def forward(self, x):
628
+ forwarded_states = self.layer_norm(x)
629
+ forwarded_states = self.DenseReluDense(forwarded_states)
630
+ x += forwarded_states
631
+ return x
632
+
633
+
634
+ class T5Attention(torch.nn.Module):
635
+ def __init__(
636
+ self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
637
+ ):
638
+ super().__init__()
639
+ # Mesh TensorFlow initialization to avoid scaling before softmax
640
+ self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
641
+ self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
642
+ self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
643
+ self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
644
+ self.num_heads = num_heads
645
+ self.relative_attention_bias = None
646
+ if relative_attention_bias:
647
+ self.relative_attention_num_buckets = 32
648
+ self.relative_attention_max_distance = 128
649
+ self.relative_attention_bias = torch.nn.Embedding(
650
+ self.relative_attention_num_buckets, self.num_heads, device=device
651
+ )
652
+
653
+ @staticmethod
654
+ def _relative_position_bucket(
655
+ relative_position, bidirectional=True, num_buckets=32, max_distance=128
656
+ ):
657
+ """
658
+ Adapted from Mesh Tensorflow:
659
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
660
+
661
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
662
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
663
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
664
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
665
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
666
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
667
+
668
+ Args:
669
+ relative_position: an int32 Tensor
670
+ bidirectional: a boolean - whether the attention is bidirectional
671
+ num_buckets: an integer
672
+ max_distance: an integer
673
+
674
+ Returns:
675
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
676
+ """
677
+ relative_buckets = 0
678
+ if bidirectional:
679
+ num_buckets //= 2
680
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
681
+ relative_position = torch.abs(relative_position)
682
+ else:
683
+ relative_position = -torch.min(
684
+ relative_position, torch.zeros_like(relative_position)
685
+ )
686
+ # now relative_position is in the range [0, inf)
687
+ # half of the buckets are for exact increments in positions
688
+ max_exact = num_buckets // 2
689
+ is_small = relative_position < max_exact
690
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
691
+ relative_position_if_large = max_exact + (
692
+ torch.log(relative_position.float() / max_exact)
693
+ / math.log(max_distance / max_exact)
694
+ * (num_buckets - max_exact)
695
+ ).to(torch.long)
696
+ relative_position_if_large = torch.min(
697
+ relative_position_if_large,
698
+ torch.full_like(relative_position_if_large, num_buckets - 1),
699
+ )
700
+ relative_buckets += torch.where(
701
+ is_small, relative_position, relative_position_if_large
702
+ )
703
+ return relative_buckets
704
+
705
+ def compute_bias(self, query_length, key_length, device):
706
+ """Compute binned relative position bias"""
707
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[
708
+ :, None
709
+ ]
710
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[
711
+ None, :
712
+ ]
713
+ relative_position = (
714
+ memory_position - context_position
715
+ ) # shape (query_length, key_length)
716
+ relative_position_bucket = self._relative_position_bucket(
717
+ relative_position, # shape (query_length, key_length)
718
+ bidirectional=True,
719
+ num_buckets=self.relative_attention_num_buckets,
720
+ max_distance=self.relative_attention_max_distance,
721
+ )
722
+ values = self.relative_attention_bias(
723
+ relative_position_bucket
724
+ ) # shape (query_length, key_length, num_heads)
725
+ values = values.permute([2, 0, 1]).unsqueeze(
726
+ 0
727
+ ) # shape (1, num_heads, query_length, key_length)
728
+ return values
729
+
730
+ def forward(self, x, past_bias=None):
731
+ q = self.q(x)
732
+ k = self.k(x)
733
+ v = self.v(x)
734
+ if self.relative_attention_bias is not None:
735
+ past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
736
+ if past_bias is not None:
737
+ mask = past_bias
738
+ out = attention(
739
+ q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask
740
+ )
741
+ return self.o(out), past_bias
742
+
743
+
744
+ class T5LayerSelfAttention(torch.nn.Module):
745
+ def __init__(
746
+ self,
747
+ model_dim,
748
+ inner_dim,
749
+ ff_dim,
750
+ num_heads,
751
+ relative_attention_bias,
752
+ dtype,
753
+ device,
754
+ ):
755
+ super().__init__()
756
+ self.SelfAttention = T5Attention(
757
+ model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device
758
+ )
759
+ self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
760
+
761
+ def forward(self, x, past_bias=None):
762
+ output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)
763
+ x += output
764
+ return x, past_bias
765
+
766
+
767
+ class T5Block(torch.nn.Module):
768
+ def __init__(
769
+ self,
770
+ model_dim,
771
+ inner_dim,
772
+ ff_dim,
773
+ num_heads,
774
+ relative_attention_bias,
775
+ dtype,
776
+ device,
777
+ ):
778
+ super().__init__()
779
+ self.layer = torch.nn.ModuleList()
780
+ self.layer.append(
781
+ T5LayerSelfAttention(
782
+ model_dim,
783
+ inner_dim,
784
+ ff_dim,
785
+ num_heads,
786
+ relative_attention_bias,
787
+ dtype,
788
+ device,
789
+ )
790
+ )
791
+ self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))
792
+
793
+ def forward(self, x, past_bias=None):
794
+ x, past_bias = self.layer[0](x, past_bias)
795
+ x = self.layer[-1](x)
796
+ return x, past_bias
797
+
798
+
799
+ class T5Stack(torch.nn.Module):
800
+ def __init__(
801
+ self,
802
+ num_layers,
803
+ model_dim,
804
+ inner_dim,
805
+ ff_dim,
806
+ num_heads,
807
+ vocab_size,
808
+ dtype,
809
+ device,
810
+ ):
811
+ super().__init__()
812
+ self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)
813
+ self.block = torch.nn.ModuleList(
814
+ [
815
+ T5Block(
816
+ model_dim,
817
+ inner_dim,
818
+ ff_dim,
819
+ num_heads,
820
+ relative_attention_bias=(i == 0),
821
+ dtype=dtype,
822
+ device=device,
823
+ )
824
+ for i in range(num_layers)
825
+ ]
826
+ )
827
+ self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)
828
+
829
+ def forward(
830
+ self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True
831
+ ):
832
+ intermediate = None
833
+ x = self.embed_tokens(input_ids)
834
+ past_bias = None
835
+ for i, l in enumerate(self.block):
836
+ x, past_bias = l(x, past_bias)
837
+ if i == intermediate_output:
838
+ intermediate = x.clone()
839
+ x = self.final_layer_norm(x)
840
+ if intermediate is not None and final_layer_norm_intermediate:
841
+ intermediate = self.final_layer_norm(intermediate)
842
+ return x, intermediate
843
+
844
+
845
+ class T5(torch.nn.Module):
846
+ def __init__(self, config_dict, dtype, device):
847
+ super().__init__()
848
+ self.num_layers = config_dict["num_layers"]
849
+ self.encoder = T5Stack(
850
+ self.num_layers,
851
+ config_dict["d_model"],
852
+ config_dict["d_model"],
853
+ config_dict["d_ff"],
854
+ config_dict["num_heads"],
855
+ config_dict["vocab_size"],
856
+ dtype,
857
+ device,
858
+ )
859
+ self.dtype = dtype
860
+
861
+ def get_input_embeddings(self):
862
+ return self.encoder.embed_tokens
863
+
864
+ def set_input_embeddings(self, embeddings):
865
+ self.encoder.embed_tokens = embeddings
866
+
867
+ def forward(self, *args, **kwargs):
868
+ return self.encoder(*args, **kwargs)