| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoModel, AutoConfig |
| |
|
| |
|
| | class ScalingLawForecaster(nn.Module): |
| | def __init__( |
| | self, |
| | base_model_name: str = "HuggingFaceTB/SmolLM2-135M", |
| | init_from_pretrained: bool = True, |
| | force_fp32: bool = False, |
| | ): |
| | super().__init__() |
| | self.config = AutoConfig.from_pretrained(base_model_name) |
| | if force_fp32: |
| | self.config.torch_dtype = torch.float32 |
| | if init_from_pretrained: |
| | if force_fp32: |
| | self.base = AutoModel.from_pretrained( |
| | base_model_name, |
| | config=self.config, |
| | torch_dtype=torch.float32, |
| | ) |
| | else: |
| | self.base = AutoModel.from_pretrained(base_model_name, config=self.config) |
| | else: |
| | self.base = AutoModel.from_config(self.config) |
| |
|
| | hidden_size = self.config.hidden_size |
| |
|
| | act_cls = nn.ReLU |
| | self.num_mlp = nn.Sequential( |
| | nn.Linear(1, hidden_size * 2), |
| | act_cls(), |
| | nn.Linear(hidden_size * 2, hidden_size) |
| | ) |
| |
|
| | self.head = nn.Linear(hidden_size, 1) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | is_number_mask: torch.BoolTensor, |
| | number_values_filled: torch.FloatTensor, |
| | attention_mask: torch.BoolTensor = None |
| | ) -> torch.FloatTensor: |
| | """ |
| | Args: |
| | input_ids: (batch, seq_len) |
| | is_number_mask: (batch, seq_len) bool mask for numeric tokens |
| | number_values_filled:(batch, seq_len) float values (0 for non-numeric) |
| | attention_mask: (batch, seq_len) optional |
| | Returns: |
| | logits: (batch, seq_len) scalar predictions per token |
| | """ |
| | |
| | input_ids[input_ids == 49152] = 0 |
| | text_emb = self.base.get_input_embeddings()(input_ids) |
| |
|
| | |
| | flat_vals = number_values_filled.view(-1, 1) |
| | mlp_out = self.num_mlp(flat_vals) |
| | mlp_out = mlp_out.view_as(text_emb) |
| |
|
| | mask = is_number_mask.unsqueeze(-1) |
| | inputs_embeds = torch.where(mask, mlp_out, text_emb) |
| |
|
| | outputs = self.base( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | return_dict=True |
| | ) |
| | hidden = outputs.last_hidden_state |
| |
|
| | |
| | logits = self.head(hidden).squeeze(-1) |
| | return logits |
| |
|
| |
|