Girinath11 commited on
Commit
bea7101
·
verified ·
1 Parent(s): e329b2c

Update model_slm.py

Browse files
Files changed (1) hide show
  1. model_slm.py +433 -136
model_slm.py CHANGED
@@ -2,86 +2,211 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
 
5
  from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class MultiHeadAttention(nn.Module):
7
- """Multi-head attention mechanism optimized for technical content"""
8
- def __init__(self, d_model, n_heads, dropout=0.1):
9
- super(MultiHeadAttention, self).__init__()
10
- assert d_model % n_heads == 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  self.d_model = d_model
12
  self.n_heads = n_heads
13
- self.d_k = d_model // n_heads
 
14
  self.w_q = nn.Linear(d_model, d_model, bias=False)
15
  self.w_k = nn.Linear(d_model, d_model, bias=False)
16
  self.w_v = nn.Linear(d_model, d_model, bias=False)
17
- self.w_o = nn.Linear(d_model, d_model)
18
  self.dropout = nn.Dropout(dropout)
19
- self._init_weights()
20
- def _init_weights(self):
21
- """Initialize weights with Xavier uniform"""
 
22
  for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
23
- nn.init.xavier_uniform_(module.weight)
24
- def forward(self, query, key, value, mask=None, pos_encoding=None):
25
- batch_size, seq_len, d_model = query.size()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
27
  K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
28
- V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
 
29
  if pos_encoding is not None:
30
- Q, K = pos_encoding(Q, K)
31
- scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
 
 
32
  if mask is not None:
33
  mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
34
- scores.masked_fill_(mask, float('-inf'))
 
35
  attention_weights = F.softmax(scores, dim=-1)
36
  attention_weights = self.dropout(attention_weights)
37
  attended = torch.matmul(attention_weights, V)
38
- attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
39
- output = self.w_o(attended)
40
- return output
41
  class FeedForward(nn.Module):
42
- """Position-wise feed forward network with GELU activation"""
43
- def __init__(self, d_model, dim_feedforward, dropout=0.1):
44
- super(FeedForward, self).__init__()
 
 
 
 
 
 
 
 
 
45
  self.linear1 = nn.Linear(d_model, dim_feedforward)
46
  self.linear2 = nn.Linear(dim_feedforward, d_model)
47
- self.dropout = nn.Dropout(dropout)
 
48
  nn.init.xavier_uniform_(self.linear1.weight)
49
- nn.init.xavier_uniform_(self.linear2.weight)
50
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
51
  x = F.gelu(self.linear1(x))
52
  x = self.dropout(x)
53
- x = self.linear2(x)
54
- return x
55
  class RecursionRouter(nn.Module):
