Girinath11 commited on
Commit
5a3c70d
·
verified ·
1 Parent(s): d6816d4

Delete embeddings.py

Browse files
Files changed (1) hide show
  1. embeddings.py +0 -480
embeddings.py DELETED
@@ -1,480 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
- from typing import Optional, Tuple, List
6
-
7
- # Constants for default configuration
8
- DEFAULT_MAX_SEQ_LEN = 512
9
- DEFAULT_DROPOUT = 0.1
10
- DEFAULT_BASE = 10000.0
11
- DEFAULT_CUTOFFS = [2000, 10000]
12
- DEFAULT_DIV_VAL = 4.0
13
- DEFAULT_PADDING_IDX = 0
14
-
15
- class PositionalEncoding(nn.Module):
16
- """Sinusoidal positional encoding for transformer models."""
17
-
18
- def __init__(self, d_model: int, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, dropout: float = DEFAULT_DROPOUT):
19
- """
20
- Initialize sinusoidal positional encoding.
21
-
22
- Args:
23
- d_model (int): Dimension of the model embeddings.
24
- max_seq_len (int): Maximum sequence length for positional encodings.
25
- dropout (float): Dropout rate for regularization.
26
- """
27
- super().__init__()
28
- self.d_model = d_model
29
- self.dropout = nn.Dropout(dropout)
30
-
31
- pe = torch.zeros(max_seq_len, d_model)
32
- position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
33
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
34
- pe[:, 0::2] = torch.sin(position * div_term)
35
- pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
36
- self.register_buffer('pe', pe.unsqueeze(0))
37
-
38
- def forward(self, x: torch.Tensor) -> torch.Tensor:
39
- """
40
- Apply positional encoding to input embeddings.
41
-
42
- Args:
43
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
44
-
45
- Returns:
46
- torch.Tensor: Tensor with positional encodings applied.
47
- """
48
- batch_size, seq_len, d_model = x.size()
49
- if d_model != self.d_model:
50
- raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
51
- x = x + self.pe[:, :seq_len]
52
- return self.dropout(x)
53
-
54
- class LearnedPositionalEmbedding(nn.Module):
55
- """Learned positional embeddings for transformer models."""
56
-
57
- def __init__(self, max_seq_len: int, d_model: int, dropout: float = DEFAULT_DROPOUT):
58
- """
59
- Initialize learned positional embeddings.
60
-
61
- Args:
62
- max_seq_len (int): Maximum sequence length.
63
- d_model (int): Dimension of the model embeddings.
64
- dropout (float): Dropout rate for regularization.
65
- """
66
- super().__init__()
67
- self.max_seq_len = max_seq_len
68
- self.d_model = d_model
69
- self.pos_embedding = nn.Embedding(max_seq_len, d_model)
70
- self.dropout = nn.Dropout(dropout)
71
- nn.init.normal_(self.pos_embedding.weight, std=0.02)
72
-
73
- def forward(self, x: torch.Tensor) -> torch.Tensor:
74
- """
75
- Apply learned positional embeddings to input.
76
-
77
- Args:
78
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
79
-
80
- Returns:
81
- torch.Tensor: Tensor with positional embeddings applied.
82
- """
83
- batch_size, seq_len, d_model = x.size()
84
- if seq_len > self.max_seq_len:
85
- raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
86
- if d_model != self.d_model:
87
- raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
88
- positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
89
- pos_emb = self.pos_embedding(positions)
90
- x = x + pos_emb
91
- return self.dropout(x)
92
-
93
- class RotaryPositionalEmbedding(nn.Module):
94
- """Rotary Positional Embedding (RoPE) for transformer models."""
95
-
96
- def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
97
- """
98
- Initialize rotary positional embeddings.
99
-
100
- Args:
101
- d_model (int): Dimension of the model embeddings.
102
- max_seq_len (int): Maximum sequence length.
103
- base (float): Base for frequency calculation.
104
- """
105
- super().__init__()
106
- self.d_model = d_model
107
- self.max_seq_len = max_seq_len
108
- self.base = base
109
- inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
110
- self.register_buffer('inv_freq', inv_freq)
111
- self._seq_len_cached = 0
112
- self._cos_cached = None
113
- self._sin_cached = None
114
-
115
- def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
116
- """Update cached cosine and sine values for RoPE."""
117
- if seq_len > self._seq_len_cached:
118
- self._seq_len_cached = seq_len
119
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
120
- freqs = torch.outer(t, self.inv_freq)
121
- self._cos_cached = freqs.cos().to(dtype)
122
- self._sin_cached = freqs.sin().to(dtype)
123
-
124
- def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
125
- """Apply rotary transformation to half of the tensor."""
126
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
127
- return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
128
-
129
- def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
130
- """
131
- Apply rotary positional embeddings to query and key tensors.
132
-
133
- Args:
134
- q (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim).
135
- k (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim).
136
- start_pos (int): Starting position for positional encoding.
137
-
138
- Returns:
139
- Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
140
- """
141
- batch_size, seq_len, num_heads, head_dim = q.shape
142
- self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
143
- cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
144
- sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
145
-
146
- q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
147
- k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
148
- q_rot = self._rotate_half(q, cos, sin)
149
- k_rot = self._rotate_half(k, cos, sin)
150
- q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
151
- k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
152
- return q_rot, k_rot
153
-
154
- class TechEmbeddingLayer(nn.Module):
155
- """Comprehensive embedding layer with token and positional embeddings."""
156
-
157
- def __init__(
158
- self,
159
- vocab_size: int,
160
- d_model: int,
161
- max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
162
- dropout: float = DEFAULT_DROPOUT,
163
- padding_idx: int = DEFAULT_PADDING_IDX,
164
- pos_encoding: str = "learned",
165
- layer_norm: bool = True,
166
- ):
167
- """
168
- Initialize the embedding layer.
169
-
170
- Args:
171
- vocab_size (int): Size of the vocabulary.
172
- d_model (int): Dimension of the model embeddings.
173
- max_seq_len (int): Maximum sequence length.
174
- dropout (float): Dropout rate.
175
- padding_idx (int): Index for padding token.
176
- pos_encoding (str): Type of positional encoding ('sinusoidal', 'learned', 'rope').
177
- layer_norm (bool): Whether to apply layer normalization.
178
- """
179
- super().__init__()
180
- self.d_model = d_model
181
- self.vocab_size = vocab_size
182
- self.padding_idx = padding_idx
183
- self.pos_encoding_type = pos_encoding.lower()
184
-
185
- self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
186
- if pos_encoding == "sinusoidal":
187
- self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
188
- elif pos_encoding == "learned":
189
- self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
190
- elif pos_encoding == "rope":
191
- self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
192
- else:
193
- raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
194
-
195
- self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
196
- self.dropout = nn.Dropout(dropout)
197
- self._init_weights()
198
-
199
- def _init_weights(self) -> None:
200
- """Initialize weights for token embeddings."""
201
- nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
202
- if self.padding_idx is not None:
203
- nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
204
-
205
- def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
206
- """
207
- Forward pass for embedding layer.
208
-
209
- Args:
210
- input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
211
-
212
- Returns:
213
- torch.Tensor: Embedded tensor of shape (batch_size, seq_len, d_model).
214
- """
215
- if (input_ids >= self.vocab_size).any():
216
- raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
217
- embeddings = self.token_embedding(input_ids)
218
- if self.pos_encoding_type != "rope":
219
- embeddings = self.pos_encoding(embeddings)
220
- embeddings = self.layer_norm(embeddings)
221
- return self.dropout(embeddings)
222
-
223
- def get_positional_encoding(self) -> Optional[nn.Module]:
224
- """Return the positional encoding module if RoPE, else None."""
225
- return self.pos_encoding if self.pos_encoding_type == "rope" else None
226
-
227
- class AdaptiveEmbedding(nn.Module):
228
- """Adaptive embedding layer with variable embedding dimensions."""
229
-
230
- def __init__(
231
- self,
232
- vocab_size: int,
233
- d_model: int,
234
- cutoffs: List[int] = DEFAULT_CUTOFFS,
235
- div_val: float = DEFAULT_DIV_VAL,
236
- ):
237
- """
238
- Initialize adaptive embedding layer.
239
-
240
- Args:
241
- vocab_size (int): Size of the vocabulary.
242
- d_model (int): Dimension of the model embeddings.
243
- cutoffs (List[int]): Cutoff points for vocabulary splits.
244
- div_val (float): Division factor for embedding dimensions.
245
- """
246
- super().__init__()
247
- self.vocab_size = vocab_size
248
- self.d_model = d_model
249
- self.cutoffs = [0] + cutoffs + [vocab_size]
250
- self.div_val = div_val
251
-
252
- self.embeddings = nn.ModuleList()
253
- self.projections = nn.ModuleList()
254
-
255
- for i in range(len(self.cutoffs) - 1):
256
- l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]
257
- d_emb = int(d_model / (div_val ** i))
258
- emb = nn.Embedding(r_idx - l_idx, d_emb)
259
- nn.init.normal_(emb.weight, mean=0.0, std=0.02)
260
- self.embeddings.append(emb)
261
- self.projections.append(
262
- nn.Linear(d_emb, d_model, bias=False) if d_emb != d_model else nn.Identity()
263
- )
264
- if d_emb != d_model:
265
- nn.init.normal_(self.projections[-1].weight, mean=0.0, std=0.02)
266
-
267
- def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
268
- """
269
- Forward pass for adaptive embedding.
270
-
271
- Args:
272
- input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
273
-
274
- Returns:
275
- torch.Tensor: Embedded tensor of shape (batch_size, seq_len, d_model).
276
- """
277
- if (input_ids >= self.vocab_size).any():
278
- raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
279
- batch_size, seq_len = input_ids.shape
280
- embeddings = torch.zeros(batch_size, seq_len, self.d_model, device=input_ids.device, dtype=torch.float32)
281
-
282
- for i in range(len(self.cutoffs) - 1):
283
- l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]
284
- mask = (input_ids >= l_idx) & (input_ids < r_idx)
285
- if mask.any():
286
- indices = (input_ids[mask] - l_idx).clamp(max=r_idx - l_idx - 1)
287
- emb = self.embeddings[i](indices)
288
- embeddings[mask] = self.projections[i](emb)
289
- return embeddings
290
-
291
- def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = DEFAULT_PADDING_IDX) -> torch.Tensor:
292
- """
293
- Create a padding mask for input IDs.
294
-
295
- Args:
296
- input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
297
- padding_idx (int): Index for padding token.
298
-
299
- Returns:
300
- torch.Tensor: Padding mask of shape (batch_size, seq_len).
301
- """
302
- return input_ids == padding_idx
303
-
304
- def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
305
- """
306
- Create a causal mask for attention.
307
-
308
- Args:
309
- seq_len (int): Sequence length.
310
- device (torch.device): Device for tensor allocation.
311
-
312
- Returns:
313
- torch.Tensor: Causal mask of shape (seq_len, seq_len).
314
- """
315
- return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
316
-
317
- def create_attention_mask(input_ids: torch.Tensor, padding_idx: int = DEFAULT_PADDING_IDX, causal: bool = True) -> torch.Tensor:
318
- """
319
- Create an attention mask combining padding and causal masks.
320
-
321
- Args:
322
- input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
323
- padding_idx (int): Index for padding token.
324
- causal (bool): Whether to include causal masking.
325
-
326
- Returns:
327
- torch.Tensor: Attention mask of shape (batch_size, seq_len, seq_len).
328
- """
329
- batch_size, seq_len = input_ids.shape
330
- device = input_ids.device
331
- padding_mask = create_padding_mask(input_ids, padding_idx).unsqueeze(1).expand(batch_size, seq_len, seq_len)
332
- if causal:
333
- causal_mask = create_causal_mask(seq_len, device).unsqueeze(0).expand(batch_size, seq_len, seq_len)
334
- return padding_mask | causal_mask
335
- return padding_mask
336
-
337
- class EmbeddingAnalyzer:
338
- """Analyzer for inspecting embedding layer properties."""
339
-
340
- def __init__(self, embedding_layer: nn.Module):
341
- """
342
- Initialize the embedding analyzer.
343
-
344
- Args:
345
- embedding_layer (nn.Module): The embedding layer to analyze.
346
- """
347
- self.embedding_layer = embedding_layer
348
-
349
- def get_similarity_matrix(self, tokens: Optional[List[int]] = None) -> torch.Tensor:
350
- """
351
- Compute the cosine similarity matrix for embeddings.
352
-
353
- Args:
354
- tokens (Optional[List[int]]): List of token IDs to compute similarities for.
355
-
356
- Returns:
357
- torch.Tensor: Cosine similarity matrix.
358
- """
359
- if hasattr(self.embedding_layer, 'token_embedding'):
360
- embeddings = self.embedding_layer.token_embedding.weight
361
- elif hasattr(self.embedding_layer, 'embeddings'):
362
- embeddings = torch.cat(
363
- [self.embedding_layer.projections[i](emb.weight) for i, emb in enumerate(self.embedding_layer.embeddings)],
364
- dim=0
365
- )
366
- else:
367
- embeddings = self.embedding_layer.weight
368
-
369
- if tokens is not None and len(tokens) > 0:
370
- embeddings = embeddings[tokens]
371
- return torch.mm(F.normalize(embeddings, p=2, dim=1), F.normalize(embeddings, p=2, dim=1).t())
372
-
373
- def find_similar_tokens(self, token_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
374
- """
375
- Find the top-k most similar tokens to a given token ID.
376
-
377
- Args:
378
- token_id (int): Token ID to find similar tokens for.
379
- top_k (int): Number of similar tokens to return.
380
-
381
- Returns:
382
- List[Tuple[int, float]]: List of (token_id, similarity_score) pairs.
383
- """
384
- similarity_matrix = self.get_similarity_matrix()
385
- if token_id >= similarity_matrix.shape[0]:
386
- raise ValueError(f"Token ID {token_id} is out of range")
387
- similarities = similarity_matrix[token_id]
388
- top_similarities, top_indices = torch.topk(similarities, top_k + 1)
389
- mask = top_indices != token_id
390
- return list(zip(top_indices[mask][:top_k].tolist(), top_similarities[mask][:top_k].tolist()))
391
-
392
- def analyze_embedding_distribution(self) -> dict:
393
- """
394
- Analyze the statistical properties of the embedding weights.
395
-
396
- Returns:
397
- dict: Dictionary containing mean, std, min, max, norm_mean, and norm_std of embeddings.
398
- """
399
- if hasattr(self.embedding_layer, 'token_embedding'):
400
- weights = self.embedding_layer.token_embedding.weight
401
- elif hasattr(self.embedding_layer, 'embeddings'):
402
- weights = torch.cat([emb.weight for emb in self.embedding_layer.embeddings], dim=0)
403
- else:
404
- weights = self.embedding_layer.weight
405
- return {
406
- 'mean': weights.mean().item(),
407
- 'std': weights.std().item(),
408
- 'min': weights.min().item(),
409
- 'max': weights.max().item(),
410
- 'norm_mean': weights.norm(dim=1).mean().item(),
411
- 'norm_std': weights.norm(dim=1).std().item(),
412
- }
413
-
414
- def test_embeddings() -> None:
415
- """Test the embedding layers and related utilities."""
416
- print("Starting embedding layer tests...")
417
- vocab_size = 1000
418
- d_model = 512
419
- max_seq_len = 128
420
- batch_size = 4
421
- seq_len = 64
422
-
423
- input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
424
- embedding_types = [
425
- ("Learned Position", "learned"),
426
- ("Sinusoidal Position", "sinusoidal"),
427
- ("RoPE", "rope"),
428
- ]
429
-
430
- for name, pos_type in embedding_types:
431
- print(f"\nTesting {name} Embedding:")
432
- embedding_layer = TechEmbeddingLayer(
433
- vocab_size=vocab_size,
434
- d_model=d_model,
435
- max_seq_len=max_seq_len,
436
- pos_encoding=pos_type,
437
- )
438
- embeddings = embedding_layer(input_ids)
439
- assert embeddings.shape == (batch_size, seq_len, d_model), f"Unexpected shape for {name}: {embeddings.shape}"
440
- print(f"Input shape: {input_ids.shape}")
441
- print(f"Output shape: {embeddings.shape}")
442
- print(f"Expected shape: ({batch_size}, {seq_len}, {d_model})")
443
-
444
- analyzer = EmbeddingAnalyzer(embedding_layer)
445
- stats = analyzer.analyze_embedding_distribution()
446
- print(f"Embedding statistics:")
447
- for key, value in stats.items():
448
- print(f" {key}: {value:.4f}")
449
-
450
- # Test similarity for a sample token
451
- similar_tokens = analyzer.find_similar_tokens(token_id=0, top_k=5)
452
- print(f"Top 5 similar tokens to token 0: {similar_tokens}")
453
-
454
- print("\nTesting Adaptive Embeddings:")
455
- adaptive_emb = AdaptiveEmbedding(vocab_size=vocab_size, d_model=d_model, cutoffs=[200, 500], div_val=2.0)
456
- embeddings = adaptive_emb(input_ids)
457
- assert embeddings.shape == (batch_size, seq_len, d_model), f"Unexpected adaptive embedding shape: {embeddings.shape}"
458
- print(f"Adaptive embedding output shape: {embeddings.shape}")
459
-
460
- print("\nTesting masking functions:")
461
- input_ids_padded = input_ids.clone()
462
- input_ids_padded[:, -10:] = 0
463
- padding_mask = create_padding_mask(input_ids_padded, padding_idx=0)
464
- causal_mask = create_causal_mask(seq_len, input_ids.device)
465
- attention_mask = create_attention_mask(input_ids_padded, padding_idx=0, causal=True)
466
-
467
- assert padding_mask.shape == (batch_size, seq_len), f"Unexpected padding mask shape: {padding_mask.shape}"
468
- assert causal_mask.shape == (seq_len, seq_len), f"Unexpected causal mask shape: {causal_mask.shape}"
469
- assert attention_mask.shape == (batch_size, seq_len, seq_len), f"Unexpected attention mask shape: {attention_mask.shape}"
470
- print(f"Padding mask shape: {padding_mask.shape}")
471
- print(f"Causal mask shape: {causal_mask.shape}")
472
- print(f"Attention mask shape: {attention_mask.shape}")
473
- print(f"Padding positions: {padding_mask.sum().item()}")
474
- print(f"Causal mask positions: {causal_mask.sum().item()}")
475
- print(f"Combined mask positions: {attention_mask.sum().item()}")
476
-
477
- print("\nAll embedding tests completed successfully!")
478
-
479
- if __name__ == "__main__":
480
- test_embeddings()