Girinath11 commited on
Commit
e329b2c
·
verified ·
1 Parent(s): 02aea9b

Update embeddings.py

Browse files
Files changed (1) hide show
  1. embeddings.py +343 -140
embeddings.py CHANGED
@@ -3,92 +3,186 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
  from typing import Optional, Tuple, List
 
 
 
 
 
 
 
 
 
6
  class PositionalEncoding(nn.Module):
7
- def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
8
- super(PositionalEncoding, self).__init__()
 
 
 
 
 
 
 
 
 
 
9
  self.d_model = d_model
10
- self.dropout = nn.Dropout(dropout)
 
11
  pe = torch.zeros(max_seq_len, d_model)
12
  position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
13
- div_term = torch.exp(torch.arange(0, d_model, 2).float() *
14
- -(math.log(10000.0) / d_model))
15
  pe[:, 0::2] = torch.sin(position * div_term)
16
- if d_model % 2 == 1:
17
- pe[:, 1::2] = torch.cos(position * div_term[:-1])
18
- else:
19
- pe[:, 1::2] = torch.cos(position * div_term)
20
- self.register_buffer('pe', pe.unsqueeze(0))
21
- def forward(self, x):
 
 
 
 
 
 
 
22
  batch_size, seq_len, d_model = x.size()
23
- x = x + self.pe[:, :seq_len, :d_model]
 
 
24
  return self.dropout(x)
 
25
  class LearnedPositionalEmbedding(nn.Module):
26
- def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1):
27
- super(LearnedPositionalEmbedding, self).__init__()
 
 
 
 
 
 
 
 
 
 
28
  self.max_seq_len = max_seq_len
29
- self.d_model = d_model
30
  self.pos_embedding = nn.Embedding(max_seq_len, d_model)
31
  self.dropout = nn.Dropout(dropout)
32
- nn.init.normal_(self.pos_embedding.weight, std=0.02)
33
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
34
  batch_size, seq_len, d_model = x.size()
35
  if seq_len > self.max_seq_len:
36
- raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
 
 
37
  positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
38
  pos_emb = self.pos_embedding(positions)
39
  x = x + pos_emb
40
  return self.dropout(x)
 
41
  class RotaryPositionalEmbedding(nn.Module):
42
- def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = 10000.0):
43
- super(RotaryPositionalEmbedding, self).__init__()
 
 
 
 
 
 
 
 
 
 
44
  self.d_model = d_model
45
  self.max_seq_len = max_seq_len
46
- self.base = base
47
  inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
48
- self.register_buffer('inv_freq', inv_freq)
49
  self._seq_len_cached = 0
50
  self._cos_cached = None
51
- self._sin_cached = None
52
- def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
 
 
53
  if seq_len > self._seq_len_cached:
54
  self._seq_len_cached = seq_len
55
  t = torch.arange(seq_len, device=device, dtype=torch.float32)
56
  freqs = torch.outer(t, self.inv_freq)
57
  self._cos_cached = freqs.cos().to(dtype)
58
- self._sin_cached = freqs.sin().to(dtype)
 
 
 
 
 
 
59
  def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
60
  batch_size, seq_len, num_heads, head_dim = q.shape
61
- self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
62
- cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2]
63
- sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2]
64
- cos = cos.view(1, seq_len, 1, -1)
65
- sin = sin.view(1, seq_len, 1, -1)
66
  q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
67
- k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
68
  q_rot = self._rotate_half(q, cos, sin)
69
- k_rot = self._rotate_half(k, cos, sin)
70
  q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