56
- """Router to decide recursion steps for different types of technical problems"""
57
- def __init__(self, d_model, max_steps=4, router_type="adaptive"):
58
- super(RecursionRouter, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
59
  self.max_steps = max_steps
60
- self.router_type = router_type
61
- if router_type == "adaptive":
 
62
  self.complexity_classifier = nn.Sequential(
63
  nn.Linear(d_model, d_model // 4),
64
  nn.GELU(),
65
- nn.Dropout(0.1),
66
  nn.Linear(d_model // 4, max_steps + 1),
67
  nn.Softmax(dim=-1)
68
  )
69
- elif router_type == "fixed":
70
- self.fixed_steps = max_steps
71
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
72
  if self.router_type == "adaptive":
73
  seq_repr = x.mean(dim=1)
74
  step_probs = self.complexity_classifier(seq_repr)
75
- steps = torch.argmax(step_probs, dim=-1)
76
- return steps
77
- return self.fixed_steps
78
  class RecursiveTransformerLayer(nn.Module):
79
- """Transformer layer with recursive computation capability"""
80
- def __init__(self, d_model, n_heads, dim_feedforward, max_steps=4,
81
- dropout=0.1, router_type="adaptive"):
82
- super(RecursiveTransformerLayer, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  self.max_steps = max_steps
84
- self.d_model = d_model
 
85
  self.attention = MultiHeadAttention(d_model, n_heads, dropout)
86
  self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
87
  self.norm1 = nn.LayerNorm(d_model)
@@ -90,35 +215,75 @@ class RecursiveTransformerLayer(nn.Module):
90
  self.router = RecursionRouter(d_model, max_steps, router_type)
91
  self.step_projections = nn.ModuleList([
92
  nn.Linear(d_model, d_model) for _ in range(max_steps)
93
- ])
94
- def forward(self, x, mask=None, pos_encoding=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  steps = self.router(x)
96
- if isinstance(steps, int):
97
- num_steps = min(steps, self.max_steps)
98
- return self._recursive_forward_fixed(x, mask, num_steps, pos_encoding)
99
- return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
100
- def _recursive_forward_fixed(self, x, mask, num_steps, pos_encoding):
 
 
 
 
 
 
 
101
  device = x.device
102
  batch_size = x.shape[0]
103
- computation_loss = torch.tensor(0.0, device=device)
104
- for step in range(num_steps):
 
105
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
106
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
107
  x = self.norm1(x + self.dropout(attended))
108
  fed_forward = self.feedforward(x)
109
  x = self.norm2(x + self.dropout(fed_forward))
110
- computation_loss += torch.tensor(0.1, device=device) * batch_size
111
- return x, computation_loss
112
- def _recursive_forward_adaptive(self, x, mask, steps, pos_encoding):
 
 
 
 
 
 
 
 
 
113
  batch_size, seq_len, d_model = x.shape
114
  device = x.device
115
  max_batch_steps = int(steps.max().item())
116
- computation_loss = torch.tensor(0.0, device=device)
117
  active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
118
- for step in range(max_batch_steps):
 
119
  step_mask = (steps > step) & active_batches
120
  if not step_mask.any():
121
- break
 
122
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
123
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
124
  attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
@@ -127,25 +292,57 @@ class RecursiveTransformerLayer(nn.Module):
127
  fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
128
  x = self.norm2(x + self.dropout(fed_forward))
129
  computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
130
- active_batches &= (steps > step)
 
131
  return x, computation_loss
 
132
  class MixtureOfRecursions(nn.Module):
133
- """Main model with mixture of recursive transformer layers"""
134
- def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8,
135
- max_steps=4, dim_feedforward=2048, dropout=0.1,
136
- max_seq_len=512, router_type="adaptive", padding_idx=0):
137
- super(MixtureOfRecursions, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  self.d_model = d_model
139
  self.vocab_size = vocab_size
140
- self.padding_idx = padding_idx
 
141
  self.embeddings = TechEmbeddingLayer(
142
  vocab_size=vocab_size,
143
  d_model=d_model,
144
  max_seq_len=max_seq_len,
145
  dropout=dropout,
146
  padding_idx=padding_idx,
147
- pos_encoding="learned"
148
- )
 
149
  self.layers = nn.ModuleList([
150
  RecursiveTransformerLayer(
151
  d_model=d_model,
@@ -155,70 +352,144 @@ class MixtureOfRecursions(nn.Module):
155
  dropout=dropout,
156
  router_type=router_type
157
  ) for _ in range(n_layers)
158
- ])
 
159
  self.final_norm = nn.LayerNorm(d_model)
160
  self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
161
- self._init_weights()
162
- def _init_weights(self):
163
- nn.init.xavier_uniform_(self.lm_head.weight)
164
- def forward(self, input_ids, attention_mask=None):
165
- batch_size, seq_len = input_ids.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
167
  causal_mask = create_causal_mask(seq_len, input_ids.device)
168
- padding_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
169
- combined_mask = padding_mask | causal_mask.unsqueeze(0)
170
  x = self.embeddings(input_ids)
171
- pos_encoding = self.embeddings.get_positional_encoding()
172
- device = x.device
173
- total_computation_loss = torch.tensor(0.0, device=device)
174
  for layer in self.layers:
175
  x, comp_loss = layer(x, combined_mask, pos_encoding)
176
- total_computation_loss += comp_loss
 
177
  x = self.final_norm(x)
178
- logits = self.lm_head(x)
179
- return logits, total_computation_loss
180
- def generate_step(self, input_ids, temperature=1.0, top_k=None, top_p=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  self.eval()
182
  with torch.no_grad():
183
  logits, _ = self.forward(input_ids)
184
- last_logits = logits[:, -1, :] / temperature
 
185
  if top_k is not None:
186
  indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
187
- last_logits[indices_to_remove] = float('-inf')
 
188
  if top_p is not None:
189
  sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
190
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
191
  sorted_indices_to_remove = cumulative_probs > top_p
192
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
193
- sorted_indices_to_remove[..., 0] = 0
194
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
195
- last_logits[indices_to_remove] = float('-inf')
 
196
  probs = F.softmax(last_logits, dim=-1)
197
- next_token = torch.multinomial(probs, num_samples=1)
198
- return next_token
199
  class TextGenerator:
200
- """Text generation utility for the tech model"""
201
- def __init__(self, model, tokenizer, max_length=100, device=None):
 
 
 
 
 
 
 
 
 
 
202
  self.model = model
203
  self.tokenizer = tokenizer
204
  self.max_length = max_length
205
  self.device = device if device else next(model.parameters()).device
206
  self.model.to(self.device)
207
  self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
208
- self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
209
- def generate(self, prompt, method="nucleus", temperature=1.0, top_k=50, top_p=0.9, max_new_tokens=None):
210
- if max_new_tokens is None:
211
- max_new_tokens = self.max_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  input_text = f"<|user|> {prompt}"
213
  input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
214
- input_tensor = torch.tensor([input_ids], device=self.device)
 
215
  self.model.eval()
216
- generated_ids = []
 
217
  with torch.no_grad():
218
  for _ in range(max_new_tokens):
219
  if input_tensor.size(1) > self.max_length:
220
- input_tensor = input_tensor[:, -self.max_length:]
221
- # Generate next token
222
  if method == "greedy":
223
  next_token = self._greedy_generate(input_tensor)
224
  elif method == "sample":
@@ -228,75 +499,101 @@ class TextGenerator:
228
  elif method == "nucleus" or method == "top_p":
229
  next_token = self._nucleus_generate(input_tensor, temperature, top_p)
230
  else:
231
- raise ValueError(f"Unknown generation method: {method}")
 
232
  next_token_id = next_token.item()
233
  generated_ids.append(next_token_id)
234
  input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
 
235
  if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
236
- break
237
- # Decode the full sequence
238
  full_ids = input_ids + generated_ids
239
- full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
240
- # Extract assistant response
241
  if "<|assistant|>" in full_text:
242
  response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
243
  else:
244
- response = full_text.split("<|endoftext|>")[0].strip()
245
- return response if response else "No response generated."
246
- def _greedy_generate(self, input_tensor):
 
 
 
247
  logits, _ = self.model(input_tensor)
248
- return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
249
- def _sample_generate(self, input_tensor, temperature):
 
 
250
  logits, _ = self.model(input_tensor)
251
  logits = logits[:, -1, :] / temperature
252
  probs = F.softmax(logits, dim=-1)
253
- return torch.multinomial(probs, num_samples=1)
254
- def _top_k_generate(self, input_tensor, temperature, top_k):
 
 
255
  logits, _ = self.model(input_tensor)
256
  logits = logits[:, -1, :] / temperature
257
  top_k_logits, top_k_indices = torch.topk(logits, top_k)
258
  probs = F.softmax(top_k_logits, dim=-1)
259
  next_token_idx = torch.multinomial(probs, num_samples=1)
260
- return top_k_indices.gather(-1, next_token_idx)
261
- def _nucleus_generate(self, input_tensor, temperature, top_p):
 
 
262
  return self.model.generate_step(input_tensor, temperature, top_p=top_p)
263
- def count_parameters(model):
 
 
 
 
 
 
 
 
 
 
264
  total_params = sum(p.numel() for p in model.parameters())
265
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
266
  return total_params, trainable_params
 
267
  def main():
268
- vocab_size = 10000
269
- d_model = 512
270
- n_layers = 6
271
- n_heads = 8
272
- seq_len = 128
273
- batch_size = 4
274
  print("Initializing MixtureOfRecursions model...")
275
  model = MixtureOfRecursions(
276
- vocab_size=vocab_size,
277
- d_model=d_model,
278
- n_layers=n_layers,
279
- n_heads=n_heads,
280
- max_steps=4,
281
- dim_feedforward=2048,
282
- dropout=0.1,
283
- router_type="adaptive"
284
- )
 
285
  total_params, trainable_params = count_parameters(model)
286
  print(f"Total parameters: {total_params:,}")
287
- print(f"Trainable parameters: {trainable_params:,}")
 
288
  print("\nTesting forward pass...")
289
- input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
 
290
  attention_mask = torch.ones_like(input_ids)
291
- attention_mask[:, -10:] = 0
 
 
 
 
292
  print(f"Input shape: {input_ids.shape}")
293
- logits, comp_loss = model(input_ids, attention_mask)
294
  print(f"Output logits shape: {logits.shape}")
295
- print(f"Computation loss: {comp_loss}")
296
- print(f"Expected logits shape: ({batch_size}, {seq_len}, {vocab_size})")
 
297
  print("\nTesting generation step...")
298
  next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
299
- print(f"Generated next token: {next_token}")
 
300
  print("\nModel test completed successfully!")
 
301
  if __name__ == "__main__":
302
  main()
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
+ from typing import Optional, Tuple, Union
6
  from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
7
+
8
+ # Constants for default configuration
9
+ DEFAULT_D_MODEL = 512
10
+ DEFAULT_N_HEADS = 8
11
+ DEFAULT_N_LAYERS = 6
12
+ DEFAULT_MAX_STEPS = 4
13
+ DEFAULT_DIM_FEEDFORWARD = 2048
14
+ DEFAULT_DROPOUT = 0.1
15
+ DEFAULT_MAX_SEQ_LEN = 512
16
+ DEFAULT_PADDING_IDX = 0
17
+ DEFAULT_ROUTER_TYPE = "adaptive"
18
+ DEFAULT_VOCAB_SIZE = 10000
19
+
20
  class MultiHeadAttention(nn.Module):
21
+ """Multi-head attention mechanism optimized for technical content."""
22
+
23
+ def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT):
24
+ """
25
+ Initialize multi-head attention.
26
+
27
+ Args:
28
+ d_model (int): Dimension of the model embeddings.
29
+ n_heads (int): Number of attention heads.
30
+ dropout (float): Dropout rate for regularization.
31
+
32
+ Raises:
33
+ ValueError: If d_model is not divisible by n_heads.
34
+ """
35
+ super().__init__()
36
+ if d_model % n_heads != 0:
37
+ raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
38
+
39
  self.d_model = d_model
40
  self.n_heads = n_heads
41
+ self.d_k = d_model // n_heads
42
+
43
  self.w_q = nn.Linear(d_model, d_model, bias=False)
44
  self.w_k = nn.Linear(d_model, d_model, bias=False)
45
  self.w_v = nn.Linear(d_model, d_model, bias=False)
46
+ self.w_o = nn.Linear(d_model, d_model)
47
  self.dropout = nn.Dropout(dropout)
48
+ self._init_weights()
49
+
50
+ def _init_weights(self) -> None:
51
+ """Initialize weights with Xavier uniform initialization."""
52
  for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
53
+ nn.init.xavier_uniform_(module.weight)
54
+ if hasattr(module, 'bias') and module.bias is not None:
55
+ nn.init.zeros_(module.bias)
56
+
57
+ def forward(
58
+ self,
59
+ query: torch.Tensor,
60
+ key: torch.Tensor,
61
+ value: torch.Tensor,
62
+ mask: Optional[torch.Tensor] = None,
63
+ pos_encoding: Optional[nn.Module] = None
64
+ ) -> torch.Tensor:
65
+ """
66
+ Forward pass for multi-head attention.
67
+
68
+ Args:
69
+ query (torch.Tensor): Query tensor of shape (batch_size, seq_len, d_model).
70
+ key (torch.Tensor): Key tensor of shape (batch_size, seq_len, d_model).
71
+ value (torch.Tensor): Value tensor of shape (batch_size, seq_len, d_model).
72
+ mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len, seq_len).
73
+ pos_encoding (Optional[nn.Module]): Positional encoding module (e.g., RoPE).
74
+
75
+ Returns:
76
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
77
+ """
78
+ batch_size, seq_len, _ = query.size()
79
+
80
  Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
81
  K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
82
+ V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
83
+
84
  if pos_encoding is not None:
85
+ Q, K = pos_encoding(Q, K)
86
+
87
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
88
+
89
  if mask is not None:
90
  mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
91
+ scores = scores.masked_fill(mask, float('-inf'))
92
+
93
  attention_weights = F.softmax(scores, dim=-1)
94
  attention_weights = self.dropout(attention_weights)
95
  attended = torch.matmul(attention_weights, V)
96
+ attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
97
+ return self.w_o(attended)
98
+
99
  class FeedForward(nn.Module):
100
+ """Position-wise feed-forward network with GELU activation."""
101
+
102
+ def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT):
103
+ """
104
+ Initialize feed-forward network.
105
+
106
+ Args:
107
+ d_model (int): Dimension of the model embeddings.
108
+ dim_feedforward (int): Dimension of the feed-forward layer.
109
+ dropout (float): Dropout rate for regularization.
110
+ """
111
+ super().__init__()
112
  self.linear1 = nn.Linear(d_model, dim_feedforward)
113
  self.linear2 = nn.Linear(dim_feedforward, d_model)
114
+ self.dropout = nn.Dropout(dropout)
115
+
116
  nn.init.xavier_uniform_(self.linear1.weight)
117
+ nn.init.zeros_(self.linear1.bias)
118
+ nn.init.xavier_uniform_(self.linear2.weight)
119
+ nn.init.zeros_(self.linear2.bias)
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ """
123
+ Forward pass for feed-forward network.
124
+
125
+ Args:
126
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
127
+
128
+ Returns:
129
+ torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model).
130
+ """
131
  x = F.gelu(self.linear1(x))
132
  x = self.dropout(x)
133
+ return self.linear2(x)
134
+
135
  class RecursionRouter(nn.Module):
136
+ """Router to determine recursion steps for technical problem processing."""
137
+
138
+ def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE):
139
+ """
140
+ Initialize recursion router.
141
+
142
+ Args:
143
+ d_model (int): Dimension of the model embeddings.
144
+ max_steps (int): Maximum number of recursion steps.
145
+ router_type (str): Type of router ('adaptive' or 'fixed').
146
+
147
+ Raises:
148
+ ValueError: If router_type is invalid.
149
+ """
150
+ super().__init__()
151
  self.max_steps = max_steps
152
+ self.router_type = router_type.lower()
153
+
154
+ if self.router_type == "adaptive":
155
  self.complexity_classifier = nn.Sequential(
156
  nn.Linear(d_model, d_model // 4),
157
  nn.GELU(),
158
+ nn.Dropout(DEFAULT_DROPOUT),
159
  nn.Linear(d_model // 4, max_steps + 1),
160
  nn.Softmax(dim=-1)
161
  )
162
+ elif self.router_type == "fixed":
163
+ self.register_buffer('fixed_steps', torch.tensor(max_steps, dtype=torch.long))
164
+ else:
165
+ raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.")
166
+
167
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]:
168
+ """
169
+ Determine the number of recursion steps.
170
+
171
+ Args:
172
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
173
+
174
+ Returns:
175
+ Union[torch.Tensor, int]: Number of steps (tensor for adaptive, int for fixed).
176
+ """
177
  if self.router_type == "adaptive":
178
  seq_repr = x.mean(dim=1)
179
  step_probs = self.complexity_classifier(seq_repr)
180
+ return torch.argmax(step_probs, dim=-1)
181
+ return self.fixed_steps.item()
182
+
183
  class RecursiveTransformerLayer(nn.Module):
184
+ """Transformer layer with recursive computation capability."""
185
+
186
+ def __init__(
187
+ self,
188
+ d_model: int,
189
+ n_heads: int,
190
+ dim_feedforward: int,
191
+ max_steps: int = DEFAULT_MAX_STEPS,
192
+ dropout: float = DEFAULT_DROPOUT,
193
+ router_type: str = DEFAULT_ROUTER_TYPE
194
+ ):
195
+ """
196
+ Initialize recursive transformer layer.
197
+
198
+ Args:
199
+ d_model (int): Dimension of the model embeddings.
200
+ n_heads (int): Number of attention heads.
201
+ dim_feedforward (int): Dimension of the feed-forward layer.
202
+ max_steps (int): Maximum number of recursion steps.
203
+ dropout (float): Dropout rate for regularization.
204
+ router_type (str): Type of router ('adaptive' or 'fixed').
205
+ """
206
+ super().__init__()
207
  self.max_steps = max_steps
208
+ self.d_model = d_model
209
+
210
  self.attention = MultiHeadAttention(d_model, n_heads, dropout)
211
  self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
212
  self.norm1 = nn.LayerNorm(d_model)
 
215
  self.router = RecursionRouter(d_model, max_steps, router_type)
216
  self.step_projections = nn.ModuleList([
217
  nn.Linear(d_model, d_model) for _ in range(max_steps)
218
+ ])
219
+
220
+ for proj in self.step_projections:
221
+ nn.init.xavier_uniform_(proj.weight)
222
+ nn.init.zeros_(proj.bias)
223
+
224
+ def forward(
225
+ self,
226
+ x: torch.Tensor,
227
+ mask: Optional[torch.Tensor] = None,
228
+ pos_encoding: Optional[nn.Module] = None
229
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
230
+ """
231
+ Forward pass for recursive transformer layer.
232
+
233
+ Args:
234
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
235
+ mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len, seq_len).
236
+ pos_encoding (Optional[nn.Module]): Positional encoding module (e.g., RoPE).
237
+
238
+ Returns:
239
+ Tuple[torch.Tensor, torch.Tensor]: Output tensor and computation loss.
240
+ """
241
  steps = self.router(x)
242
+ if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps):
243
+ return self._recursive_forward_fixed(x, mask, steps, pos_encoding)
244
+ return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
245
+
246
+ def _recursive_forward_fixed(
247
+ self,
248
+ x: torch.Tensor,
249
+ mask: Optional[torch.Tensor],
250
+ num_steps: int,
251
+ pos_encoding: Optional[nn.Module]
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ """Fixed recursion forward pass."""
254
  device = x.device
255
  batch_size = x.shape[0]
256
+ computation_loss = torch.tensor(0.0, device=device)
257
+
258
+ for step in range(min(num_steps, self.max_steps)):
259
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
260
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
261
  x = self.norm1(x + self.dropout(attended))
262
  fed_forward = self.feedforward(x)
263
  x = self.norm2(x + self.dropout(fed_forward))
264
+ computation_loss += torch.tensor(0.1, device=device) * batch_size
265
+
266
+ return x, computation_loss
267
+
268
+ def _recursive_forward_adaptive(
269
+ self,
270
+ x: torch.Tensor,
271
+ mask: Optional[torch.Tensor],
272
+ steps: torch.Tensor,
273
+ pos_encoding: Optional[nn.Module]
274
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
275
+ """Adaptive recursion forward pass."""
276
  batch_size, seq_len, d_model = x.shape
277
  device = x.device
278
  max_batch_steps = int(steps.max().item())
279
+ computation_loss = torch.tensor(0.0, device=device)
280
  active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
281
+
282
+ for step in range(min(max_batch_steps, self.max_steps)):
283
  step_mask = (steps > step) & active_batches
284
  if not step_mask.any():
285
+ break
286
+
287
  step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
288
  attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
289
  attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
 
292
  fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
293
  x = self.norm2(x + self.dropout(fed_forward))
294
  computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
295
+ active_batches &= (steps > step)
296
+
297
  return x, computation_loss
298
+
299
  class MixtureOfRecursions(nn.Module):
300
+ """Transformer model with mixture of recursive layers for technical content."""
301
+
302
+ def __init__(
303
+ self,
304
+ vocab_size: int,
305
+ d_model: int = DEFAULT_D_MODEL,
306
+ n_layers: int = DEFAULT_N_LAYERS,
307
+ n_heads: int = DEFAULT_N_HEADS,
308
+ max_steps: int = DEFAULT_MAX_STEPS,
309
+ dim_feedforward: int = DEFAULT_DIM_FEEDFORWARD,
310
+ dropout: float = DEFAULT_DROPOUT,
311
+ max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
312
+ router_type: str = DEFAULT_ROUTER_TYPE,
313
+ padding_idx: int = DEFAULT_PADDING_IDX,
314
+ pos_encoding: str = "learned"
315
+ ):
316
+ """
317
+ Initialize the Mixture of Recursions model.
318
+
319
+ Args:
320
+ vocab_size (int): Size of the vocabulary.
321
+ d_model (int): Dimension of the model embeddings.
322
+ n_layers (int): Number of transformer layers.
323
+ n_heads (int): Number of attention heads.
324
+ max_steps (int): Maximum number of recursion steps.
325
+ dim_feedforward (int): Dimension of the feed-forward layer.
326
+ dropout (float): Dropout rate for regularization.
327
+ max_seq_len (int): Maximum sequence length.
328
+ router_type (str): Type of router ('adaptive' or 'fixed').
329
+ padding_idx (int): Index for padding token.
330
+ pos_encoding (str): Type of positional encoding ('learned', 'sinusoidal', 'rope').
331
+ """
332
+ super().__init__()
333
  self.d_model = d_model
334
  self.vocab_size = vocab_size
335
+ self.padding_idx = padding_idx
336
+
337
  self.embeddings = TechEmbeddingLayer(
338
  vocab_size=vocab_size,
339
  d_model=d_model,
340
  max_seq_len=max_seq_len,
341
  dropout=dropout,
342
  padding_idx=padding_idx,
343
+ pos_encoding=pos_encoding
344
+ )
345
+
346
  self.layers = nn.ModuleList([
347
  RecursiveTransformerLayer(
348
  d_model=d_model,
 
352
  dropout=dropout,
353
  router_type=router_type
354
  ) for _ in range(n_layers)
355
+ ])
356
+
357
  self.final_norm = nn.LayerNorm(d_model)
358
  self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
359
+ self._init_weights()
360
+
361
+ def _init_weights(self) -> None:
362
+ """Initialize weights for the language model head."""
363
+ nn.init.xavier_uniform_(self.lm_head.weight)
364
+
365
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
366
+ """
367
+ Forward pass for the model.
368
+
369
+ Args:
370
+ input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
371
+ attention_mask (Optional[torch.Tensor]): Attention mask of shape (batch_size, seq_len).
372
+
373
+ Returns:
374
+ Tuple[torch.Tensor, torch.Tensor]: Logits and total computation loss.
375
+ """
376
+ batch_size, seq_len = input_ids.shape
377
  padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
378
  causal_mask = create_causal_mask(seq_len, input_ids.device)
379
+ combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
380
+
381
  x = self.embeddings(input_ids)
382
+ pos_encoding = self.embeddings.get_positional_encoding()
383
+
384
+ total_computation_loss = torch.tensor(0.0, device=x.device)
385
  for layer in self.layers:
386
  x, comp_loss = layer(x, combined_mask, pos_encoding)
387
+ total_computation_loss += comp_loss
388
+
389
  x = self.final_norm(x)
390
+ logits = self.lm_head(x)
391
+ return logits, total_computation_loss
392
+
393
+ def generate_step(
394
+ self,
395
+ input_ids: torch.Tensor,
396
+ temperature: float = 1.0,
397
+ top_k: Optional[int] = None,
398
+ top_p: Optional[float] = None
399
+ ) -> torch.Tensor:
400
+ """
401
+ Generate the next token for a given input sequence.
402
+
403
+ Args:
404
+ input_ids (torch.Tensor): Input tensor of shape (batch_size, seq_len).
405
+ temperature (float): Temperature for softmax scaling.
406
+ top_k (Optional[int]): Number of top-k tokens to sample from.
407
+ top_p (Optional[float]): Cumulative probability for nucleus sampling.
408
+
409
+ Returns:
410
+ torch.Tensor: Next token IDs of shape (batch_size, 1).
411
+ """
412
  self.eval()
413
  with torch.no_grad():
414
  logits, _ = self.forward(input_ids)
415
+ last_logits = logits[:, -1, :] / temperature
416
+
417
  if top_k is not None:
418
  indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
419
+ last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
420
+
421
  if top_p is not None:
422
  sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
423
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
424
  sorted_indices_to_remove = cumulative_probs > top_p
425
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
426
+ sorted_indices_to_remove[..., 0] = False
427
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
428
+ last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
429
+
430
  probs = F.softmax(last_logits, dim=-1)
431
+ return torch.multinomial(probs, num_samples=1)
432
+
433
  class TextGenerator:
434
+ """Text generation utility for the MixtureOfRecursions model."""
435
+
436
+ def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
437
+ """
438
+ Initialize the text generator.
439
+
440
+ Args:
441
+ model (nn.Module): The transformer model.
442
+ tokenizer (Tokenizer): Tokenizer for encoding/decoding text.
443
+ max_length (int): Maximum sequence length for generation.
444
+ device (Optional[torch.device]): Device to run the model on.
445
+ """
446
  self.model = model
447
  self.tokenizer = tokenizer
448
  self.max_length = max_length
449
  self.device = device if device else next(model.parameters()).device
450
  self.model.to(self.device)
451
  self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
452
+ self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
453
+
454
+ def generate(
455
+ self,
456
+ prompt: str,
457
+ method: str = "nucleus",
458
+ temperature: float = 1.0,
459
+ top_k: Optional[int] = 50,
460
+ top_p: Optional[float] = 0.9,
461
+ max_new_tokens: Optional[int] = None
462
+ ) -> str:
463
+ """
464
+ Generate text based on a prompt.
465
+
466
+ Args:
467
+ prompt (str): Input prompt for generation.
468
+ method (str): Generation method ('greedy', 'sample', 'top_k', 'nucleus').
469
+ temperature (float): Temperature for softmax scaling.
470
+ top_k (Optional[int]): Number of top-k tokens to sample from.
471
+ top_p (Optional[float]): Cumulative probability for nucleus sampling.
472
+ max_new_tokens (Optional[int]): Maximum number of new tokens to generate.
473
+
474
+ Returns:
475
+ str: Generated text response.
476
+
477
+ Raises:
478
+ ValueError: If the generation method is invalid.
479
+ """
480
+ max_new_tokens = max_new_tokens or self.max_length
481
  input_text = f"<|user|> {prompt}"
482
  input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
483
+ input_tensor = torch.tensor([input_ids], device=self.device)
484
+
485
  self.model.eval()
486
+ generated_ids = []
487
+
488
  with torch.no_grad():
489
  for _ in range(max_new_tokens):
490
  if input_tensor.size(1) > self.max_length:
491
+ input_tensor = input_tensor[:, -self.max_length:]
492
+
493
  if method == "greedy":
494
  next_token = self._greedy_generate(input_tensor)
495
  elif method == "sample":
 
499
  elif method == "nucleus" or method == "top_p":
500
  next_token = self._nucleus_generate(input_tensor, temperature, top_p)
501
  else:
502
+ raise ValueError(f"Unknown generation method: {method}")
503
+
504
  next_token_id = next_token.item()
505
  generated_ids.append(next_token_id)
506
  input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
507
+
508
  if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
509
+ break
510
+
511
  full_ids = input_ids + generated_ids
512
+ full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
513
+
514
  if "<|assistant|>" in full_text:
515
  response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
516
  else:
517
+ response = full_text.split("<|endoftext|>")[0].strip()
518
+
519
+ return response if response else "No response generated."
520
+
521
+ def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
522
+ """Generate the next token using greedy decoding."""
523
  logits, _ = self.model(input_tensor)
524
+ return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
525
+
526
+ def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
527
+ """Generate the next token using random sampling."""
528
  logits, _ = self.model(input_tensor)
529
  logits = logits[:, -1, :] / temperature
530
  probs = F.softmax(logits, dim=-1)
531
+ return torch.multinomial(probs, num_samples=1)
532
+
533
+ def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
534
+ """Generate the next token using top-k sampling."""
535
  logits, _ = self.model(input_tensor)
536
  logits = logits[:, -1, :] / temperature
537
  top_k_logits, top_k_indices = torch.topk(logits, top_k)
538
  probs = F.softmax(top_k_logits, dim=-1)
539
  next_token_idx = torch.multinomial(probs, num_samples=1)
540
+ return top_k_indices.gather(-1, next_token_idx)
541
+
542
+ def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
543
+ """Generate the next token using nucleus (top-p) sampling."""
544
  return self.model.generate_step(input_tensor, temperature, top_p=top_p)
545
+
546
+ def count_parameters(model: nn.Module) -> Tuple[int, int]:
547
+ """
548
+ Count total and trainable parameters in the model.
549
+
550
+ Args:
551
+ model (nn.Module): The model to analyze.
552
+
553
+ Returns:
554
+ Tuple[int, int]: Total and trainable parameter counts.
555
+ """
556
  total_params = sum(p.numel() for p in model.parameters())
557
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
558
  return total_params, trainable_params
559
+
560
  def main():
561
+ """Test the MixtureOfRecursions model and its components."""
 
 
 
 
 
562
  print("Initializing MixtureOfRecursions model...")
563
  model = MixtureOfRecursions(
564
+ vocab_size=DEFAULT_VOCAB_SIZE,
565
+ d_model=DEFAULT_D_MODEL,
566
+ n_layers=DEFAULT_N_LAYERS,
567
+ n_heads=DEFAULT_N_HEADS,
568
+ max_steps=DEFAULT_MAX_STEPS,
569
+ dim_feedforward=DEFAULT_DIM_FEEDFORWARD,
570
+ dropout=DEFAULT_DROPOUT,
571
+ router_type=DEFAULT_ROUTER_TYPE
572
+ )
573
+
574
  total_params, trainable_params = count_parameters(model)
575
  print(f"Total parameters: {total_params:,}")
576
+ print(f"Trainable parameters: {trainable_params:,}")
577
+
578
  print("\nTesting forward pass...")
579
+ batch_size, seq_len = 4, 128
580
+ input_ids = torch.randint(0, DEFAULT_VOCAB_SIZE, (batch_size, seq_len))
581
  attention_mask = torch.ones_like(input_ids)
582
+ attention_mask[:, -10:] = 0
583
+
584
+ logits, comp_loss = model(input_ids, attention_mask)
585
+
586
+ assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
587
  print(f"Input shape: {input_ids.shape}")
 
588
  print(f"Output logits shape: {logits.shape}")
589
+ print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
590
+ print(f"Computation loss: {comp_loss:.4f}")
591
+
592
  print("\nTesting generation step...")
593
  next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
594
+ print(f"Generated next token: {next_token.item()}")
595
+
596
  print("\nModel test completed successfully!")
597
+
598
  if __name__ == "__main__":
599
  main()