Girinath11 commited on
Commit
ef28d88
·
verified ·
1 Parent(s): 613f2bb

Create embeddings.py

Browse files
Files changed (1) hide show
  1. embeddings.py +277 -0
embeddings.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Optional, Tuple, List
6
+ class PositionalEncoding(nn.Module):
7
+ def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
8
+ super(PositionalEncoding, self).__init__()
9
+ self.d_model = d_model
10
+ self.dropout = nn.Dropout(dropout)
11
+ pe = torch.zeros(max_seq_len, d_model)
12
+ position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
13
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() *
14
+ -(math.log(10000.0) / d_model))
15
+ pe[:, 0::2] = torch.sin(position * div_term)
16
+ if d_model % 2 == 1:
17
+ pe[:, 1::2] = torch.cos(position * div_term[:-1])
18
+ else:
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+ self.register_buffer('pe', pe.unsqueeze(0))
21
+ def forward(self, x):
22
+ batch_size, seq_len, d_model = x.size()
23
+ x = x + self.pe[:, :seq_len, :d_model]
24
+ return self.dropout(x)
25
+ class LearnedPositionalEmbedding(nn.Module):
26
+ def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1):
27
+ super(LearnedPositionalEmbedding, self).__init__()
28
+ self.max_seq_len = max_seq_len
29
+ self.d_model = d_model
30
+ self.pos_embedding = nn.Embedding(max_seq_len, d_model)
31
+ self.dropout = nn.Dropout(dropout)
32
+ nn.init.normal_(self.pos_embedding.weight, std=0.02)
33
+ def forward(self, x):
34
+ batch_size, seq_len, d_model = x.size()
35
+ if seq_len > self.max_seq_len:
36
+ raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
37
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
38
+ pos_emb = self.pos_embedding(positions)
39
+ x = x + pos_emb
40
+ return self.dropout(x)
41
+ class RotaryPositionalEmbedding(nn.Module):
42
+ def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = 10000.0):
43
+ super(RotaryPositionalEmbedding, self).__init__()
44
+ self.d_model = d_model
45
+ self.max_seq_len = max_seq_len
46
+ self.base = base
47
+ inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
48
+ self.register_buffer('inv_freq', inv_freq)
49
+ self._seq_len_cached = 0
50
+ self._cos_cached = None
51
+ self._sin_cached = None
52
+ def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
53
+ if seq_len > self._seq_len_cached:
54
+ self._seq_len_cached = seq_len
55
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
56
+ freqs = torch.outer(t, self.inv_freq)
57
+ self._cos_cached = freqs.cos().to(dtype)
58
+ self._sin_cached = freqs.sin().to(dtype)
59
+ def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ batch_size, seq_len, num_heads, head_dim = q.shape
61
+ self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
62
+ cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2]
63
+ sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2]
64
+ cos = cos.view(1, seq_len, 1, -1)
65
+ sin = sin.view(1, seq_len, 1, -1)
66
+ q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
67
+ k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
68
+ q_rot = self._rotate_half(q, cos, sin)
69
+ k_rot = self._rotate_half(k, cos, sin)
70
+ q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
71
+ k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
72
+ return q_rot, k_rot
73
+ def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
74
+ x1 = x[..., :x.shape[-1] // 2]
75
+ x2 = x[..., x.shape[-1] // 2:]
76
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
77
+ class TechEmbeddingLayer(nn.Module):
78
+ def __init__(self,
79
+ vocab_size: int,
80
+ d_model: int,
81
+ max_seq_len: int = 512,
82
+ dropout: float = 0.1,
83
+ padding_idx: int = 0,
84
+ pos_encoding: str = "learned",
85
+ layer_norm: bool = True):
86
+ super(TechEmbeddingLayer, self).__init__()
87
+ self.d_model = d_model
88
+ self.vocab_size = vocab_size
89
+ self.padding_idx = padding_idx
90
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
91
+ self.pos_encoding_type = pos_encoding
92
+ if pos_encoding == "sinusoidal":
93
+ self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
94
+ elif pos_encoding == "learned":
95
+ self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
96
+ elif pos_encoding == "rope":
97
+ self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
98
+ else:
99
+ raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
100
+ self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
101
+ self.dropout = nn.Dropout(dropout)
102
+ self._init_weights()
103
+ def _init_weights(self):
104
+ nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
105
+ if self.padding_idx is not None:
106
+ nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
107
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
108
+ if (input_ids >= self.vocab_size).any():
109
+ raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
110
+ embeddings = self.token_embedding(input_ids)
111
+ if self.pos_encoding_type != "rope":
112
+ embeddings = self.pos_encoding(embeddings)
113
+ embeddings = self.layer_norm(embeddings)
114
+ embeddings = self.dropout(embeddings)
115
+ return embeddings
116
+ def get_positional_encoding(self):
117
+ return self.pos_encoding if self.pos_encoding_type == "rope" else None
118
+ class AdaptiveEmbedding(nn.Module):
119
+ def __init__(self,
120
+ vocab_size: int,
121
+ d_model: int,
122
+ cutoffs: list = [2000, 10000],
123
+ div_val: float = 4.0):
124
+ super(AdaptiveEmbedding, self).__init__()
125
+ self.vocab_size = vocab_size
126
+ self.d_model = d_model
127
+ self.cutoffs = [0] + cutoffs + [vocab_size]
128
+ self.div_val = div_val
129
+ self.embeddings = nn.ModuleList()
130
+ self.projections = nn.ModuleList()
131
+ for i in range(len(self.cutoffs) - 1):
132
+ l_idx = self.cutoffs[i]
133
+ r_idx = self.cutoffs[i + 1]
134
+ d_emb = int(d_model / (div_val ** i))
135
+ emb = nn.Embedding(r_idx - l_idx, d_emb)
136
+ nn.init.normal_(emb.weight, mean=0.0, std=0.02)
137
+ self.embeddings.append(emb)
138
+ if d_emb != d_model:
139
+ proj = nn.Linear(d_emb, d_model, bias=False)
140
+ nn.init.normal_(proj.weight, mean=0.0, std=0.02)
141
+ self.projections.append(proj)
142
+ else:
143
+ self.projections.append(nn.Identity())
144
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
145
+ if (input_ids >= self.vocab_size).any():
146
+ raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
147
+ batch_size, seq_len = input_ids.shape
148
+ embeddings = torch.zeros(batch_size, seq_len, self.d_model,
149
+ device=input_ids.device, dtype=torch.float32)
150
+ for i in range(len(self.cutoffs) - 1):
151
+ l_idx = self.cutoffs[i]
152
+ r_idx = self.cutoffs[i + 1]
153
+ mask = (input_ids >= l_idx) & (input_ids < r_idx)
154
+ if mask.any():
155
+ indices = input_ids[mask] - l_idx
156
+ indices = indices.clamp(max=r_idx - l_idx - 1)
157
+ emb = self.embeddings[i](indices)
158
+ emb = self.projections[i](emb)
159
+ embeddings[mask] = emb
160
+ return embeddings
161
+ def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
162
+ return input_ids == padding_idx
163
+ def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
164
+ return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
165
+ def create_attention_mask(input_ids: torch.Tensor,
166
+ padding_idx: int = 0,
167
+ causal: bool = True) -> torch.Tensor:
168
+ batch_size, seq_len = input_ids.shape
169
+ device = input_ids.device
170
+
171
+ padding_mask = create_padding_mask(input_ids, padding_idx)
172
+ padding_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
173
+ if causal:
174
+ causal_mask = create_causal_mask(seq_len, device)
175
+ causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, seq_len)
176
+ combined_mask = padding_mask | causal_mask
177
+ else:
178
+ combined_mask = padding_mask
179
+
180
+ return combined_mask
181
+ class EmbeddingAnalyzer:
182
+ def __init__(self, embedding_layer: nn.Module):
183
+ self.embedding_layer = embedding_layer
184
+ def get_similarity_matrix(self, tokens: List[int] = None) -> torch.Tensor:
185
+ if hasattr(self.embedding_layer, 'token_embedding'):
186
+ embeddings = self.embedding_layer.token_embedding.weight
187
+ elif hasattr(self.embedding_layer, 'embeddings'):
188
+ weights = [emb.weight for emb in self.embedding_layer.embeddings]
189
+ embeddings = []
190
+ for i, w in enumerate(weights):
191
+ proj = self.embedding_layer.projections[i]
192
+ embeddings.append(proj(w))
193
+ embeddings = torch.cat(embeddings, dim=0)
194
+ else:
195
+ embeddings = self.embedding_layer.weight
196
+ if tokens is not None and len(tokens) > 0:
197
+ embeddings = embeddings[tokens]
198
+ normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
199
+ return torch.mm(normalized_embeddings, normalized_embeddings.t())
200
+ def find_similar_tokens(self, token_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
201
+ similarity_matrix = self.get_similarity_matrix()
202
+ similarities = similarity_matrix[token_id]
203
+ top_similarities, top_indices = torch.topk(similarities, top_k + 1)
204
+ mask = top_indices != token_id
205
+ top_similarities = top_similarities[mask][:top_k]
206
+ top_indices = top_indices[mask][:top_k]
207
+ return list(zip(top_indices.tolist(), top_similarities.tolist()))
208
+ def analyze_embedding_distribution(self):
209
+ if hasattr(self.embedding_layer, 'token_embedding'):
210
+ weights = self.embedding_layer.token_embedding.weight
211
+ elif hasattr(self.embedding_layer, 'embeddings'):
212
+ weights = torch.cat([emb.weight for emb in self.embedding_layer.embeddings], dim=0)
213
+ else:
214
+ weights = self.embedding_layer.weight
215
+ stats = {
216
+ 'mean': weights.mean().item(),
217
+ 'std': weights.std().item(),
218
+ 'min': weights.min().item(),
219
+ 'max': weights.max().item(),
220
+ 'norm_mean': weights.norm(dim=1).mean().item(),
221
+ 'norm_std': weights.norm(dim=1).std().item()
222
+ }
223
+ return stats
224
+ def test_embeddings():
225
+ print("Testing embedding layers...")
226
+ vocab_size = 1000
227
+ d_model = 512
228
+ max_seq_len = 128
229
+ batch_size = 4
230
+ seq_len = 64
231
+ input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
232
+ embedding_types = [
233
+ ("Learned Position", "learned"),
234
+ ("Sinusoidal Position", "sinusoidal"),
235
+ ("RoPE", "rope")
236
+ ]
237
+ for name, pos_type in embedding_types:
238
+ print(f"\nTesting {name} Embedding:")
239
+ embedding_layer = TechEmbeddingLayer(
240
+ vocab_size=vocab_size,
241
+ d_model=d_model,
242
+ max_seq_len=max_seq_len,
243
+ pos_encoding=pos_type
244
+ )
245
+ embeddings = embedding_layer(input_ids)
246
+ print(f"Input shape: {input_ids.shape}")
247
+ print(f"Output shape: {embeddings.shape}")
248
+ print(f"Expected shape: ({batch_size}, {seq_len}, {d_model})")
249
+ analyzer = EmbeddingAnalyzer(embedding_layer)
250
+ stats = analyzer.analyze_embedding_distribution()
251
+ print(f"Embedding statistics:")
252
+ for key, value in stats.items():
253
+ print(f" {key}: {value:.4f}")
254
+ print(f"\nTesting Adaptive Embeddings:")
255
+ adaptive_emb = AdaptiveEmbedding(
256
+ vocab_size=vocab_size,
257
+ d_model=d_model,
258
+ cutoffs=[200, 500],
259
+ div_val=2.0
260
+ )
261
+ embeddings = adaptive_emb(input_ids)
262
+ print(f"Adaptive embedding output shape: {embeddings.shape}")
263
+ print(f"\nTesting masking functions:")
264
+ input_ids_padded = input_ids.clone()
265
+ input_ids_padded[:, -10:] = 0
266
+ padding_mask = create_padding_mask(input_ids_padded, padding_idx=0)
267
+ causal_mask = create_causal_mask(seq_len, input_ids.device)
268
+ attention_mask = create_attention_mask(input_ids_padded, padding_idx=0, causal=True)
269
+ print(f"Padding mask shape: {padding_mask.shape}")
270
+ print(f"Causal mask shape: {causal_mask.shape}")
271
+ print(f"Attention mask shape: {attention_mask.shape}")
272
+ print(f"Padding positions: {padding_mask.sum().item()}")
273
+ print(f"Causal mask positions: {causal_mask.sum().item()}")
274
+ print(f"Combined mask positions: {attention_mask.sum().item()}")
275
+ print("\nAll embedding tests completed successfully!")
276
+ if __name__ == "__main__":
277
+ test_embeddings()