Girinath11 commited on
Commit
9c59aeb
·
verified ·
1 Parent(s): 031a3f7

Update model_slm.py

Browse files
Files changed (1) hide show
  1. model_slm.py +99 -119
model_slm.py CHANGED
@@ -54,7 +54,7 @@ class MixtureOfRecursionsConfig(PretrainedConfig):
54
  self.max_position_embeddings = max_position_embeddings or max_seq_len
55
 
56
  # ============================================================================
57
- # EMBEDDINGS MODULE (merged from embeddings.py)
58
  # ============================================================================
59
 
60
  DEFAULT_BASE = 10000.0
@@ -400,66 +400,113 @@ class RecursiveTransformerLayer(nn.Module):
400
  active_batches &= (steps > step)
401
  return x, computation_loss
402
 
403
- class MixtureOfRecursions(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  """Transformer model with mixture of recursive layers for technical content."""
405
 
406
- def __init__(
407
- self,
408
- vocab_size: int,
409
- d_model: int = DEFAULT_D_MODEL,
410
- n_layers: int = DEFAULT_N_LAYERS,
411
- n_heads: int = DEFAULT_N_HEADS,
412
- max_steps: int = DEFAULT_MAX_STEPS,
413
- dim_feedforward: int = DEFAULT_DIM_FEEDFORWARD,
414
- dropout: float = DEFAULT_DROPOUT,
415
- max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
416
- router_type: str = DEFAULT_ROUTER_TYPE,
417
- padding_idx: int = DEFAULT_PADDING_IDX,
418
- pos_encoding: str = "learned"
419
- ):
420
- super().__init__()
421
- self.d_model = d_model
422
- self.vocab_size = vocab_size
423
- self.padding_idx = padding_idx
424
  self.embeddings = TechEmbeddingLayer(
425
- vocab_size=vocab_size,
426
- d_model=d_model,
427
- max_seq_len=max_seq_len,
428
- dropout=dropout,
429
- padding_idx=padding_idx,
430
- pos_encoding=pos_encoding
431
  )
 
432
  self.layers = nn.ModuleList([
433
  RecursiveTransformerLayer(
434
- d_model=d_model,
435
- n_heads=n_heads,
436
- dim_feedforward=dim_feedforward,
437
- max_steps=max_steps,
438
- dropout=dropout,
439
- router_type=router_type
440
- ) for _ in range(n_layers)
441
  ])
442
- self.final_norm = nn.LayerNorm(d_model)
443
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
444
- self._init_weights()
445
-
446
- def _init_weights(self) -> None:
447
- nn.init.xavier_uniform_(self.lm_head.weight)
448
 
449
- def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
450
  batch_size, seq_len = input_ids.shape
 
 
451
  padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
452
  causal_mask = create_causal_mask(seq_len, input_ids.device)
453
  combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
 
 
454
  x = self.embeddings(input_ids)
455
  pos_encoding = self.embeddings.get_positional_encoding()
 
456
  total_computation_loss = torch.tensor(0.0, device=x.device)
457
  for layer in self.layers:
458
  x, comp_loss = layer(x, combined_mask, pos_encoding)
459
  total_computation_loss += comp_loss
 
460
  x = self.final_norm(x)
461
  logits = self.lm_head(x)
462
- return logits, total_computation_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
  def generate_step(
465
  self,
@@ -470,11 +517,14 @@ class MixtureOfRecursions(nn.Module):
470
  ) -> torch.Tensor:
471
  self.eval()
472
  with torch.no_grad():
473
- logits, _ = self.forward(input_ids)
 
474
  last_logits = logits[:, -1, :] / temperature
 
475
  if top_k is not None:
476
  indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
477
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
 
478
  if top_p is not None:
479
  sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
480
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@@ -483,83 +533,12 @@ class MixtureOfRecursions(nn.Module):
483
  sorted_indices_to_remove[..., 0] = False
484
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
485
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
 
486
  probs = F.softmax(last_logits, dim=-1)
487
  return torch.multinomial(probs, num_samples=1)
488
 
489
- class TextGenerator:
490
- """Text generation utility for the MixtureOfRecursions model."""
491
-
492
- def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
493
- self.model = model
494
- self.tokenizer = tokenizer
495
- self.max_length = max_length
496
- self.device = device if device else next(model.parameters()).device
497
- self.model.to(self.device)
498
- self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
499
- self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
500
-
501
- def generate(
502
- self,
503
- prompt: str,
504
- method: str = "nucleus",
505
- temperature: float = 1.0,
506
- top_k: Optional[int] = 50,
507
- top_p: Optional[float] = 0.9,
508
- max_new_tokens: Optional[int] = None
509
- ) -> str:
510
- max_new_tokens = max_new_tokens or self.max_length
511
- input_text = f"<|user|> {prompt}"
512
- input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
513
- input_tensor = torch.tensor([input_ids], device=self.device)
514
- self.model.eval()
515
- generated_ids = []
516
- with torch.no_grad():
517
- for _ in range(max_new_tokens):
518
- if input_tensor.size(1) > self.max_length:
519
- input_tensor = input_tensor[:, -self.max_length:]
520
- if method == "greedy":
521
- next_token = self._greedy_generate(input_tensor)
522
- elif method == "sample":
523
- next_token = self._sample_generate(input_tensor, temperature)
524
- elif method == "top_k":
525
- next_token = self._top_k_generate(input_tensor, temperature, top_k)
526
- elif method == "nucleus" or method == "top_p":
527
- next_token = self._nucleus_generate(input_tensor, temperature, top_p)
528
- else:
529
- raise ValueError(f"Unknown generation method: {method}")
530
- next_token_id = next_token.item()
531
- generated_ids.append(next_token_id)
532
- input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
533
- if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
534
- break
535
- full_ids = input_ids + generated_ids
536
- full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
537
- if "<|assistant|>" in full_text:
538
- response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
539
- else:
540
- response = full_text.split("<|endoftext|>")[0].strip()
541
- return response if response else "No response generated."
542
-
543
- def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
544
- logits, _ = self.model(input_tensor)
545
- return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
546
-
547
- def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
548
- logits, _ = self.model(input_tensor)
549
- logits = logits[:, -1, :] / temperature
550
- probs = F.softmax(logits, dim=-1)
551
- return torch.multinomial(probs, num_samples=1)
552
-
553
- def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
554
- logits, _ = self.model(input_tensor)
555
- logits = logits[:, -1, :] / temperature
556
- top_k_logits, top_k_indices = torch.topk(logits, top_k)
557
- probs = F.softmax(top_k_logits, dim=-1)
558
- next_token_idx = torch.multinomial(probs, num_samples=1)
559
- return top_k_indices.gather(-1, next_token_idx)
560
-
561
- def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
562
- return self.model.generate_step(input_tensor, temperature, top_p=top_p)
563
 
564
  def count_parameters(model: nn.Module) -> Tuple[int, int]:
565
  total_params = sum(p.numel() for p in model.parameters())
@@ -569,7 +548,7 @@ def count_parameters(model: nn.Module) -> Tuple[int, int]:
569
  def main():
570
  """Test the MixtureOfRecursions model and its components."""
571
  print("Initializing MixtureOfRecursions model...")
572
- model = MixtureOfRecursions(
573
  vocab_size=DEFAULT_VOCAB_SIZE,
574
  d_model=DEFAULT_D_MODEL,
575
  n_layers=DEFAULT_N_LAYERS,
@@ -579,6 +558,7 @@ def main():
579
  dropout=DEFAULT_DROPOUT,
580
  router_type=DEFAULT_ROUTER_TYPE
581
  )
 
582
 
583
  total_params, trainable_params = count_parameters(model)
584
  print(f"Total parameters: {total_params:,}")
@@ -590,13 +570,13 @@ def main():
590
  attention_mask = torch.ones_like(input_ids)
591
  attention_mask[:, -10:] = 0
592
 
593
- logits, comp_loss = model(input_ids, attention_mask)
 
594
 
595
  assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
596
  print(f"Input shape: {input_ids.shape}")
597
  print(f"Output logits shape: {logits.shape}")
598
  print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
599
- print(f"Computation loss: {comp_loss:.4f}")
600
 
601
  print("\nTesting generation step...")
602
  next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
 
54
  self.max_position_embeddings = max_position_embeddings or max_seq_len
55
 
56
  # ============================================================================
57
+ # EMBEDDINGS MODULE
58
  # ============================================================================
59
 
60
  DEFAULT_BASE = 10000.0
 
400
  active_batches &= (steps > step)
401
  return x, computation_loss
402
 
403
+ # ============================================================================
404
+ # PRETRAINED MODEL WRAPPER
405
+ # ============================================================================
406
+
407
+ class MixtureOfRecursionsPreTrainedModel(PreTrainedModel):
408
+ """PreTrainedModel wrapper for MixtureOfRecursions."""
409
+
410
+ config_class = MixtureOfRecursionsConfig
411
+ base_model_prefix = "model"
412
+ supports_gradient_checkpointing = True
413
+
414
+ def _init_weights(self, module):
415
+ """Initialize weights."""
416
+ if isinstance(module, nn.Linear):
417
+ module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
418
+ if module.bias is not None:
419
+ module.bias.data.zero_()
420
+ elif isinstance(module, nn.Embedding):
421
+ module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
422
+ if module.padding_idx is not None:
423
+ module.weight.data[module.padding_idx].zero_()
424
+ elif isinstance(module, nn.LayerNorm):
425
+ module.bias.data.zero_()
426
+ module.weight.data.fill_(1.0)
427
+
428
+ class MixtureOfRecursions(MixtureOfRecursionsPreTrainedModel):
429
  """Transformer model with mixture of recursive layers for technical content."""
430
 
431
+ def __init__(self, config: MixtureOfRecursionsConfig):
432
+ super().__init__(config)
433
+ self.config = config
434
+ self.d_model = config.d_model
435
+ self.vocab_size = config.vocab_size
436
+ self.padding_idx = config.padding_idx
437
+
 
 
 
 
 
 
 
 
 
 
 
438
  self.embeddings = TechEmbeddingLayer(
439
+ vocab_size=config.vocab_size,
440
+ d_model=config.d_model,
441
+ max_seq_len=config.max_seq_len,
442
+ dropout=config.dropout,
443
+ padding_idx=config.padding_idx,
444
+ pos_encoding=config.pos_encoding
445
  )
446
+
447
  self.layers = nn.ModuleList([
448
  RecursiveTransformerLayer(
449
+ d_model=config.d_model,
450
+ n_heads=config.n_heads,
451
+ dim_feedforward=config.dim_feedforward,
452
+ max_steps=config.max_steps,
453
+ dropout=config.dropout,
454
+ router_type=config.router_type
455
+ ) for _ in range(config.n_layers)
456
  ])
457
+
458
+ self.final_norm = nn.LayerNorm(config.d_model)
459
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
460
+
461
+ # Initialize weights
462
+ self.post_init()
463
 
464
+ def forward(
465
+ self,
466
+ input_ids: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ labels: Optional[torch.Tensor] = None,
469
+ return_dict: bool = True
470
+ ):
471
  batch_size, seq_len = input_ids.shape
472
+
473
+ # Create masks
474
  padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
475
  causal_mask = create_causal_mask(seq_len, input_ids.device)
476
  combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
477
+
478
+ # Forward pass
479
  x = self.embeddings(input_ids)
480
  pos_encoding = self.embeddings.get_positional_encoding()
481
+
482
  total_computation_loss = torch.tensor(0.0, device=x.device)
483
  for layer in self.layers:
484
  x, comp_loss = layer(x, combined_mask, pos_encoding)
485
  total_computation_loss += comp_loss
486
+
487
  x = self.final_norm(x)
488
  logits = self.lm_head(x)
489
+
490
+ loss = None
491
+ if labels is not None:
492
+ # Shift logits and labels for language modeling
493
+ shift_logits = logits[..., :-1, :].contiguous()
494
+ shift_labels = labels[..., 1:].contiguous()
495
+ loss_fct = nn.CrossEntropyLoss()
496
+ loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
497
+ loss += 0.01 * total_computation_loss # Add computation loss
498
+
499
+ if not return_dict:
500
+ output = (logits,)
501
+ return ((loss,) + output) if loss is not None else output
502
+
503
+ from transformers.modeling_outputs import CausalLMOutput
504
+ return CausalLMOutput(
505
+ loss=loss,
506
+ logits=logits,
507
+ hidden_states=None,
508
+ attentions=None,
509
+ )
510
 
511
  def generate_step(
512
  self,
 
517
  ) -> torch.Tensor:
518
  self.eval()
519
  with torch.no_grad():
520
+ outputs = self.forward(input_ids, return_dict=True)
521
+ logits = outputs.logits
522
  last_logits = logits[:, -1, :] / temperature
523
+
524
  if top_k is not None:
525
  indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
526
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
527
+
528
  if top_p is not None:
529
  sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
530
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 
533
  sorted_indices_to_remove[..., 0] = False
534
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
535
  last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
536
+
537
  probs = F.softmax(last_logits, dim=-1)
538
  return torch.multinomial(probs, num_samples=1)
539
 
540
+ # Register the model for auto class
541
+ MixtureOfRecursions.register_for_auto_class("AutoModelForCausalLM")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
  def count_parameters(model: nn.Module) -> Tuple[int, int]:
544
  total_params = sum(p.numel() for p in model.parameters())
 
548
  def main():
549
  """Test the MixtureOfRecursions model and its components."""
550
  print("Initializing MixtureOfRecursions model...")
551
+ config = MixtureOfRecursionsConfig(
552
  vocab_size=DEFAULT_VOCAB_SIZE,
553
  d_model=DEFAULT_D_MODEL,
554
  n_layers=DEFAULT_N_LAYERS,
 
558
  dropout=DEFAULT_DROPOUT,
559
  router_type=DEFAULT_ROUTER_TYPE
560
  )
561
+ model = MixtureOfRecursions(config)
562
 
563
  total_params, trainable_params = count_parameters(model)
564
  print(f"Total parameters: {total_params:,}")
 
570
  attention_mask = torch.ones_like(input_ids)
571
  attention_mask[:, -10:] = 0
572
 
573
+ outputs = model(input_ids, attention_mask, return_dict=True)
574
+ logits = outputs.logits
575
 
576
  assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
577
  print(f"Input shape: {input_ids.shape}")
578
  print(f"Output logits shape: {logits.shape}")
579
  print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
 
580
 
581
  print("\nTesting generation step...")
582
  next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)