Butanium commited on
Commit
510ebe9
·
verified ·
1 Parent(s): 6e0cced

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +24 -28
README.md CHANGED
@@ -1,35 +1,31 @@
1
- ---
2
- {}
3
- ---
4
- # Zero-Layer Simple Transformer
5
-
6
- A 0-layer transformer described in [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html).
7
-
8
- ## Usage
9
-
10
  ```python
11
- from transformers import LlamaConfig
12
- from migrate_models import ZeroLayerTransformer
13
-
14
- # Load the model
15
- model = ZeroLayerTransformer.from_pretrained('Butanium/simple-stories-zero-layer-simple-transformer')
 
 
16
 
17
- # Or create from config
18
- config = LlamaConfig(vocab_size=4096, hidden_size=128, num_hidden_layers=0)
19
- model = ZeroLayerTransformer(config)
20
- ```
21
 
22
- ## Model Architecture
 
 
 
 
 
 
 
23
 
24
- This model consists of only:
25
- - Token embeddings
26
- - Linear output head (no transformer layers)
27
 
28
- It serves as a baseline for understanding transformer circuits and the importance of attention layers.
29
 
30
- ## Training Details
31
 
32
- - Trained on SimpleStories dataset
33
- - Vocabulary size: 4096
34
- - Hidden size: 128
35
- - No transformer layers (0-layer architecture)
 
1
+ 0-layer transformer described in [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html).
2
+ Load with
 
 
 
 
 
 
 
3
  ```python
4
+ class ZeroLayerTransformer(PreTrainedModel):
5
+ config_class = LlamaConfig
6
+
7
+ def __init__(self, config: LlamaConfig):
8
+ super().__init__(config)
9
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
10
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
11
 
12
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
13
+ hidden_states = self.embed_tokens(input_ids)
14
+ logits = self.lm_head(hidden_states)
 
15
 
16
+ loss = None
17
+ if labels is not None:
18
+ shift_logits = logits[..., :-1, :].contiguous()
19
+ shift_labels = labels[..., 1:].contiguous()
20
+ loss_fct = nn.CrossEntropyLoss()
21
+ loss = loss_fct(
22
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
23
+ )
24
 
25
+ return {"loss": loss, "logits": logits}
 
 
26
 
 
27
 
 
28
 
29
+ model = ZeroLayerTransformer.from_pretrained('Butanium/simple-stories-zero-layer-simple-transformer')
30
+ ```
31
+ The model is trained on the SimpleStories dataset.