71
- k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
72
- return q_rot, k_rot
73
- def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
74
- x1 = x[..., :x.shape[-1] // 2]
75
- x2 = x[..., x.shape[-1] // 2:]
76
- return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
77
  class TechEmbeddingLayer(nn.Module):
78
- def __init__(self,
79
- vocab_size: int,
80
- d_model: int,
81
- max_seq_len: int = 512,
82
- dropout: float = 0.1,
83
- padding_idx: int = 0,
84
- pos_encoding: str = "learned",
85
- layer_norm: bool = True):
86
- super(TechEmbeddingLayer, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  self.d_model = d_model
88
  self.vocab_size = vocab_size
89
- self.padding_idx = padding_idx
90
- self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
91
- self.pos_encoding_type = pos_encoding
 
92
  if pos_encoding == "sinusoidal":
93
  self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
94
  elif pos_encoding == "learned":
@@ -96,182 +190,291 @@ class TechEmbeddingLayer(nn.Module):
96
  elif pos_encoding == "rope":
97
  self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
98
  else:
99
- raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
 
100
  self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
101
  self.dropout = nn.Dropout(dropout)
102
- self._init_weights()
103
- def _init_weights(self):
 
 
104
  nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
105
  if self.padding_idx is not None:
106
- nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
 
107
  def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
108
  if (input_ids >= self.vocab_size).any():
109
- raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
110
  embeddings = self.token_embedding(input_ids)
111
  if self.pos_encoding_type != "rope":
112
- embeddings = self.pos_encoding(embeddings)
113
  embeddings = self.layer_norm(embeddings)
114
- embeddings = self.dropout(embeddings)
115
- return embeddings
116
- def get_positional_encoding(self):
 
117
  return self.pos_encoding if self.pos_encoding_type == "rope" else None
 
118
  class AdaptiveEmbedding(nn.Module):
119
- def __init__(self,
120
- vocab_size: int,
121
- d_model: int,
122
- cutoffs: list = [2000, 10000],
123
- div_val: float = 4.0):
124
- super(AdaptiveEmbedding, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  self.vocab_size = vocab_size
126
  self.d_model = d_model
127
  self.cutoffs = [0] + cutoffs + [vocab_size]
128
- self.div_val = div_val
 
129
  self.embeddings = nn.ModuleList()
130
- self.projections = nn.ModuleList()
 
131
  for i in range(len(self.cutoffs) - 1):
132
- l_idx = self.cutoffs[i]
133
- r_idx = self.cutoffs[i + 1]
134
- d_emb = int(d_model / (div_val ** i))
135
  emb = nn.Embedding(r_idx - l_idx, d_emb)
136
  nn.init.normal_(emb.weight, mean=0.0, std=0.02)
137
- self.embeddings.append(emb)
 
 
 
138
  if d_emb != d_model:
139
- proj = nn.Linear(d_emb, d_model, bias=False)
140
- nn.init.normal_(proj.weight, mean=0.0, std=0.02)
141
- self.projections.append(proj)
142
- else:
143
- self.projections.append(nn.Identity())
144
  def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
145
  if (input_ids >= self.vocab_size).any():
146
- raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
147
  batch_size, seq_len = input_ids.shape
148
- embeddings = torch.zeros(batch_size, seq_len, self.d_model,
149
- device=input_ids.device, dtype=torch.float32)
150
  for i in range(len(self.cutoffs) - 1):
151
- l_idx = self.cutoffs[i]
152
- r_idx = self.cutoffs[i + 1]
153
  mask = (input_ids >= l_idx) & (input_ids < r_idx)
154
  if mask.any():
155
- indices = input_ids[mask] - l_idx
156
- indices = indices.clamp(max=r_idx - l_idx - 1)
157
  emb = self.embeddings[i](indices)
158
- emb = self.projections[i](emb)
159
- embeddings[mask] = emb
160
  return embeddings
161
- def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
162
  return input_ids == padding_idx
 
163
  def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
164
  return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
165
- def create_attention_mask(input_ids: torch.Tensor,
166
- padding_idx: int = 0,
167
- causal: bool = True) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
168
  batch_size, seq_len = input_ids.shape
169
  device = input_ids.device
170
-
171
- padding_mask = create_padding_mask(input_ids, padding_idx)
172
- padding_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
173
  if causal:
174
- causal_mask = create_causal_mask(seq_len, device)
175
- causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, seq_len)
176
- combined_mask = padding_mask | causal_mask
177
- else:
178
- combined_mask = padding_mask
179
-
180
- return combined_mask
181
  class EmbeddingAnalyzer:
 
 
182
  def __init__(self, embedding_layer: nn.Module):
183
- self.embedding_layer = embedding_layer
184
- def get_similarity_matrix(self, tokens: List[int] = None) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  if hasattr(self.embedding_layer, 'token_embedding'):
186
  embeddings = self.embedding_layer.token_embedding.weight
187
  elif hasattr(self.embedding_layer, 'embeddings'):
188
- weights = [emb.weight for emb in self.embedding_layer.embeddings]
189
- embeddings = []
190
- for i, w in enumerate(weights):
191
- proj = self.embedding_layer.projections[i]
192
- embeddings.append(proj(w))
193
- embeddings = torch.cat(embeddings, dim=0)
194
  else:
195
- embeddings = self.embedding_layer.weight
 
196
  if tokens is not None and len(tokens) > 0:
197
- embeddings = embeddings[tokens]
198
- normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
199
- return torch.mm(normalized_embeddings, normalized_embeddings.t())
200
  def find_similar_tokens(self, token_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
 
 
 
 
 
 
 
 
 
 
201
  similarity_matrix = self.get_similarity_matrix()
 
 
202
  similarities = similarity_matrix[token_id]
203
  top_similarities, top_indices = torch.topk(similarities, top_k + 1)
204
  mask = top_indices != token_id
205
- top_similarities = top_similarities[mask][:top_k]
206
- top_indices = top_indices[mask][:top_k]
207
- return list(zip(top_indices.tolist(), top_similarities.tolist()))
208
- def analyze_embedding_distribution(self):
 
 
 
 
 
209
  if hasattr(self.embedding_layer, 'token_embedding'):
210
  weights = self.embedding_layer.token_embedding.weight
211
  elif hasattr(self.embedding_layer, 'embeddings'):
212
  weights = torch.cat([emb.weight for emb in self.embedding_layer.embeddings], dim=0)
213
  else:
214
- weights = self.embedding_layer.weight
215
- stats = {
216
  'mean': weights.mean().item(),
217
  'std': weights.std().item(),
218
  'min': weights.min().item(),
219
  'max': weights.max().item(),
220
  'norm_mean': weights.norm(dim=1).mean().item(),
221
- 'norm_std': weights.norm(dim=1).std().item()
222
  }
223
- return stats
224
- def test_embeddings():
225
- print("Testing embedding layers...")
 
226
  vocab_size = 1000
227
  d_model = 512
228
  max_seq_len = 128
229
  batch_size = 4
230
- seq_len = 64
231
- input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
 
232
  embedding_types = [
233
  ("Learned Position", "learned"),
234
  ("Sinusoidal Position", "sinusoidal"),
235
- ("RoPE", "rope")
236
- ]
 
237
  for name, pos_type in embedding_types:
238
  print(f"\nTesting {name} Embedding:")
239
  embedding_layer = TechEmbeddingLayer(
240
  vocab_size=vocab_size,
241
  d_model=d_model,
242
  max_seq_len=max_seq_len,
243
- pos_encoding=pos_type
244
- )
245
  embeddings = embedding_layer(input_ids)
 
246
  print(f"Input shape: {input_ids.shape}")
247
  print(f"Output shape: {embeddings.shape}")
248
- print(f"Expected shape: ({batch_size}, {seq_len}, {d_model})")
 
249
  analyzer = EmbeddingAnalyzer(embedding_layer)
250
  stats = analyzer.analyze_embedding_distribution()
251
  print(f"Embedding statistics:")
252
  for key, value in stats.items():
253
- print(f" {key}: {value:.4f}")
254
- print(f"\nTesting Adaptive Embeddings:")
255
- adaptive_emb = AdaptiveEmbedding(
256
- vocab_size=vocab_size,
257
- d_model=d_model,
258
- cutoffs=[200, 500],
259
- div_val=2.0
260
- )
261
  embeddings = adaptive_emb(input_ids)
262
- print(f"Adaptive embedding output shape: {embeddings.shape}")
263
- print(f"\nTesting masking functions:")
 
 
264
  input_ids_padded = input_ids.clone()
265
  input_ids_padded[:, -10:] = 0
266
  padding_mask = create_padding_mask(input_ids_padded, padding_idx=0)
267
  causal_mask = create_causal_mask(seq_len, input_ids.device)
268
- attention_mask = create_attention_mask(input_ids_padded, padding_idx=0, causal=True)
 
 
 
 
269
  print(f"Padding mask shape: {padding_mask.shape}")
270
  print(f"Causal mask shape: {causal_mask.shape}")
271
  print(f"Attention mask shape: {attention_mask.shape}")
272
  print(f"Padding positions: {padding_mask.sum().item()}")
273
  print(f"Causal mask positions: {causal_mask.sum().item()}")
274
- print(f"Combined mask positions: {attention_mask.sum().item()}")
 
275
  print("\nAll embedding tests completed successfully!")
 
276
  if __name__ == "__main__":
277
  test_embeddings()
 
3
  import torch.nn.functional as F
4
  import math
5
  from typing import Optional, Tuple, List
6
+
7
+ # Constants for default configuration
8
+ DEFAULT_MAX_SEQ_LEN = 512
9
+ DEFAULT_DROPOUT = 0.1
10
+ DEFAULT_BASE = 10000.0
11
+ DEFAULT_CUTOFFS = [2000, 10000]
12
+ DEFAULT_DIV_VAL = 4.0
13
+ DEFAULT_PADDING_IDX = 0
14
+
15
  class PositionalEncoding(nn.Module):
16
+ """Sinusoidal positional encoding for transformer models."""
17
+
18
+ def __init__(self, d_model: int, max_seq_len: int = DEFAULT_MAX_SEQ_LEN, dropout: float = DEFAULT_DROPOUT):
19
+ """
20
+ Initialize sinusoidal positional encoding.
21
+
22
+ Args:
23
+ d_model (int): Dimension of the model embeddings.
24
+ max_seq_len (int): Maximum sequence length for positional encodings.
25
+ dropout (float): Dropout rate for regularization.
26
+ """
27
+ super().__init__()
28
  self.d_model = d_model
29
+ self.dropout = nn.Dropout(dropout)
30
+
31
  pe = torch.zeros(max_seq_len, d_model)
32
  position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
33
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
 
34
  pe[:, 0::2] = torch.sin(position * div_term)
35
+ pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
36
+ self.register_buffer('pe', pe.unsqueeze(0))
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Apply positional encoding to input embeddings.
41
+
42
+ Args:
43
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
44
+
45
+ Returns:
46
+ torch.Tensor: Tensor with positional encodings applied.
47
+ """
48
  batch_size, seq_len, d_model = x.size()
49
+ if d_model != self.d_model:
50
+ raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
51
+ x = x + self.pe[:, :seq_len]
52
  return self.dropout(x)
53
+
54
  class LearnedPositionalEmbedding(nn.Module):
55
+ """Learned positional embeddings for transformer models."""
56
+
57
+ def __init__(self, max_seq_len: int, d_model: int, dropout: float = DEFAULT_DROPOUT):
58
+ """
59
+ Initialize learned positional embeddings.
60
+
61
+ Args:
62
+ max_seq_len (int): Maximum sequence length.
63
+ d_model (int): Dimension of the model embeddings.
64
+ dropout (float): Dropout rate for regularization.
65
+ """
66
+ super().__init__()
67
  self.max_seq_len = max_seq_len
68
+ self.d_model = d_model
69
  self.pos_embedding = nn.Embedding(max_seq_len, d_model)
70
  self.dropout = nn.Dropout(dropout)
71
+ nn.init.normal_(self.pos_embedding.weight, std=0.02)
72
+
73
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
74
+ """
75
+ Apply learned positional embeddings to input.
76
+
77
+ Args:
78
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
79
+
80
+ Returns:
81
+ torch.Tensor: Tensor with positional embeddings applied.
82
+ """
83
  batch_size, seq_len, d_model = x.size()
84
  if seq_len > self.max_seq_len:
85
+ raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
86
+ if d_model != self.d_model:
87
+ raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
88
  positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
89
  pos_emb = self.pos_embedding(positions)
90
  x = x + pos_emb
91
  return self.dropout(x)
92
+
93
  class RotaryPositionalEmbedding(nn.Module):
94
+ """Rotary Positional Embedding (RoPE) for transformer models."""
95
+
96
+ def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
97
+ """
98
+ Initialize rotary positional embeddings.
99
+
100
+ Args:
101
+ d_model (int): Dimension of the model embeddings.
102
+ max_seq_len (int): Maximum sequence length.
103
+ base (float): Base for frequency calculation.
104
+ """
105
+ super().__init__()
106
  self.d_model = d_model
107
  self.max_seq_len = max_seq_len
108
+ self.base = base
109
  inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
110
+ self.register_buffer('inv_freq', inv_freq)
111
  self._seq_len_cached = 0
112
  self._cos_cached = None
113
+ self._sin_cached = None
114
+
115
+ def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
116
+ """Update cached cosine and sine values for RoPE."""
117
  if seq_len > self._seq_len_cached:
118
  self._seq_len_cached = seq_len
119
  t = torch.arange(seq_len, device=device, dtype=torch.float32)
120
  freqs = torch.outer(t, self.inv_freq)
121
  self._cos_cached = freqs.cos().to(dtype)
122
+ self._sin_cached = freqs.sin().to(dtype)
123
+
124
+ def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
125
+ """Apply rotary transformation to half of the tensor."""
126
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
127
+ return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
128
+
129
  def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ """
131
+ Apply rotary positional embeddings to query and key tensors.
132
+
133
+ Args:
134
+ q (torch.Tensor): Query tensor of shape (batch_size, seq_len, num_heads, head_dim).
135
+ k (torch.Tensor): Key tensor of shape (batch_size, seq_len, num_heads, head_dim).
136
+ start_pos (int): Starting position for positional encoding.
137
+
138
+ Returns:
139
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
140
+ """
141
  batch_size, seq_len, num_heads, head_dim = q.shape
142
+ self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
143
+ cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
144
+ sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
145
+
 
146
  q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
147
+ k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
148
  q_rot = self._rotate_half(q, cos, sin)
149
+ k_rot = self._rotate_half(k, cos, sin)
150
  q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
151
+ k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
152
+ return q_rot, k_rot
153
+
 
 
 
154
  class TechEmbeddingLayer(nn.Module):
155
+ """Comprehensive embedding layer with token and positional embeddings."""
156
+
157
+ def __init__(
158
+ self,
159
+ vocab_size: int,
160
+ d_model: int,
161
+ max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
162
+ dropout: float = DEFAULT_DROPOUT,
163
+ padding_idx: int = DEFAULT_PADDING_IDX,
164
+ pos_encoding: str = "learned",
165
+ layer_norm: bool = True,
166
+ ):
167
+ """
168
+ Initialize the embedding layer.
169
+
170
+ Args:
171
+ vocab_size (int): Size of the vocabulary.
172
+ d_model (int): Dimension of the model embeddings.
173
+ max_seq_len (int): Maximum sequence length.
174
+ dropout (float): Dropout rate.
175
+ padding_idx (int): Index for padding token.
176
+ pos_encoding (str): Type of positional encoding ('sinusoidal', 'learned', 'rope').
177
+ layer_norm (bool): Whether to apply layer normalization.
178
+ """
179
+ super().__init__()
180
  self.d_model = d_model
181
  self.vocab_size = vocab_size
182
+ self.padding_idx = padding_idx
183
+ self.pos_encoding_type = pos_encoding.lower()
184
+
185
+ self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
186
  if pos_encoding == "sinusoidal":
187
  self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
188
  elif pos_encoding == "learned":
 
190
  elif pos_encoding == "rope":
191
  self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
192
  else:
193
+ raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
194
+
195
  self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
196
  self.dropout = nn.Dropout(dropout)
197
+ self._init_weights()
198
+
199
+ def _init_weights(self) -> None:
200
+ """Initialize weights for token embeddings."""
201
  nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
202
  if self.padding_idx is not None:
203
+ nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
204
+
205
  def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
206
+ """
207
+ Forward pass for embedding layer.
208
+
209
+ Args:
210
+ input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
211
+
212
+ Returns:
213
+ torch.Tensor: Embedded tensor of shape (batch_size, seq_len, d_model).
214
+ """
215
  if (input_ids >= self.vocab_size).any():
216
+ raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
217
  embeddings = self.token_embedding(input_ids)
218
  if self.pos_encoding_type != "rope":
219
+ embeddings = self.pos_encoding(embeddings)
220
  embeddings = self.layer_norm(embeddings)
221
+ return self.dropout(embeddings)
222
+
223
+ def get_positional_encoding(self) -> Optional[nn.Module]:
224
+ """Return the positional encoding module if RoPE, else None."""
225
  return self.pos_encoding if self.pos_encoding_type == "rope" else None
226
+
227
  class AdaptiveEmbedding(nn.Module):
228
+ """Adaptive embedding layer with variable embedding dimensions."""
229
+
230
+ def __init__(
231
+ self,
232
+ vocab_size: int,
233
+ d_model: int,
234
+ cutoffs: List[int] = DEFAULT_CUTOFFS,
235
+ div_val: float = DEFAULT_DIV_VAL,
236
+ ):
237
+ """
238
+ Initialize adaptive embedding layer.
239
+
240
+ Args:
241
+ vocab_size (int): Size of the vocabulary.
242
+ d_model (int): Dimension of the model embeddings.
243
+ cutoffs (List[int]): Cutoff points for vocabulary splits.
244
+ div_val (float): Division factor for embedding dimensions.
245
+ """
246
+ super().__init__()
247
  self.vocab_size = vocab_size
248
  self.d_model = d_model
249
  self.cutoffs = [0] + cutoffs + [vocab_size]
250
+ self.div_val = div_val
251
+
252
  self.embeddings = nn.ModuleList()
253
+ self.projections = nn.ModuleList()
254
+
255
  for i in range(len(self.cutoffs) - 1):
256
+ l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]
257
+ d_emb = int(d_model / (div_val ** i))
 
258
  emb = nn.Embedding(r_idx - l_idx, d_emb)
259
  nn.init.normal_(emb.weight, mean=0.0, std=0.02)
260
+ self.embeddings.append(emb)
261
+ self.projections.append(
262
+ nn.Linear(d_emb, d_model, bias=False) if d_emb != d_model else nn.Identity()
263
+ )
264
  if d_emb != d_model:
265
+ nn.init.normal_(self.projections[-1].weight, mean=0.0, std=0.02)
266
+
 
 
 
267
  def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
268
+ """
269
+ Forward pass for adaptive embedding.
270
+
271
+ Args:
272
+ input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
273
+
274
+ Returns:
275
+ torch.Tensor: Embedded tensor of shape (batch_size, seq_len, d_model).
276
+ """
277
  if (input_ids >= self.vocab_size).any():
278
+ raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
279
  batch_size, seq_len = input_ids.shape
280
+ embeddings = torch.zeros(batch_size, seq_len, self.d_model, device=input_ids.device, dtype=torch.float32)
281
+
282
  for i in range(len(self.cutoffs) - 1):
283
+ l_idx, r_idx = self.cutoffs[i], self.cutoffs[i + 1]
 
284
  mask = (input_ids >= l_idx) & (input_ids < r_idx)
285
  if mask.any():
286
+ indices = (input_ids[mask] - l_idx).clamp(max=r_idx - l_idx - 1)
 
287
  emb = self.embeddings[i](indices)
288
+ embeddings[mask] = self.projections[i](emb)
 
289
  return embeddings
290
+
291
+ def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = DEFAULT_PADDING_IDX) -> torch.Tensor:
292
+ """
293
+ Create a padding mask for input IDs.
294
+
295
+ Args:
296
+ input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
297
+ padding_idx (int): Index for padding token.
298
+
299
+ Returns:
300
+ torch.Tensor: Padding mask of shape (batch_size, seq_len).
301
+ """
302
  return input_ids == padding_idx
303
+
304
  def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
305
+ """
306
+ Create a causal mask for attention.
307
+
308
+ Args:
309
+ seq_len (int): Sequence length.
310
+ device (torch.device): Device for tensor allocation.
311
+
312
+ Returns:
313
+ torch.Tensor: Causal mask of shape (seq_len, seq_len).
314
+ """
315
  return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
316
+
317
+ def create_attention_mask(input_ids: torch.Tensor, padding_idx: int = DEFAULT_PADDING_IDX, causal: bool = True) -> torch.Tensor:
318
+ """
319
+ Create an attention mask combining padding and causal masks.
320
+
321
+ Args:
322
+ input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
323
+ padding_idx (int): Index for padding token.
324
+ causal (bool): Whether to include causal masking.
325
+
326
+ Returns:
327
+ torch.Tensor: Attention mask of shape (batch_size, seq_len, seq_len).
328
+ """
329
  batch_size, seq_len = input_ids.shape
330
  device = input_ids.device
331
+ padding_mask = create_padding_mask(input_ids, padding_idx).unsqueeze(1).expand(batch_size, seq_len, seq_len)
 
 
332
  if causal:
333
+ causal_mask = create_causal_mask(seq_len, device).unsqueeze(0).expand(batch_size, seq_len, seq_len)
334
+ return padding_mask | causal_mask
335
+ return padding_mask
336
+
 
 
 
337
  class EmbeddingAnalyzer:
338
+ """Analyzer for inspecting embedding layer properties."""
339
+
340
  def __init__(self, embedding_layer: nn.Module):
341
+ """
342
+ Initialize the embedding analyzer.
343
+
344
+ Args:
345
+ embedding_layer (nn.Module): The embedding layer to analyze.
346
+ """
347
+ self.embedding_layer = embedding_layer
348
+
349
+ def get_similarity_matrix(self, tokens: Optional[List[int]] = None) -> torch.Tensor:
350
+ """
351
+ Compute the cosine similarity matrix for embeddings.
352
+
353
+ Args:
354
+ tokens (Optional[List[int]]): List of token IDs to compute similarities for.
355
+
356
+ Returns:
357
+ torch.Tensor: Cosine similarity matrix.
358
+ """
359
  if hasattr(self.embedding_layer, 'token_embedding'):
360
  embeddings = self.embedding_layer.token_embedding.weight
361
  elif hasattr(self.embedding_layer, 'embeddings'):
362
+ embeddings = torch.cat(
363
+ [self.embedding_layer.projections[i](emb.weight) for i, emb in enumerate(self.embedding_layer.embeddings)],
364
+ dim=0
365
+ )
 
 
366
  else:
367
+ embeddings = self.embedding_layer.weight
368
+
369
  if tokens is not None and len(tokens) > 0:
370
+ embeddings = embeddings[tokens]
371
+ return torch.mm(F.normalize(embeddings, p=2, dim=1), F.normalize(embeddings, p=2, dim=1).t())
372
+
373
  def find_similar_tokens(self, token_id: int, top_k: int = 10) -> List[Tuple[int, float]]:
374
+ """
375
+ Find the top-k most similar tokens to a given token ID.
376
+
377
+ Args:
378
+ token_id (int): Token ID to find similar tokens for.
379
+ top_k (int): Number of similar tokens to return.
380
+
381
+ Returns:
382
+ List[Tuple[int, float]]: List of (token_id, similarity_score) pairs.
383
+ """
384
  similarity_matrix = self.get_similarity_matrix()
385
+ if token_id >= similarity_matrix.shape[0]:
386
+ raise ValueError(f"Token ID {token_id} is out of range")
387
  similarities = similarity_matrix[token_id]
388
  top_similarities, top_indices = torch.topk(similarities, top_k + 1)
389
  mask = top_indices != token_id
390
+ return list(zip(top_indices[mask][:top_k].tolist(), top_similarities[mask][:top_k].tolist()))
391
+
392
+ def analyze_embedding_distribution(self) -> dict:
393
+ """
394
+ Analyze the statistical properties of the embedding weights.
395
+
396
+ Returns:
397
+ dict: Dictionary containing mean, std, min, max, norm_mean, and norm_std of embeddings.
398
+ """
399
  if hasattr(self.embedding_layer, 'token_embedding'):
400
  weights = self.embedding_layer.token_embedding.weight
401
  elif hasattr(self.embedding_layer, 'embeddings'):
402
  weights = torch.cat([emb.weight for emb in self.embedding_layer.embeddings], dim=0)
403
  else:
404
+ weights = self.embedding_layer.weight
405
+ return {
406
  'mean': weights.mean().item(),
407
  'std': weights.std().item(),
408
  'min': weights.min().item(),
409
  'max': weights.max().item(),
410
  'norm_mean': weights.norm(dim=1).mean().item(),
411
+ 'norm_std': weights.norm(dim=1).std().item(),
412
  }
413
+
414
+ def test_embeddings() -> None:
415
+ """Test the embedding layers and related utilities."""
416
+ print("Starting embedding layer tests...")
417
  vocab_size = 1000
418
  d_model = 512
419
  max_seq_len = 128
420
  batch_size = 4
421
+ seq_len = 64
422
+
423
+ input_ids = torch.randint(1, vocab_size, (batch_size, seq_len))
424
  embedding_types = [
425
  ("Learned Position", "learned"),
426
  ("Sinusoidal Position", "sinusoidal"),
427
+ ("RoPE", "rope"),
428
+ ]
429
+
430
  for name, pos_type in embedding_types:
431
  print(f"\nTesting {name} Embedding:")
432
  embedding_layer = TechEmbeddingLayer(
433
  vocab_size=vocab_size,
434
  d_model=d_model,
435
  max_seq_len=max_seq_len,
436
+ pos_encoding=pos_type,
437
+ )
438
  embeddings = embedding_layer(input_ids)
439
+ assert embeddings.shape == (batch_size, seq_len, d_model), f"Unexpected shape for {name}: {embeddings.shape}"
440
  print(f"Input shape: {input_ids.shape}")
441
  print(f"Output shape: {embeddings.shape}")
442
+ print(f"Expected shape: ({batch_size}, {seq_len}, {d_model})")
443
+
444
  analyzer = EmbeddingAnalyzer(embedding_layer)
445
  stats = analyzer.analyze_embedding_distribution()
446
  print(f"Embedding statistics:")
447
  for key, value in stats.items():
448
+ print(f" {key}: {value:.4f}")
449
+
450
+ # Test similarity for a sample token
451
+ similar_tokens = analyzer.find_similar_tokens(token_id=0, top_k=5)
452
+ print(f"Top 5 similar tokens to token 0: {similar_tokens}")
453
+
454
+ print("\nTesting Adaptive Embeddings:")
455
+ adaptive_emb = AdaptiveEmbedding(vocab_size=vocab_size, d_model=d_model, cutoffs=[200, 500], div_val=2.0)
456
  embeddings = adaptive_emb(input_ids)
457
+ assert embeddings.shape == (batch_size, seq_len, d_model), f"Unexpected adaptive embedding shape: {embeddings.shape}"
458
+ print(f"Adaptive embedding output shape: {embeddings.shape}")
459
+
460
+ print("\nTesting masking functions:")
461
  input_ids_padded = input_ids.clone()
462
  input_ids_padded[:, -10:] = 0
463
  padding_mask = create_padding_mask(input_ids_padded, padding_idx=0)
464
  causal_mask = create_causal_mask(seq_len, input_ids.device)
465
+ attention_mask = create_attention_mask(input_ids_padded, padding_idx=0, causal=True)
466
+
467
+ assert padding_mask.shape == (batch_size, seq_len), f"Unexpected padding mask shape: {padding_mask.shape}"
468
+ assert causal_mask.shape == (seq_len, seq_len), f"Unexpected causal mask shape: {causal_mask.shape}"
469
+ assert attention_mask.shape == (batch_size, seq_len, seq_len), f"Unexpected attention mask shape: {attention_mask.shape}"
470
  print(f"Padding mask shape: {padding_mask.shape}")
471
  print(f"Causal mask shape: {causal_mask.shape}")
472
  print(f"Attention mask shape: {attention_mask.shape}")
473
  print(f"Padding positions: {padding_mask.sum().item()}")
474
  print(f"Causal mask positions: {causal_mask.sum().item()}")
475
+ print(f"Combined mask positions: {attention_mask.sum().item()}")
476
+
477
  print("\nAll embedding tests completed successfully!")
478
+
479
  if __name__ == "__main__":
480
  test_embeddings()