flopml commited on
Commit
2fc11ed
·
1 Parent(s): 897c036

adding sources

Browse files
Files changed (4) hide show
  1. model.py +162 -0
  2. requirements.txt +3 -0
  3. train.py +113 -0
  4. util.py +45 -0
model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Architecture implementation
2
+ # TODO: comment and rename variables / clean code
3
+
4
+
5
+ # https://arxiv.org/abs/2410.01201v1
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ # appendix B
13
+ # https://github.com/glassroom/heinsen_sequence
14
+
15
+ def heinsen_associative_scan_log(log_coeffs, log_values):
16
+ a_star = log_coeffs.cumsum(dim = 1)
17
+ log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1)
18
+ log_h = a_star + log_h0_plus_b_star
19
+ return log_h.exp()
20
+
21
+ # appendix B.3
22
+
23
+ def g(x): return torch.where(x >= 0, x + 0.5, x.sigmoid())
24
+ def log_g(x): return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))
25
+
26
+ # log-space version of minGRU - B.3.1
27
+ # they enforce the hidden states to be positive
28
+
29
+ class minGRU(nn.Module):
30
+ def __init__(self, d_model, d_inner):
31
+ super().__init__()
32
+
33
+ self.hidden_proj = nn.Linear(d_model, d_inner, bias=False)
34
+ self.gate_proj = nn.Linear(d_model, d_inner, bias=False)
35
+ self.out_proj = nn.Linear(d_inner, d_model, bias=False)
36
+
37
+
38
+ def step(self, x, h_prev=None):
39
+ hidden = self.hidden_proj(x)
40
+ gate = self.gate_proj(x)
41
+
42
+ h_prev = h_prev.detach() if h_prev is not None else None
43
+
44
+ hidden = g(hidden)
45
+ gate = gate.sigmoid()
46
+ out = torch.lerp(h_prev, hidden, gate) if h_prev is not None else (hidden * gate)
47
+
48
+ h_next = out[:, -1:]
49
+ out = self.out_proj(out)
50
+
51
+ return out, h_next
52
+
53
+
54
+ def forward(self, x, h_prev=None):
55
+ seq_len = x.shape[1]
56
+ hidden = self.hidden_proj(x)
57
+ gate = self.gate_proj(x)
58
+
59
+ h_prev = h_prev.detach() if h_prev is not None else None
60
+
61
+ log_coeffs = -F.softplus(gate)
62
+ log_z = -F.softplus(-gate)
63
+ log_tilde_h = log_g(hidden)
64
+ log_values = log_z + log_tilde_h
65
+
66
+ if h_prev is not None:
67
+ log_values = torch.cat((h_prev.log(), log_values), dim=1)
68
+ log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))
69
+
70
+ out = heinsen_associative_scan_log(log_coeffs, log_values)
71
+ out = out[:, -seq_len:]
72
+
73
+ h_next = out[:, -1:]
74
+ out = self.out_proj(out)
75
+
76
+ return out, h_next
77
+
78
+
79
+
80
+
81
+
82
+
83
+ class RMSNorm(nn.Module):
84
+ def __init__(self, d_model: int, eps: float=1e-5):
85
+ super().__init__()
86
+ self.eps = eps
87
+ self.weight = nn.Parameter(torch.ones(d_model))
88
+
89
+ def _norm(self, x):
90
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
91
+
92
+ def forward(self, x):
93
+ output = self._norm(x.float()).type_as(x)
94
+ return output * self.weight
95
+
96
+
97
+
98
+
99
+
100
+ class minGRULM(nn.Module):
101
+ def __init__(self, vocab_size, d_model, d_inner, n_layers):
102
+ super().__init__()
103
+ self.embed = nn.Embedding(vocab_size, d_model)
104
+
105
+ self.layers = nn.ModuleList([])
106
+ for _ in range(n_layers):
107
+ self.layers.append(nn.ModuleList([
108
+ RMSNorm(d_model),
109
+ minGRU(d_model, d_inner)
110
+ ]))
111
+
112
+ self.norm_f = RMSNorm(d_model)
113
+ self.lm_head = nn.Linear(d_model, vocab_size, bias = False)
114
+
115
+
116
+
117
+ # One single step of minGRU, forwarding one token and outputting one token
118
+ def step(self, x, h_states=None):
119
+ x = self.embed(x)
120
+
121
+ h_next = []
122
+ h_states = iter(h_states if h_states is not None else [])
123
+
124
+ for norm, mingru in self.layers:
125
+ h_prev = next(h_states, None)
126
+ residual = x
127
+
128
+ x = norm(x)
129
+ x, h_t = mingru.step(x, h_prev)
130
+ x = x + residual
131
+
132
+ h_next.append(h_t)
133
+
134
+ x = self.norm_f(x)
135
+ logits = self.lm_head(x)
136
+
137
+ return logits, h_next
138
+
139
+
140
+
141
+ def forward(self, x, h_states=None):
142
+ x, labels = x[:, :-1], x[:, 1:]
143
+ x = self.embed(x)
144
+
145
+ h_next = []
146
+ h_states = iter(h_states if h_states is not None else [])
147
+
148
+ for norm, mingru in self.layers:
149
+ h_prev = next(h_states, None)
150
+ residual = x
151
+
152
+ x = norm(x)
153
+ x, h_t = mingru.forward(x, h_prev)
154
+ x = x + residual
155
+
156
+ h_next.append(h_t)
157
+
158
+ x = self.norm_f(x)
159
+ logits = self.lm_head(x)
160
+ loss = F.cross_entropy(logits.transpose(1, 2), labels)
161
+
162
+ return logits, h_next, loss
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers>=4.44.2
2
+ datasets>=3.0.1
3
+ wandb>=0.18.3
train.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Utility functions for sampling
2
+
3
+ import torch
4
+ import math
5
+ from transformers import GPT2Tokenizer
6
+ from datasets import load_dataset
7
+ import numpy as np
8
+
9
+ from model import minGRULM
10
+ from util import generate_text
11
+
12
+
13
+
14
+
15
+
16
+
17
+
18
+ dataset_path = 'flpelerin/tinystories-100k'
19
+
20
+ num_epochs = 1
21
+ batch_size = 4
22
+ seq_length = 256
23
+ learning_rate = 1e-4
24
+ infer_step = 50
25
+
26
+ input_len = 50
27
+ num_predict = 250
28
+
29
+ reset_state_every = 16
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ print(f"total context size is {batch_size * seq_length} tokens");
33
+
34
+
35
+
36
+
37
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
38
+ tokenizer.pad_token = tokenizer.eos_token
39
+ vocab_size = tokenizer.vocab_size
40
+ print(f"tokenizer has {vocab_size} unique tokens")
41
+
42
+
43
+
44
+
45
+ dataset = load_dataset(dataset_path)
46
+
47
+ def process_function(examples):
48
+ return tokenizer(examples['text'], padding='longest', truncation=True)
49
+
50
+ tokenized_datasets = dataset.map(process_function, batched=True)
51
+ print(f"dataset has {tokenized_datasets['train'].num_rows} rows of {batch_size} times {seq_length} tokens")
52
+
53
+
54
+
55
+
56
+ #model = minGRULM(
57
+ # vocab_size = vocab_size,
58
+ # d_model = 768,
59
+ # d_inner = 1536,
60
+ # n_layers = 12
61
+ #)
62
+
63
+ model = minGRULM(
64
+ vocab_size = vocab_size,
65
+ d_model = 384,
66
+ d_inner = 768,
67
+ n_layers = 6
68
+ )
69
+
70
+
71
+ model.to(device)
72
+ print(f"model has {sum(p.numel() for p in model.parameters()):,} parameters")
73
+
74
+
75
+
76
+
77
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
78
+
79
+ h_states = None
80
+
81
+ step = 0
82
+ for epoch in range(num_epochs):
83
+ for i in range(0, len(tokenized_datasets['train']), batch_size):
84
+ batch = tokenized_datasets['train'][i:i + batch_size]
85
+ input_ids = torch.tensor(batch['input_ids']).to(device)
86
+
87
+ #if (i / batch_size) % reset_state_every == 0:
88
+ # print(f"resetting state, {(i / batch_size)} % {reset_state_every} == 0")
89
+ # h_states = None
90
+
91
+ h_states = h_states if (i / batch_size) % reset_state_every != 0 else None
92
+ str_states = ''.join(['{:.3f}, '.format(h_states[0][0][0][i].item()) for i in range(10)]) if h_states is not None else 'None'
93
+
94
+ optimizer.zero_grad()
95
+ _, h_states, loss = model.forward(input_ids, h_states)
96
+ loss.backward()
97
+ optimizer.step()
98
+
99
+ step += 1
100
+ print(f"Epoch: {epoch} / {num_epochs}, Step: {step}, Loss: {loss.item():.4f}, Hidden State: {str_states}")
101
+
102
+ if step % infer_step == 0:
103
+ model.eval()
104
+
105
+ ids = input_ids[0][:input_len]
106
+ text = tokenizer.decode(ids)
107
+ print(f"input: {text}")
108
+
109
+ prompt = ids[None, ...]
110
+ text = generate_text(model, tokenizer, prompt, num_predict)
111
+ print(f"output: {text}")
112
+
113
+ model.train()
util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+
5
+
6
+ def log(t, eps = 1e-20):
7
+ return torch.log(t.clamp(min = eps))
8
+
9
+ def gumbel_noise(t):
10
+ noise = torch.zeros_like(t).uniform_(0, 1)
11
+ return -log(-log(noise))
12
+
13
+ def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
14
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
15
+
16
+ def top_k(logits, thres = 0.9):
17
+ k = math.ceil((1 - thres) * logits.shape[-1])
18
+ val, ind = torch.topk(logits, k)
19
+ probs = torch.full_like(logits, float('-inf'))
20
+ probs.scatter_(-1, ind, val)
21
+ return probs
22
+
23
+
24
+
25
+ def generate_text(model, tokenizer, prompt: torch.Tensor, seq_len: int):
26
+ prompt_seq_len = prompt.shape[-1]
27
+
28
+ h_states = None
29
+ logits = None
30
+ text = ""
31
+
32
+ for i in range(prompt_seq_len):
33
+ tok = prompt[:, i:i+1] # (1, 1)
34
+ logits, h_states = model.step(tok, h_states)
35
+
36
+ for _ in range(seq_len):
37
+ logits = top_k(logits, thres=.9)
38
+ token = gumbel_sample(logits, temperature=.7, dim=-1)[0]
39
+
40
+ logits, h_states = model.step(token, h_states)
41
+
42
+ token = tokenizer.decode(token.item())
43
+ text += token
44
+
45
+ return text