Girinath11 commited on
Commit
a17939e
·
verified ·
1 Parent(s): ef28d88

Create model_slm.py

Browse files
Files changed (1) hide show
  1. model_slm.py +302 -0
model_slm.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
6
+ class MultiHeadAttention(nn.Module):
7
+ """Multi-head attention mechanism optimized for technical content"""
8
+ def __init__(self, d_model, n_heads, dropout=0.1):
9
+ super(MultiHeadAttention, self).__init__()
10
+ assert d_model % n_heads == 0
11
+ self.d_model = d_model
12
+ self.n_heads = n_heads
13
+ self.d_k = d_model // n_heads
14
+ self.w_q = nn.Linear(d_model, d_model, bias=False)
15
+ self.w_k = nn.Linear(d_model, d_model, bias=False)
16
+ self.w_v = nn.Linear(d_model, d_model, bias=False)
17
+ self.w_o = nn.Linear(d_model, d_model)
18
+ self.dropout = nn.Dropout(dropout)
19
+ self._init_weights()
20
+ def _init_weights(self):
21
+ """Initialize weights with Xavier uniform"""
22
+ for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
23
+ nn.init.xavier_uniform_(module.weight)
24
+ def forward(self, query, key, value, mask=None, pos_encoding=None):
25
+ batch_size, seq_len, d_model = query.size()
26
+ Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
27
+ K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
28
+ V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
29
+ if pos_encoding is not None:
30
+ Q, K = pos_encoding(Q, K)
31
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
32
+ if mask is not None:
33
+ mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
34
+ scores.masked_fill_(mask, float('-inf'))
35
+ attention_weights = F.softmax(scores, dim=-1)
36
+ attention_weights = self.dropout(attention_weights)
37
+ attended = torch.matmul(attention_weights, V)
38
+ attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
39
+ output = self.w_o(attended)
40
+ return output
41
+ class FeedForward(nn.Module):
42
+ """Position-wise feed forward network with GELU activation"""
43
+ def __init__(self, d_model, dim_feedforward, dropout=0.1):
44
+ super(FeedForward, self).__init__()
45
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
46
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
47
+ self.dropout = nn.Dropout(dropout)
48
+ nn.init.xavier_uniform_(self.linear1.weight)
49
+ nn.init.xavier_uniform_(self.linear2.weight)
50
+ def forward(self, x):
51
+ x = F.gelu(self.linear1(x))
52
+ x = self.dropout(x)
53
+ x = self.linear2(x)
54
+ return x
55
+ class RecursionRouter(nn.Module):
56
+ """Router to decide recursion steps for different types of technical problems"""
57
+ def __init__(self, d_model, max_steps=4, router_type="adaptive"):
58
+ super(RecursionRouter, self).__init__()
59
+ self.max_steps = max_steps
60
+ self.router_type = router_type
61
+ if router_type == "adaptive":
62
+ self.complexity_classifier = nn.Sequential(
63
+ nn.Linear(d_model, d_model // 4),
64
+ nn.GELU(),
65
+ nn.Dropout(0.1),
66
+ nn.Linear(d_model // 4, max_steps + 1),
67
+ nn.Softmax(dim=-1)
68
+ )
69
+ elif router_type == "fixed":
70
+ self.fixed_steps = max_steps
71
+ def forward(self, x):
72
+ if self.router_type == "adaptive":
73
+ seq_repr = x.mean(dim=1)
74
+ step_probs = self.complexity_classifier(seq_repr)
75
+ steps = torch.argmax(step_probs, dim=-1)
76
+ return steps
77
+ return self.fixed_steps
78
+ class RecursiveTransformerLayer(nn.Module):
79
+ """Transformer layer with recursive computation capability"""
80
+ def __init__(self, d_model, n_heads, dim_feedforward, max_steps=4,
81
+ dropout=0.1, router_type="adaptive"):
82
+ super(RecursiveTransformerLayer, self).__init__()
83
+ self.max_steps = max_steps
84
+ self.d_model = d_model
85
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
86
+ self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
87
+ self.norm1 = nn.LayerNorm(d_model)
88
+ self.norm2 = nn.LayerNorm(d_model)
89
+ self.dropout = nn.Dropout(dropout)
90
+ self.router = RecursionRouter(d_model, max_steps, router_type)
91
+ self.step_projections = nn.ModuleList([
92
+ nn.Linear(d_model, d_model) for _ in range(max_steps)
93
+ ])
94
+ def forward(self, x, mask=None, pos_encoding=None):
95
+ steps = self.router(x)
96
+ if isinstance(steps, int):
97
+ num_steps = min(steps, self.max_steps)
98
+ return self._recursive_forward_fixed(x, mask, num_steps, pos_encoding)
99
+ return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
100
+ def _recursive_forward_fixed(self, x, mask, num_steps, pos_encoding):
101
+ device = x.device
102
+ batch_size = x.shape[0]
103
+ computation_loss = torch.tensor(0.0, device=device)
104
+ for step in range(num_steps):
105
+ step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
106
+ attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
107
+ x = self.norm1(x + self.dropout(attended))
108
+ fed_forward = self.feedforward(x)
109
+ x = self.norm2(x + self.dropout(fed_forward))
110
+ computation_loss += torch.tensor(0.1, device=device) * batch_size
111
+ return x, computation_loss
112
+ def _recursive_forward_adaptive(self, x, mask, steps, pos_encoding):
113
+ batch_size, seq_len, d_model = x.shape
114
+ device = x.device
115
+ max_batch_steps = int(steps.max().item())
116
+ computation_loss = torch.tensor(0.0, device=device)
117
+ active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
118
+ for step in range(max_batch_steps):
119
+ step_mask = (steps > step) & active_batches
120
+ if not step_mask.any():
121
+ break
122
+ step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
123
+ attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
124
+ attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
125
+ x = self.norm1(x + self.dropout(attended))
126
+ fed_forward = self.feedforward(x)
127
+ fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
128
+ x = self.norm2(x + self.dropout(fed_forward))
129
+ computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
130
+ active_batches &= (steps > step)
131
+ return x, computation_loss
132
+ class MixtureOfRecursions(nn.Module):
133
+ """Main model with mixture of recursive transformer layers"""
134
+ def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8,
135
+ max_steps=4, dim_feedforward=2048, dropout=0.1,
136
+ max_seq_len=512, router_type="adaptive", padding_idx=0):
137
+ super(MixtureOfRecursions, self).__init__()
138
+ self.d_model = d_model
139
+ self.vocab_size = vocab_size
140
+ self.padding_idx = padding_idx
141
+ self.embeddings = TechEmbeddingLayer(
142
+ vocab_size=vocab_size,
143
+ d_model=d_model,
144
+ max_seq_len=max_seq_len,
145
+ dropout=dropout,
146
+ padding_idx=padding_idx,
147
+ pos_encoding="learned"
148
+ )
149
+ self.layers = nn.ModuleList([
150
+ RecursiveTransformerLayer(
151
+ d_model=d_model,
152
+ n_heads=n_heads,
153
+ dim_feedforward=dim_feedforward,
154
+ max_steps=max_steps,
155
+ dropout=dropout,
156
+ router_type=router_type
157
+ ) for _ in range(n_layers)
158
+ ])
159
+ self.final_norm = nn.LayerNorm(d_model)
160
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
161
+ self._init_weights()
162
+ def _init_weights(self):
163
+ nn.init.xavier_uniform_(self.lm_head.weight)
164
+ def forward(self, input_ids, attention_mask=None):
165
+ batch_size, seq_len = input_ids.shape
166
+ padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
167
+ causal_mask = create_causal_mask(seq_len, input_ids.device)
168
+ padding_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len)
169
+ combined_mask = padding_mask | causal_mask.unsqueeze(0)
170
+ x = self.embeddings(input_ids)
171
+ pos_encoding = self.embeddings.get_positional_encoding()
172
+ device = x.device
173
+ total_computation_loss = torch.tensor(0.0, device=device)
174
+ for layer in self.layers:
175
+ x, comp_loss = layer(x, combined_mask, pos_encoding)
176
+ total_computation_loss += comp_loss
177
+ x = self.final_norm(x)
178
+ logits = self.lm_head(x)
179
+ return logits, total_computation_loss
180
+ def generate_step(self, input_ids, temperature=1.0, top_k=None, top_p=None):
181
+ self.eval()
182
+ with torch.no_grad():
183
+ logits, _ = self.forward(input_ids)
184
+ last_logits = logits[:, -1, :] / temperature
185
+ if top_k is not None:
186
+ indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
187
+ last_logits[indices_to_remove] = float('-inf')
188
+ if top_p is not None:
189
+ sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
190
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
191
+ sorted_indices_to_remove = cumulative_probs > top_p
192
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
193
+ sorted_indices_to_remove[..., 0] = 0
194
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
195
+ last_logits[indices_to_remove] = float('-inf')
196
+ probs = F.softmax(last_logits, dim=-1)
197
+ next_token = torch.multinomial(probs, num_samples=1)
198
+ return next_token
199
+ class TextGenerator:
200
+ """Text generation utility for the tech model"""
201
+ def __init__(self, model, tokenizer, max_length=100, device=None):
202
+ self.model = model
203
+ self.tokenizer = tokenizer
204
+ self.max_length = max_length
205
+ self.device = device if device else next(model.parameters()).device
206
+ self.model.to(self.device)
207
+ self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
208
+ self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
209
+ def generate(self, prompt, method="nucleus", temperature=1.0, top_k=50, top_p=0.9, max_new_tokens=None):
210
+ if max_new_tokens is None:
211
+ max_new_tokens = self.max_length
212
+ input_text = f"<|user|> {prompt}"
213
+ input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
214
+ input_tensor = torch.tensor([input_ids], device=self.device)
215
+ self.model.eval()
216
+ generated_ids = []
217
+ with torch.no_grad():
218
+ for _ in range(max_new_tokens):
219
+ if input_tensor.size(1) > self.max_length:
220
+ input_tensor = input_tensor[:, -self.max_length:]
221
+ # Generate next token
222
+ if method == "greedy":
223
+ next_token = self._greedy_generate(input_tensor)
224
+ elif method == "sample":
225
+ next_token = self._sample_generate(input_tensor, temperature)
226
+ elif method == "top_k":
227
+ next_token = self._top_k_generate(input_tensor, temperature, top_k)
228
+ elif method == "nucleus" or method == "top_p":
229
+ next_token = self._nucleus_generate(input_tensor, temperature, top_p)
230
+ else:
231
+ raise ValueError(f"Unknown generation method: {method}")
232
+ next_token_id = next_token.item()
233
+ generated_ids.append(next_token_id)
234
+ input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
235
+ if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
236
+ break
237
+ # Decode the full sequence
238
+ full_ids = input_ids + generated_ids
239
+ full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
240
+ # Extract assistant response
241
+ if "<|assistant|>" in full_text:
242
+ response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
243
+ else:
244
+ response = full_text.split("<|endoftext|>")[0].strip()
245
+ return response if response else "No response generated."
246
+ def _greedy_generate(self, input_tensor):
247
+ logits, _ = self.model(input_tensor)
248
+ return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
249
+ def _sample_generate(self, input_tensor, temperature):
250
+ logits, _ = self.model(input_tensor)
251
+ logits = logits[:, -1, :] / temperature
252
+ probs = F.softmax(logits, dim=-1)
253
+ return torch.multinomial(probs, num_samples=1)
254
+ def _top_k_generate(self, input_tensor, temperature, top_k):
255
+ logits, _ = self.model(input_tensor)
256
+ logits = logits[:, -1, :] / temperature
257
+ top_k_logits, top_k_indices = torch.topk(logits, top_k)
258
+ probs = F.softmax(top_k_logits, dim=-1)
259
+ next_token_idx = torch.multinomial(probs, num_samples=1)
260
+ return top_k_indices.gather(-1, next_token_idx)
261
+ def _nucleus_generate(self, input_tensor, temperature, top_p):
262
+ return self.model.generate_step(input_tensor, temperature, top_p=top_p)
263
+ def count_parameters(model):
264
+ total_params = sum(p.numel() for p in model.parameters())
265
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
266
+ return total_params, trainable_params
267
+ def main():
268
+ vocab_size = 10000
269
+ d_model = 512
270
+ n_layers = 6
271
+ n_heads = 8
272
+ seq_len = 128
273
+ batch_size = 4
274
+ print("Initializing MixtureOfRecursions model...")
275
+ model = MixtureOfRecursions(
276
+ vocab_size=vocab_size,
277
+ d_model=d_model,
278
+ n_layers=n_layers,
279
+ n_heads=n_heads,
280
+ max_steps=4,
281
+ dim_feedforward=2048,
282
+ dropout=0.1,
283
+ router_type="adaptive"
284
+ )
285
+ total_params, trainable_params = count_parameters(model)
286
+ print(f"Total parameters: {total_params:,}")
287
+ print(f"Trainable parameters: {trainable_params:,}")
288
+ print("\nTesting forward pass...")
289
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
290
+ attention_mask = torch.ones_like(input_ids)
291
+ attention_mask[:, -10:] = 0
292
+ print(f"Input shape: {input_ids.shape}")
293
+ logits, comp_loss = model(input_ids, attention_mask)
294
+ print(f"Output logits shape: {logits.shape}")
295
+ print(f"Computation loss: {comp_loss}")
296
+ print(f"Expected logits shape: ({batch_size}, {seq_len}, {vocab_size})")
297
+ print("\nTesting generation step...")
298
+ next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
299
+ print(f"Generated next token: {next_token}")
300
+ print("\nModel test completed successfully!")
301
+ if __name__ == "__main__":
302
+ main()