Girinath11 commited on
Commit
2c7c331
·
verified ·
1 Parent(s): 820de40

Create modeling_mixture_of_recursions.py

Browse files
Files changed (1) hide show
  1. modeling_mixture_of_recursions.py +182 -0
modeling_mixture_of_recursions.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_mixture_of_recursions.py
2
+ # Create this file in your repository root
3
+
4
+ import torch
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast
7
+ from typing import Optional, Tuple
8
+
9
+ # Import your existing model
10
+ try:
11
+ from model_slm import * # Import everything from your existing model file
12
+ except:
13
+ pass # Will work when uploaded to HF
14
+
15
+ from .configuration_mixture_of_recursions import MixtureOfRecursionsConfig
16
+
17
+
18
+ class MixtureOfRecursionsPreTrainedModel(PreTrainedModel):
19
+ """
20
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
21
+ """
22
+ config_class = MixtureOfRecursionsConfig
23
+ base_model_prefix = "model"
24
+ supports_gradient_checkpointing = False
25
+ _no_split_modules = []
26
+
27
+ def _init_weights(self, module):
28
+ """Initialize the weights"""
29
+ if isinstance(module, torch.nn.Linear):
30
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
31
+ if module.bias is not None:
32
+ module.bias.data.zero_()
33
+ elif isinstance(module, torch.nn.Embedding):
34
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
35
+ if module.padding_idx is not None:
36
+ module.weight.data[module.padding_idx].zero_()
37
+ elif isinstance(module, torch.nn.LayerNorm):
38
+ module.bias.data.zero_()
39
+ module.weight.data.fill_(1.0)
40
+
41
+
42
+ class MixtureOfRecursionsModel(MixtureOfRecursionsPreTrainedModel):
43
+ """
44
+ Wrapper around your existing model to make it compatible with Transformers
45
+ """
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+ self.config = config
49
+
50
+ # This should match your actual model initialization from model_slm.py
51
+ # Replace this with your actual model class name
52
+ # For example: self.model = YourModelClass(config)
53
+
54
+ # Placeholder - update with your actual model architecture
55
+ self.vocab_size = config.vocab_size
56
+ self.hidden_size = config.hidden_size
57
+
58
+ self.post_init()
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ return_dict: Optional[bool] = None,
72
+ ):
73
+ # Your forward pass logic from model_slm.py
74
+ # This is a placeholder - replace with your actual forward implementation
75
+ pass
76
+
77
+
78
+ class MixtureOfRecursionsForCausalLM(MixtureOfRecursionsPreTrainedModel):
79
+ """
80
+ Causal LM head wrapper for your model
81
+ """
82
+ _tied_weights_keys = ["lm_head.weight"]
83
+
84
+ def __init__(self, config):
85
+ super().__init__(config)
86
+ self.model = MixtureOfRecursionsModel(config)
87
+ self.vocab_size = config.vocab_size
88
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
89
+
90
+ self.post_init()
91
+
92
+ def get_input_embeddings(self):
93
+ return self.model.embed_tokens if hasattr(self.model, 'embed_tokens') else None
94
+
95
+ def set_input_embeddings(self, value):
96
+ if hasattr(self.model, 'embed_tokens'):
97
+ self.model.embed_tokens = value
98
+
99
+ def get_output_embeddings(self):
100
+ return self.lm_head
101
+
102
+ def set_output_embeddings(self, new_embeddings):
103
+ self.lm_head = new_embeddings
104
+
105
+ def forward(
106
+ self,
107
+ input_ids: Optional[torch.LongTensor] = None,
108
+ attention_mask: Optional[torch.FloatTensor] = None,
109
+ position_ids: Optional[torch.LongTensor] = None,
110
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
111
+ inputs_embeds: Optional[torch.FloatTensor] = None,
112
+ labels: Optional[torch.LongTensor] = None,
113
+ use_cache: Optional[bool] = None,
114
+ output_attentions: Optional[bool] = None,
115
+ output_hidden_states: Optional[bool] = None,
116
+ return_dict: Optional[bool] = None,
117
+ ):
118
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
119
+
120
+ # Forward pass through model
121
+ outputs = self.model(
122
+ input_ids=input_ids,
123
+ attention_mask=attention_mask,
124
+ position_ids=position_ids,
125
+ past_key_values=past_key_values,
126
+ inputs_embeds=inputs_embeds,
127
+ use_cache=use_cache,
128
+ output_attentions=output_attentions,
129
+ output_hidden_states=output_hidden_states,
130
+ return_dict=return_dict,
131
+ )
132
+
133
+ hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs.last_hidden_state
134
+ logits = self.lm_head(hidden_states)
135
+
136
+ loss = None
137
+ if labels is not None:
138
+ # Shift for causal language modeling
139
+ shift_logits = logits[..., :-1, :].contiguous()
140
+ shift_labels = labels[..., 1:].contiguous()
141
+ loss_fct = torch.nn.CrossEntropyLoss()
142
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
143
+
144
+ if not return_dict:
145
+ output = (logits,) + outputs[1:]
146
+ return ((loss,) + output) if loss is not None else output
147
+
148
+ return CausalLMOutputWithPast(
149
+ loss=loss,
150
+ logits=logits,
151
+ past_key_values=outputs.past_key_values if hasattr(outputs, 'past_key_values') else None,
152
+ hidden_states=outputs.hidden_states if hasattr(outputs, 'hidden_states') else None,
153
+ attentions=outputs.attentions if hasattr(outputs, 'attentions') else None,
154
+ )
155
+
156
+ def prepare_inputs_for_generation(
157
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
158
+ ):
159
+ if past_key_values:
160
+ input_ids = input_ids[:, -1:]
161
+
162
+ position_ids = kwargs.get("position_ids", None)
163
+ if attention_mask is not None and position_ids is None:
164
+ position_ids = attention_mask.long().cumsum(-1) - 1
165
+ position_ids.masked_fill_(attention_mask == 0, 1)
166
+ if past_key_values:
167
+ position_ids = position_ids[:, -1].unsqueeze(-1)
168
+
169
+ if inputs_embeds is not None and past_key_values is None:
170
+ model_inputs = {"inputs_embeds": inputs_embeds}
171
+ else:
172
+ model_inputs = {"input_ids": input_ids}
173
+
174
+ model_inputs.update(
175
+ {
176
+ "position_ids": position_ids,
177
+ "past_key_values": past_key_values,
178
+ "use_cache": kwargs.get("use_cache"),
179
+ "attention_mask": attention_mask,
180
+ }
181
+ )
182
+ return model_inputs