Girinath11 commited on
Commit
752c496
·
verified ·
1 Parent(s): fdd8947

Update model_slm.py

Browse files
Files changed (1) hide show
  1. model_slm.py +58 -0
model_slm.py CHANGED
@@ -5,6 +5,64 @@ import math
5
  from typing import Optional, Tuple, Union
6
  from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Constants for default configuration
9
  DEFAULT_D_MODEL = 512
10
  DEFAULT_N_HEADS = 8
 
5
  from typing import Optional, Tuple, Union
6
  from embeddings import TechEmbeddingLayer, create_padding_mask, create_causal_mask
7
 
8
+ # ============================================================================
9
+ # TRANSFORMERS COMPATIBILITY - ADD THIS SECTION
10
+ # ============================================================================
11
+ from transformers import PretrainedConfig
12
+ from transformers.modeling_utils import PreTrainedModel
13
+
14
+ class MixtureOfRecursionsConfig(PretrainedConfig):
15
+ """Configuration class for MixtureOfRecursions model."""
16
+
17
+ model_type = "mixture_of_recursions"
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=31985,
22
+ d_model=384,
23
+ n_layers=12,
24
+ n_heads=6,
25
+ max_steps=4,
26
+ dim_feedforward=2048,
27
+ dropout=0.1,
28
+ max_seq_len=128,
29
+ router_type="adaptive",
30
+ padding_idx=0,
31
+ pos_encoding="learned",
32
+ # Transformers standard names (for compatibility)
33
+ hidden_size=None,
34
+ num_hidden_layers=None,
35
+ num_attention_heads=None,
36
+ intermediate_size=None,
37
+ max_position_embeddings=None,
38
+ **kwargs
39
+ ):
40
+ super().__init__(**kwargs)
41
+
42
+ # Your model's parameters
43
+ self.vocab_size = vocab_size
44
+ self.d_model = d_model
45
+ self.n_layers = n_layers
46
+ self.n_heads = n_heads
47
+ self.max_steps = max_steps
48
+ self.dim_feedforward = dim_feedforward
49
+ self.dropout = dropout
50
+ self.max_seq_len = max_seq_len
51
+ self.router_type = router_type
52
+ self.padding_idx = padding_idx
53
+ self.pos_encoding = pos_encoding
54
+
55
+ # Transformers standard aliases (for compatibility)
56
+ self.hidden_size = hidden_size or d_model
57
+ self.num_hidden_layers = num_hidden_layers or n_layers
58
+ self.num_attention_heads = num_attention_heads or n_heads
59
+ self.intermediate_size = intermediate_size or dim_feedforward
60
+ self.max_position_embeddings = max_position_embeddings or max_seq_len
61
+
62
+ # ============================================================================
63
+ # END TRANSFORMERS COMPATIBILITY SECTION
64
+ # ============================================================================
65
+
66
  # Constants for default configuration
67
  DEFAULT_D_MODEL = 512
68
  DEFAULT_N_HEADS = 8