Girinath11 commited on
Commit
031a3f7
·
verified ·
1 Parent(s): 5a3c70d

Update model_slm.py

Browse files
Files changed (1) hide show
  1. model_slm.py +0 -607
model_slm.py CHANGED
@@ -604,612 +604,5 @@ def main():
604
 
605
  print("\nModel test completed successfully!")
606
 
607
- if __name__ == "__main__":
608
- main()import torch
609
- import torch.nn as nn
610
- import torch.nn.functional as F
611
- import math
612
- from typing import Optional, Tuple, Union, List
613
-
614
- # ============================================================================
615
- # TRANSFORMERS COMPATIBILITY
616
- # ============================================================================
617
- from transformers import PretrainedConfig
618
- from transformers.modeling_utils import PreTrainedModel
619
-
620
- class MixtureOfRecursionsConfig(PretrainedConfig):
621
- """Configuration class for MixtureOfRecursions model."""
622
-
623
- model_type = "mixture_of_recursions"
624
-
625
- def __init__(
626
- self,
627
- vocab_size=31985,
628
- d_model=384,
629
- n_layers=12,
630
- n_heads=6,
631
- max_steps=4,
632
- dim_feedforward=2048,
633
- dropout=0.1,
634
- max_seq_len=128,
635
- router_type="adaptive",
636
- padding_idx=0,
637
- pos_encoding="learned",
638
- hidden_size=None,
639
- num_hidden_layers=None,
640
- num_attention_heads=None,
641
- intermediate_size=None,
642
- max_position_embeddings=None,
643
- **kwargs
644
- ):
645
- super().__init__(**kwargs)
646
- self.vocab_size = vocab_size
647
- self.d_model = d_model
648
- self.n_layers = n_layers
649
- self.n_heads = n_heads
650
- self.max_steps = max_steps
651
- self.dim_feedforward = dim_feedforward
652
- self.dropout = dropout
653
- self.max_seq_len = max_seq_len
654
- self.router_type = router_type
655
- self.padding_idx = padding_idx
656
- self.pos_encoding = pos_encoding
657
- self.hidden_size = hidden_size or d_model
658
- self.num_hidden_layers = num_hidden_layers or n_layers
659
- self.num_attention_heads = num_attention_heads or n_heads
660
- self.intermediate_size = intermediate_size or dim_feedforward
661
- self.max_position_embeddings = max_position_embeddings or max_seq_len
662
-
663
- # ============================================================================
664
- # EMBEDDINGS MODULE (merged from embeddings.py)
665
- # ============================================================================
666
-
667
- DEFAULT_BASE = 10000.0
668
- DEFAULT_CUTOFFS = [2000, 10000]
669
- DEFAULT_DIV_VAL = 4.0
670
-
671
- class PositionalEncoding(nn.Module):
672
- """Sinusoidal positional encoding for transformer models."""
673
-
674
- def __init__(self, d_model: int, max_seq_len: int = 512, dropout: float = 0.1):
675
- super().__init__()
676
- self.d_model = d_model
677
- self.dropout = nn.Dropout(dropout)
678
- pe = torch.zeros(max_seq_len, d_model)
679
- position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
680
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(DEFAULT_BASE) / d_model))
681
- pe[:, 0::2] = torch.sin(position * div_term)
682
- pe[:, 1::2] = torch.cos(position * div_term[:, :-1] if d_model % 2 == 1 else div_term)
683
- self.register_buffer('pe', pe.unsqueeze(0))
684
-
685
- def forward(self, x: torch.Tensor) -> torch.Tensor:
686
- batch_size, seq_len, d_model = x.size()
687
- if d_model != self.d_model:
688
- raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
689
- x = x + self.pe[:, :seq_len]
690
- return self.dropout(x)
691
-
692
- class LearnedPositionalEmbedding(nn.Module):
693
- """Learned positional embeddings for transformer models."""
694
-
695
- def __init__(self, max_seq_len: int, d_model: int, dropout: float = 0.1):
696
- super().__init__()
697
- self.max_seq_len = max_seq_len
698
- self.d_model = d_model
699
- self.pos_embedding = nn.Embedding(max_seq_len, d_model)
700
- self.dropout = nn.Dropout(dropout)
701
- nn.init.normal_(self.pos_embedding.weight, std=0.02)
702
-
703
- def forward(self, x: torch.Tensor) -> torch.Tensor:
704
- batch_size, seq_len, d_model = x.size()
705
- if seq_len > self.max_seq_len:
706
- raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_len}")
707
- if d_model != self.d_model:
708
- raise ValueError(f"Input dimension {d_model} does not match d_model {self.d_model}")
709
- positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
710
- pos_emb = self.pos_embedding(positions)
711
- x = x + pos_emb
712
- return self.dropout(x)
713
-
714
- class RotaryPositionalEmbedding(nn.Module):
715
- """Rotary Positional Embedding (RoPE) for transformer models."""
716
-
717
- def __init__(self, d_model: int, max_seq_len: int = 2048, base: float = DEFAULT_BASE):
718
- super().__init__()
719
- self.d_model = d_model
720
- self.max_seq_len = max_seq_len
721
- self.base = base
722
- inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
723
- self.register_buffer('inv_freq', inv_freq)
724
- self._seq_len_cached = 0
725
- self._cos_cached = None
726
- self._sin_cached = None
727
-
728
- def _update_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
729
- if seq_len > self._seq_len_cached:
730
- self._seq_len_cached = seq_len
731
- t = torch.arange(seq_len, device=device, dtype=torch.float32)
732
- freqs = torch.outer(t, self.inv_freq)
733
- self._cos_cached = freqs.cos().to(dtype)
734
- self._sin_cached = freqs.sin().to(dtype)
735
-
736
- def _rotate_half(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
737
- x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
738
- return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
739
-
740
- def forward(self, q: torch.Tensor, k: torch.Tensor, start_pos: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
741
- batch_size, seq_len, num_heads, head_dim = q.shape
742
- self._update_cos_sin_cache(start_pos + seq_len, q.device, q.dtype)
743
- cos = self._cos_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
744
- sin = self._sin_cached[start_pos:start_pos + seq_len, :head_dim // 2].view(1, seq_len, 1, -1)
745
- q = q.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
746
- k = k.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
747
- q_rot = self._rotate_half(q, cos, sin)
748
- k_rot = self._rotate_half(k, cos, sin)
749
- q_rot = q_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
750
- k_rot = k_rot.reshape(batch_size, num_heads, seq_len, head_dim).transpose(1, 2)
751
- return q_rot, k_rot
752
-
753
- class TechEmbeddingLayer(nn.Module):
754
- """Comprehensive embedding layer with token and positional embeddings."""
755
-
756
- def __init__(
757
- self,
758
- vocab_size: int,
759
- d_model: int,
760
- max_seq_len: int = 512,
761
- dropout: float = 0.1,
762
- padding_idx: int = 0,
763
- pos_encoding: str = "learned",
764
- layer_norm: bool = True,
765
- ):
766
- super().__init__()
767
- self.d_model = d_model
768
- self.vocab_size = vocab_size
769
- self.padding_idx = padding_idx
770
- self.pos_encoding_type = pos_encoding.lower()
771
- self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
772
-
773
- if pos_encoding == "sinusoidal":
774
- self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
775
- elif pos_encoding == "learned":
776
- self.pos_encoding = LearnedPositionalEmbedding(max_seq_len, d_model, dropout)
777
- elif pos_encoding == "rope":
778
- self.pos_encoding = RotaryPositionalEmbedding(d_model, max_seq_len)
779
- else:
780
- raise ValueError(f"Unknown positional encoding type: {pos_encoding}")
781
-
782
- self.layer_norm = nn.LayerNorm(d_model) if layer_norm else nn.Identity()
783
- self.dropout = nn.Dropout(dropout)
784
- self._init_weights()
785
-
786
- def _init_weights(self) -> None:
787
- nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
788
- if self.padding_idx is not None:
789
- nn.init.constant_(self.token_embedding.weight[self.padding_idx], 0.0)
790
-
791
- def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
792
- if (input_ids >= self.vocab_size).any():
793
- raise ValueError(f"Input IDs contain values >= vocab_size ({self.vocab_size})")
794
- embeddings = self.token_embedding(input_ids)
795
- if self.pos_encoding_type != "rope":
796
- embeddings = self.pos_encoding(embeddings)
797
- embeddings = self.layer_norm(embeddings)
798
- return self.dropout(embeddings)
799
-
800
- def get_positional_encoding(self) -> Optional[nn.Module]:
801
- return self.pos_encoding if self.pos_encoding_type == "rope" else None
802
-
803
- def create_padding_mask(input_ids: torch.Tensor, padding_idx: int = 0) -> torch.Tensor:
804
- return input_ids == padding_idx
805
-
806
- def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
807
- return torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
808
-
809
- # ============================================================================
810
- # MODEL CONSTANTS
811
- # ============================================================================
812
-
813
- DEFAULT_D_MODEL = 512
814
- DEFAULT_N_HEADS = 8
815
- DEFAULT_N_LAYERS = 6
816
- DEFAULT_MAX_STEPS = 4
817
- DEFAULT_DIM_FEEDFORWARD = 2048
818
- DEFAULT_DROPOUT = 0.1
819
- DEFAULT_MAX_SEQ_LEN = 512
820
- DEFAULT_PADDING_IDX = 0
821
- DEFAULT_ROUTER_TYPE = "adaptive"
822
- DEFAULT_VOCAB_SIZE = 10000
823
-
824
- # ============================================================================
825
- # MODEL COMPONENTS
826
- # ============================================================================
827
-
828
- class MultiHeadAttention(nn.Module):
829
- """Multi-head attention mechanism optimized for technical content."""
830
-
831
- def __init__(self, d_model: int, n_heads: int, dropout: float = DEFAULT_DROPOUT):
832
- super().__init__()
833
- if d_model % n_heads != 0:
834
- raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
835
- self.d_model = d_model
836
- self.n_heads = n_heads
837
- self.d_k = d_model // n_heads
838
- self.w_q = nn.Linear(d_model, d_model, bias=False)
839
- self.w_k = nn.Linear(d_model, d_model, bias=False)
840
- self.w_v = nn.Linear(d_model, d_model, bias=False)
841
- self.w_o = nn.Linear(d_model, d_model)
842
- self.dropout = nn.Dropout(dropout)
843
- self._init_weights()
844
-
845
- def _init_weights(self) -> None:
846
- for module in [self.w_q, self.w_k, self.w_v, self.w_o]:
847
- nn.init.xavier_uniform_(module.weight)
848
- if hasattr(module, 'bias') and module.bias is not None:
849
- nn.init.zeros_(module.bias)
850
-
851
- def forward(
852
- self,
853
- query: torch.Tensor,
854
- key: torch.Tensor,
855
- value: torch.Tensor,
856
- mask: Optional[torch.Tensor] = None,
857
- pos_encoding: Optional[nn.Module] = None
858
- ) -> torch.Tensor:
859
- batch_size, seq_len, _ = query.size()
860
- Q = self.w_q(query).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
861
- K = self.w_k(key).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
862
- V = self.w_v(value).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
863
-
864
- if pos_encoding is not None:
865
- Q, K = pos_encoding(Q, K)
866
-
867
- scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
868
-
869
- if mask is not None:
870
- mask = mask.unsqueeze(1).expand(batch_size, self.n_heads, seq_len, seq_len)
871
- scores = scores.masked_fill(mask, float('-inf'))
872
-
873
- attention_weights = F.softmax(scores, dim=-1)
874
- attention_weights = self.dropout(attention_weights)
875
- attended = torch.matmul(attention_weights, V)
876
- attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
877
- return self.w_o(attended)
878
-
879
- class FeedForward(nn.Module):
880
- """Position-wise feed-forward network with GELU activation."""
881
-
882
- def __init__(self, d_model: int, dim_feedforward: int, dropout: float = DEFAULT_DROPOUT):
883
- super().__init__()
884
- self.linear1 = nn.Linear(d_model, dim_feedforward)
885
- self.linear2 = nn.Linear(dim_feedforward, d_model)
886
- self.dropout = nn.Dropout(dropout)
887
- nn.init.xavier_uniform_(self.linear1.weight)
888
- nn.init.zeros_(self.linear1.bias)
889
- nn.init.xavier_uniform_(self.linear2.weight)
890
- nn.init.zeros_(self.linear2.bias)
891
-
892
- def forward(self, x: torch.Tensor) -> torch.Tensor:
893
- x = F.gelu(self.linear1(x))
894
- x = self.dropout(x)
895
- return self.linear2(x)
896
-
897
- class RecursionRouter(nn.Module):
898
- """Router to determine recursion steps for technical problem processing."""
899
-
900
- def __init__(self, d_model: int, max_steps: int = DEFAULT_MAX_STEPS, router_type: str = DEFAULT_ROUTER_TYPE):
901
- super().__init__()
902
- self.max_steps = max_steps
903
- self.router_type = router_type.lower()
904
-
905
- if self.router_type == "adaptive":
906
- self.complexity_classifier = nn.Sequential(
907
- nn.Linear(d_model, d_model // 4),
908
- nn.GELU(),
909
- nn.Dropout(DEFAULT_DROPOUT),
910
- nn.Linear(d_model // 4, max_steps + 1),
911
- nn.Softmax(dim=-1)
912
- )
913
- elif self.router_type == "fixed":
914
- self.register_buffer('fixed_steps', torch.tensor(max_steps, dtype=torch.long))
915
- else:
916
- raise ValueError(f"Invalid router_type: {router_type}. Choose 'adaptive' or 'fixed'.")
917
-
918
- def forward(self, x: torch.Tensor) -> Union[torch.Tensor, int]:
919
- if self.router_type == "adaptive":
920
- seq_repr = x.mean(dim=1)
921
- step_probs = self.complexity_classifier(seq_repr)
922
- return torch.argmax(step_probs, dim=-1)
923
- return self.fixed_steps.item()
924
-
925
- class RecursiveTransformerLayer(nn.Module):
926
- """Transformer layer with recursive computation capability."""
927
-
928
- def __init__(
929
- self,
930
- d_model: int,
931
- n_heads: int,
932
- dim_feedforward: int,
933
- max_steps: int = DEFAULT_MAX_STEPS,
934
- dropout: float = DEFAULT_DROPOUT,
935
- router_type: str = DEFAULT_ROUTER_TYPE
936
- ):
937
- super().__init__()
938
- self.max_steps = max_steps
939
- self.d_model = d_model
940
- self.attention = MultiHeadAttention(d_model, n_heads, dropout)
941
- self.feedforward = FeedForward(d_model, dim_feedforward, dropout)
942
- self.norm1 = nn.LayerNorm(d_model)
943
- self.norm2 = nn.LayerNorm(d_model)
944
- self.dropout = nn.Dropout(dropout)
945
- self.router = RecursionRouter(d_model, max_steps, router_type)
946
- self.step_projections = nn.ModuleList([
947
- nn.Linear(d_model, d_model) for _ in range(max_steps)
948
- ])
949
- for proj in self.step_projections:
950
- nn.init.xavier_uniform_(proj.weight)
951
- nn.init.zeros_(proj.bias)
952
-
953
- def forward(
954
- self,
955
- x: torch.Tensor,
956
- mask: Optional[torch.Tensor] = None,
957
- pos_encoding: Optional[nn.Module] = None
958
- ) -> Tuple[torch.Tensor, torch.Tensor]:
959
- steps = self.router(x)
960
- if isinstance(steps, (int, torch.Tensor)) and not torch.is_tensor(steps):
961
- return self._recursive_forward_fixed(x, mask, steps, pos_encoding)
962
- return self._recursive_forward_adaptive(x, mask, steps, pos_encoding)
963
-
964
- def _recursive_forward_fixed(
965
- self,
966
- x: torch.Tensor,
967
- mask: Optional[torch.Tensor],
968
- num_steps: int,
969
- pos_encoding: Optional[nn.Module]
970
- ) -> Tuple[torch.Tensor, torch.Tensor]:
971
- device = x.device
972
- batch_size = x.shape[0]
973
- computation_loss = torch.tensor(0.0, device=device)
974
- for step in range(min(num_steps, self.max_steps)):
975
- step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
976
- attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
977
- x = self.norm1(x + self.dropout(attended))
978
- fed_forward = self.feedforward(x)
979
- x = self.norm2(x + self.dropout(fed_forward))
980
- computation_loss += torch.tensor(0.1, device=device) * batch_size
981
- return x, computation_loss
982
-
983
- def _recursive_forward_adaptive(
984
- self,
985
- x: torch.Tensor,
986
- mask: Optional[torch.Tensor],
987
- steps: torch.Tensor,
988
- pos_encoding: Optional[nn.Module]
989
- ) -> Tuple[torch.Tensor, torch.Tensor]:
990
- batch_size, seq_len, d_model = x.shape
991
- device = x.device
992
- max_batch_steps = int(steps.max().item())
993
- computation_loss = torch.tensor(0.0, device=device)
994
- active_batches = torch.ones(batch_size, device=device, dtype=torch.bool)
995
- for step in range(min(max_batch_steps, self.max_steps)):
996
- step_mask = (steps > step) & active_batches
997
- if not step_mask.any():
998
- break
999
- step_input = self.step_projections[step](x) if step < len(self.step_projections) else x
1000
- attended = self.attention(step_input, step_input, step_input, mask, pos_encoding)
1001
- attended = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), attended, torch.zeros_like(attended))
1002
- x = self.norm1(x + self.dropout(attended))
1003
- fed_forward = self.feedforward(x)
1004
- fed_forward = torch.where(step_mask.unsqueeze(-1).unsqueeze(-1), fed_forward, torch.zeros_like(fed_forward))
1005
- x = self.norm2(x + self.dropout(fed_forward))
1006
- computation_loss += torch.tensor(0.1, device=device) * step_mask.sum()
1007
- active_batches &= (steps > step)
1008
- return x, computation_loss
1009
-
1010
- class MixtureOfRecursions(nn.Module):
1011
- """Transformer model with mixture of recursive layers for technical content."""
1012
-
1013
- def __init__(
1014
- self,
1015
- vocab_size: int,
1016
- d_model: int = DEFAULT_D_MODEL,
1017
- n_layers: int = DEFAULT_N_LAYERS,
1018
- n_heads: int = DEFAULT_N_HEADS,
1019
- max_steps: int = DEFAULT_MAX_STEPS,
1020
- dim_feedforward: int = DEFAULT_DIM_FEEDFORWARD,
1021
- dropout: float = DEFAULT_DROPOUT,
1022
- max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
1023
- router_type: str = DEFAULT_ROUTER_TYPE,
1024
- padding_idx: int = DEFAULT_PADDING_IDX,
1025
- pos_encoding: str = "learned"
1026
- ):
1027
- super().__init__()
1028
- self.d_model = d_model
1029
- self.vocab_size = vocab_size
1030
- self.padding_idx = padding_idx
1031
- self.embeddings = TechEmbeddingLayer(
1032
- vocab_size=vocab_size,
1033
- d_model=d_model,
1034
- max_seq_len=max_seq_len,
1035
- dropout=dropout,
1036
- padding_idx=padding_idx,
1037
- pos_encoding=pos_encoding
1038
- )
1039
- self.layers = nn.ModuleList([
1040
- RecursiveTransformerLayer(
1041
- d_model=d_model,
1042
- n_heads=n_heads,
1043
- dim_feedforward=dim_feedforward,
1044
- max_steps=max_steps,
1045
- dropout=dropout,
1046
- router_type=router_type
1047
- ) for _ in range(n_layers)
1048
- ])
1049
- self.final_norm = nn.LayerNorm(d_model)
1050
- self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
1051
- self._init_weights()
1052
-
1053
- def _init_weights(self) -> None:
1054
- nn.init.xavier_uniform_(self.lm_head.weight)
1055
-
1056
- def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
1057
- batch_size, seq_len = input_ids.shape
1058
- padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
1059
- causal_mask = create_causal_mask(seq_len, input_ids.device)
1060
- combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
1061
- x = self.embeddings(input_ids)
1062
- pos_encoding = self.embeddings.get_positional_encoding()
1063
- total_computation_loss = torch.tensor(0.0, device=x.device)
1064
- for layer in self.layers:
1065
- x, comp_loss = layer(x, combined_mask, pos_encoding)
1066
- total_computation_loss += comp_loss
1067
- x = self.final_norm(x)
1068
- logits = self.lm_head(x)
1069
- return logits, total_computation_loss
1070
-
1071
- def generate_step(
1072
- self,
1073
- input_ids: torch.Tensor,
1074
- temperature: float = 1.0,
1075
- top_k: Optional[int] = None,
1076
- top_p: Optional[float] = None
1077
- ) -> torch.Tensor:
1078
- self.eval()
1079
- with torch.no_grad():
1080
- logits, _ = self.forward(input_ids)
1081
- last_logits = logits[:, -1, :] / temperature
1082
- if top_k is not None:
1083
- indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
1084
- last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
1085
- if top_p is not None:
1086
- sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
1087
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1088
- sorted_indices_to_remove = cumulative_probs > top_p
1089
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1090
- sorted_indices_to_remove[..., 0] = False
1091
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1092
- last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
1093
- probs = F.softmax(last_logits, dim=-1)
1094
- return torch.multinomial(probs, num_samples=1)
1095
-
1096
- class TextGenerator:
1097
- """Text generation utility for the MixtureOfRecursions model."""
1098
-
1099
- def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
1100
- self.model = model
1101
- self.tokenizer = tokenizer
1102
- self.max_length = max_length
1103
- self.device = device if device else next(model.parameters()).device
1104
- self.model.to(self.device)
1105
- self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
1106
- self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
1107
-
1108
- def generate(
1109
- self,
1110
- prompt: str,
1111
- method: str = "nucleus",
1112
- temperature: float = 1.0,
1113
- top_k: Optional[int] = 50,
1114
- top_p: Optional[float] = 0.9,
1115
- max_new_tokens: Optional[int] = None
1116
- ) -> str:
1117
- max_new_tokens = max_new_tokens or self.max_length
1118
- input_text = f"<|user|> {prompt}"
1119
- input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
1120
- input_tensor = torch.tensor([input_ids], device=self.device)
1121
- self.model.eval()
1122
- generated_ids = []
1123
- with torch.no_grad():
1124
- for _ in range(max_new_tokens):
1125
- if input_tensor.size(1) > self.max_length:
1126
- input_tensor = input_tensor[:, -self.max_length:]
1127
- if method == "greedy":
1128
- next_token = self._greedy_generate(input_tensor)
1129
- elif method == "sample":
1130
- next_token = self._sample_generate(input_tensor, temperature)
1131
- elif method == "top_k":
1132
- next_token = self._top_k_generate(input_tensor, temperature, top_k)
1133
- elif method == "nucleus" or method == "top_p":
1134
- next_token = self._nucleus_generate(input_tensor, temperature, top_p)
1135
- else:
1136
- raise ValueError(f"Unknown generation method: {method}")
1137
- next_token_id = next_token.item()
1138
- generated_ids.append(next_token_id)
1139
- input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
1140
- if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
1141
- break
1142
- full_ids = input_ids + generated_ids
1143
- full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
1144
- if "<|assistant|>" in full_text:
1145
- response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
1146
- else:
1147
- response = full_text.split("<|endoftext|>")[0].strip()
1148
- return response if response else "No response generated."
1149
-
1150
- def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
1151
- logits, _ = self.model(input_tensor)
1152
- return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
1153
-
1154
- def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
1155
- logits, _ = self.model(input_tensor)
1156
- logits = logits[:, -1, :] / temperature
1157
- probs = F.softmax(logits, dim=-1)
1158
- return torch.multinomial(probs, num_samples=1)
1159
-
1160
- def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
1161
- logits, _ = self.model(input_tensor)
1162
- logits = logits[:, -1, :] / temperature
1163
- top_k_logits, top_k_indices = torch.topk(logits, top_k)
1164
- probs = F.softmax(top_k_logits, dim=-1)
1165
- next_token_idx = torch.multinomial(probs, num_samples=1)
1166
- return top_k_indices.gather(-1, next_token_idx)
1167
-
1168
- def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
1169
- return self.model.generate_step(input_tensor, temperature, top_p=top_p)
1170
-
1171
- def count_parameters(model: nn.Module) -> Tuple[int, int]:
1172
- total_params = sum(p.numel() for p in model.parameters())
1173
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1174
- return total_params, trainable_params
1175
-
1176
- def main():
1177
- """Test the MixtureOfRecursions model and its components."""
1178
- print("Initializing MixtureOfRecursions model...")
1179
- model = MixtureOfRecursions(
1180
- vocab_size=DEFAULT_VOCAB_SIZE,
1181
- d_model=DEFAULT_D_MODEL,
1182
- n_layers=DEFAULT_N_LAYERS,
1183
- n_heads=DEFAULT_N_HEADS,
1184
- max_steps=DEFAULT_MAX_STEPS,
1185
- dim_feedforward=DEFAULT_DIM_FEEDFORWARD,
1186
- dropout=DEFAULT_DROPOUT,
1187
- router_type=DEFAULT_ROUTER_TYPE
1188
- )
1189
-
1190
- total_params, trainable_params = count_parameters(model)
1191
- print(f"Total parameters: {total_params:,}")
1192
- print(f"Trainable parameters: {trainable_params:,}")
1193
-
1194
- print("\nTesting forward pass...")
1195
- batch_size, seq_len = 4, 128
1196
- input_ids = torch.randint(0, DEFAULT_VOCAB_SIZE, (batch_size, seq_len))
1197
- attention_mask = torch.ones_like(input_ids)
1198
- attention_mask[:, -10:] = 0
1199
-
1200
- logits, comp_loss = model(input_ids, attention_mask)
1201
-
1202
- assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
1203
- print(f"Input shape: {input_ids.shape}")
1204
- print(f"Output logits shape: {logits.shape}")
1205
- print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
1206
- print(f"Computation loss: {comp_loss:.4f}")
1207
-
1208
- print("\nTesting generation step...")
1209
- next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
1210
- print(f"Generated next token: {next_token.item()}")
1211
-
1212
- print("\nModel test completed successfully!")
1213
-
1214
  if __name__ == "__main__":
1215
  main()
 
604
 
605
  print("\nModel test completed successfully!")
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  if __name__ == "__main__":
608
  main()