nitinvig commited on
Commit
89dca07
·
verified ·
1 Parent(s): 67ece0a

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +126 -0
  2. model.pt +3 -0
  3. model.py +311 -0
  4. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer
4
+ from model import SmolLMForCausalLM, SmolLMConfig
5
+ import os
6
+
7
+ # 1. Configuration constants
8
+ MODEL_CHECKPOINT = "model.pt" # Expects the model weights to be in this file
9
+ TOKENIZER_ID = "HuggingFaceTB/SmolLM-135M" # Using the standard tokenizer
10
+ DEVICE = "cpu" # HF Spaces free tier usually is CPU. Change to 'cuda' if GPU is available.
11
+
12
+ # 2. Load Model and Tokenizer
13
+ print("Loading tokenizer...")
14
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
15
+
16
+ print("Initializing model...")
17
+ config = SmolLMConfig()
18
+ model = SmolLMForCausalLM(config)
19
+
20
+ # 3. Load Weights
21
+ if os.path.exists(MODEL_CHECKPOINT):
22
+ print(f"Loading weights from {MODEL_CHECKPOINT}...")
23
+ try:
24
+ # Map location to CPU to be safe
25
+ checkpoint = torch.load(MODEL_CHECKPOINT, map_location=torch.device('cpu'))
26
+
27
+ # Check if it's a full checkpoint (dict with 'model_state_dict') or just weights
28
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
29
+ state_dict = checkpoint['model_state_dict']
30
+ else:
31
+ state_dict = checkpoint
32
+
33
+ # Handle any prefix issues (e.g. if saved from compiled model with '_orig_mod.')
34
+ new_state_dict = {}
35
+ for k, v in state_dict.items():
36
+ if k.startswith("_orig_mod."):
37
+ new_state_dict[k[10:]] = v
38
+ else:
39
+ new_state_dict[k] = v
40
+
41
+ model.load_state_dict(new_state_dict)
42
+ print("Weights loaded successfully.")
43
+ except Exception as e:
44
+ print(f"Error loading weights: {e}")
45
+ print("Running with initialized (random) weights for demonstration.")
46
+ else:
47
+ print(f"Warning: {MODEL_CHECKPOINT} not found! Running with random weights.")
48
+
49
+ model.to(DEVICE)
50
+ model.eval()
51
+
52
+ # 4. Generation Function
53
+ def generate_text(prompt, max_new_tokens, temperature, top_k):
54
+ if not prompt:
55
+ return "Please enter a prompt."
56
+
57
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
58
+
59
+ # Text Generation Loop
60
+ # We implement a simple loop similar to the training script's generate function
61
+ # but added temperature and top-k sampling for better variety in the demo.
62
+
63
+ curr_input_ids = input_ids
64
+
65
+ with torch.no_grad():
66
+ for _ in range(int(max_new_tokens)):
67
+ # Get logits
68
+ logits = model(curr_input_ids)
69
+ next_token_logits = logits[:, -1, :]
70
+
71
+ # Apply Temperature
72
+ if temperature > 0:
73
+ next_token_logits = next_token_logits / temperature
74
+ else:
75
+ # Greedy decoding if temperature is 0 (or very close)
76
+ # Just take argmax, but for code simplicity we'll let multinomial handle it with very high conf or Argmax
77
+ next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
78
+ curr_input_ids = torch.cat([curr_input_ids, next_token_id], dim=1)
79
+ continue
80
+
81
+ # Apply Top-K
82
+ if top_k > 0:
83
+ v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
84
+ next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
85
+
86
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
87
+
88
+ # Sample
89
+ next_token_id = torch.multinomial(probs, num_samples=1)
90
+ curr_input_ids = torch.cat([curr_input_ids, next_token_id], dim=1)
91
+
92
+ # optional: stop if EOS token is generated (if we had one defined and training used it)
93
+ # if next_token_id == tokenizer.eos_token_id:
94
+ # break
95
+
96
+ output_text = tokenizer.decode(curr_input_ids[0].tolist(), skip_special_tokens=True)
97
+ return output_text
98
+
99
+ # 5. Build Gradio Interface
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown("# SmolLM-135M Implementation Demo")
102
+ gr.Markdown("This is a demo of the 135M parameter transformer model trained from scratch.")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Once upon a time...", lines=3)
107
+ with gr.Row():
108
+ max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max New Tokens")
109
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
110
+ top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K")
111
+ generate_btn = gr.Button("Generate", variant="primary")
112
+
113
+ with gr.Column():
114
+ output = gr.Textbox(label="Generated Text", lines=10)
115
+
116
+ generate_btn.click(
117
+ fn=generate_text,
118
+ inputs=[prompt_input, max_tokens, temperature, top_k],
119
+ outputs=output
120
+ )
121
+
122
+ gr.Markdown("### Note on inputs")
123
+ gr.Markdown("Because this model is small (135M) and trained on a specific dataset, it may not follow instructions like ChatGPT. It is best at completing text/stories.")
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch()
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aba04700adc3a6ad2a4a89943dc7507c4393b5c9bab5eea07a6f8615278951e0
3
+ size 538148429
model.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ @dataclass
9
+ class SmolLMConfig:
10
+ """
11
+ Configuration class for SmolLM.
12
+ This holds all the hyperparameters that define the model architecture.
13
+ """
14
+ vocab_size: int = 49152 # Size of vocabulary (number of unique tokens)
15
+ hidden_size: int = 576 # Dimension of the embedding vectors
16
+ intermediate_size: int = 1536 # Dimension of the inner layer in the MLP
17
+ num_hidden_layers: int = 30 # Number of Transformer blocks (depth)
18
+ num_attention_heads: int = 9 # Number of heads for the query
19
+ num_key_value_heads: int = 3 # Number of heads for keys and values (GQA)
20
+ hidden_act: str = "silu" # Activation function
21
+ max_position_embeddings: int = 2048 # Maximum sequence length
22
+ initializer_range: float = 0.02
23
+ rms_norm_eps: float = 1e-05
24
+ use_cache: bool = True
25
+ tie_word_embeddings: bool = True # Share weights between input embedding and output layer
26
+ rope_theta: float = 10000.0
27
+
28
+ def __post_init__(self):
29
+ # Calculate dimension per head
30
+ self.head_dim = self.hidden_size // self.num_attention_heads
31
+ # Calculate how many Query heads share one Key/Value head (Grouped Query Attention)
32
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
33
+
34
+ class RMSNorm(nn.Module):
35
+ """
36
+ Root Mean Square Layer Normalization (RMSNorm).
37
+ A simpler version of LayerNorm that re-scales inputs based on their RMS.
38
+ It stabilizes training and is used in Llama-based models instead of standard LayerNorm.
39
+ """
40
+ def __init__(self, dim: int, eps: float = 1e-5):
41
+ super().__init__()
42
+ self.eps = eps
43
+ self.weight = nn.Parameter(torch.ones(dim))
44
+
45
+ def _norm(self, x):
46
+ # Calculate RMS: sqrt(mean(x^2) + epsilon)
47
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
48
+
49
+ def forward(self, x):
50
+ # Normalize and then scale by a learnable parameter
51
+ output = self._norm(x.float()).type_as(x)
52
+ return output * self.weight
53
+
54
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
55
+ """
56
+ Applies Rotary Positional Embeddings (RoPE) to queries and keys.
57
+ RoPE rotates the query and key vectors to inject relative positional information.
58
+ """
59
+ # q, k: [bs, num_heads, seq_len, head_dim]
60
+ # cos, sin: [seq_len, head_dim] or projected
61
+
62
+ # Rotate function: [-x2, x1]
63
+ def rotate_half(x):
64
+ x1 = x[..., : x.shape[-1] // 2]
65
+ x2 = x[..., x.shape[-1] // 2 :]
66
+ return torch.cat((-x2, x1), dim=-1)
67
+
68
+ cos = cos.unsqueeze(0).unsqueeze(unsqueeze_dim) # [1, 1, seq_len, head_dim]
69
+ sin = sin.unsqueeze(0).unsqueeze(unsqueeze_dim)
70
+
71
+ # Apply rotation: (x * cos) + (rotate_90(x) * sin)
72
+ q_embed = (q * cos) + (rotate_half(q) * sin)
73
+ k_embed = (k * cos) + (rotate_half(k) * sin)
74
+ return q_embed, k_embed
75
+
76
+ class LlamaRotaryEmbedding(nn.Module):
77
+ """
78
+ Pre-computes the cosine and sine values for RoPE.
79
+ These are fixed values based on position indices, used to modulate Q and K.
80
+ """
81
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.max_position_embeddings = max_position_embeddings
85
+ self.base = base
86
+ # Calculate inverse frequencies for the rotations
87
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
88
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
89
+ self._set_cos_sin_cache(max_position_embeddings, device=device)
90
+
91
+ def _set_cos_sin_cache(self, seq_len, device):
92
+ self.max_seq_len_cached = seq_len
93
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
94
+ freqs = torch.outer(t, self.inv_freq)
95
+ # Different from standard position embeddings, we concat freq to itself to cover both halves
96
+ emb = torch.cat((freqs, freqs), dim=-1)
97
+ self.register_buffer("cos_cached", emb.cos().to(dtype=torch.float32), persistent=False)
98
+ self.register_buffer("sin_cached", emb.sin().to(dtype=torch.float32), persistent=False)
99
+
100
+
101
+ def forward(self, x, seq_len):
102
+ if seq_len > self.max_seq_len_cached:
103
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
104
+ return (
105
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
106
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
107
+ )
108
+
109
+ class LlamaMLP(nn.Module):
110
+ """
111
+ Feed-Forward Network (FFN) utilizing the SwiGLU activation.
112
+ Structure:
113
+ x -> GateProj -> SiLU \
114
+ -> Multiply -> DownProj -> output
115
+ x -> UpProj_________/
116
+ """
117
+ def __init__(self, config: SmolLMConfig):
118
+ super().__init__()
119
+ self.hidden_size = config.hidden_size
120
+ self.intermediate_size = config.intermediate_size
121
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
122
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
123
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
124
+ self.act_fn = nn.SiLU()
125
+
126
+ def forward(self, x):
127
+ # SwiGLU: (SiLU(Gate(x)) * Up(x)) -> Down(x)
128
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
129
+ return down_proj
130
+
131
+ class LlamaAttention(nn.Module):
132
+ """
133
+ Multi-Head Attention with Grouped Query Attention (GQA).
134
+ GQA uses fewer Key/Value heads than Query heads to save memory and KV cache during inference.
135
+ """
136
+ def __init__(self, config: SmolLMConfig):
137
+ super().__init__()
138
+ self.config = config
139
+ self.hidden_size = config.hidden_size
140
+ self.num_heads = config.num_attention_heads
141
+ self.head_dim = config.head_dim
142
+ self.num_key_value_heads = config.num_key_value_heads
143
+ self.num_key_value_groups = config.num_key_value_groups
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+
146
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
147
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
148
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
149
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
150
+
151
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
152
+
153
+ def forward(self, x, position_ids=None, attention_mask=None):
154
+ bsz, q_len, _ = x.size()
155
+
156
+ # 1. Project inputs to Q, K, V
157
+ q = self.q_proj(x)
158
+ k = self.k_proj(x)
159
+ v = self.v_proj(x)
160
+
161
+ # 2. Reshape for multi-head attention
162
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
163
+ k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
164
+ v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
165
+
166
+ # 3. Apply Rotary Embeddings
167
+ cos, sin = self.rotary_emb(v, seq_len=q_len)
168
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
169
+
170
+ # 4. Handle GQA (Grouped Query Attention)
171
+ # If we have fewer KV heads than Q heads, we repeat K and V to match Q's dimensions
172
+ if self.num_key_value_groups > 1:
173
+ k = k[:, :, None, :, :].expand(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, self.head_dim).reshape(bsz, self.num_heads, q_len, self.head_dim)
174
+ v = v[:, :, None, :, :].expand(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, self.head_dim).reshape(bsz, self.num_heads, q_len, self.head_dim)
175
+
176
+ # 5. Scaled Dot Product Attention (Flash Attention / Memory Efficient Attention)
177
+ # We use PyTorch's optimized implementation which selects the best backend (FlashAttn, etc.)
178
+ # If passed an attention_mask, we might need to rely on the manual path if it's complex,
179
+ # but for causal masking we can just use is_causal=True
180
+
181
+ # NOTE: F.scaled_dot_product_attention expects 40D input: [batch, heads, seq, head_dim]
182
+ # Our q, k, v are already in that format after transpose.
183
+
184
+ # If we have a mask that is NOT the causal mask (e.g. padding mask), we need to handle it.
185
+ # But for training from scratch with standard causal LM, we usually just need causal mask.
186
+
187
+ dropout_p = 0.0 # Could add to config if desired
188
+
189
+ # We need to broadcast the Mask if it is provided
190
+ if attention_mask is not None:
191
+ # Standard implementation if a custom mask is provided (rare for basic causal LM training)
192
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
193
+ attn_weights = attn_weights + attention_mask
194
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
195
+ attn_output = torch.matmul(attn_weights, v)
196
+ else:
197
+ # Optimized path
198
+ attn_output = F.scaled_dot_product_attention(
199
+ q, k, v,
200
+ attn_mask=None,
201
+ dropout_p=dropout_p,
202
+ is_causal=True
203
+ )
204
+
205
+ # 6. Reshape back and apply output projection
206
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
207
+ attn_output = self.o_proj(attn_output)
208
+
209
+ return attn_output
210
+
211
+ class LlamaDecoderLayer(nn.Module):
212
+ """
213
+ A single Transformer block.
214
+ Consists of:
215
+ 1. Pre-Norm -> Attention -> Add Residual
216
+ 2. Pre-Norm -> MLP (Feed Forward) -> Add Residual
217
+ """
218
+ def __init__(self, config: SmolLMConfig):
219
+ super().__init__()
220
+ self.hidden_size = config.hidden_size
221
+ self.self_attn = LlamaAttention(config)
222
+ self.mlp = LlamaMLP(config)
223
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
224
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
225
+
226
+ def forward(self, x, position_ids=None, attention_mask=None):
227
+ residual = x
228
+ x = self.input_layernorm(x)
229
+ # Self Attention block
230
+ x = self.self_attn(x, position_ids=position_ids, attention_mask=attention_mask)
231
+ x = residual + x # Residual connection
232
+
233
+ residual = x
234
+ x = self.post_attention_layernorm(x)
235
+ # MLP block
236
+ x = self.mlp(x)
237
+ x = residual + x # Residual connection
238
+ return x
239
+
240
+ class SmolLMModel(nn.Module):
241
+ """
242
+ Main Transformer model (the "trunk").
243
+ Embeddings -> N x Decoder Layers -> Final Norm
244
+ """
245
+ def __init__(self, config: SmolLMConfig):
246
+ super().__init__()
247
+ self.config = config
248
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
249
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
250
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
251
+
252
+ def forward(self, input_ids):
253
+ # 1. Lookup Embeddings
254
+ x = self.embed_tokens(input_ids)
255
+
256
+ seq_len = x.shape[1]
257
+
258
+ # 2. Key Concept: Causal Mask
259
+ # We want the model to predict the NEXT token, so it shouldn't see future tokens.
260
+ # However, with F.scaled_dot_product_attention(is_causal=True), we don't need to pass an explicit mask
261
+ # unless dealing with padding.
262
+ # We pass None to allow the optimized attention to handle it.
263
+ mask = None
264
+ # mask = torch.full((seq_len, seq_len), float("-inf"), device=x.device)
265
+ # mask = torch.triu(mask, diagonal=1)
266
+
267
+ # 3. Pass through all Transformer Layers
268
+ for layer in self.layers:
269
+ x = layer(x, attention_mask=mask)
270
+
271
+ x = self.norm(x)
272
+ return x
273
+
274
+ class SmolLMForCausalLM(nn.Module):
275
+ """
276
+ The full Causal Language Model.
277
+ Wraps the trunk (SmolLMModel) and adds the Language Model Head (Linear Layer) to project to accumulation logic.
278
+ """
279
+ def __init__(self, config: SmolLMConfig):
280
+ super().__init__()
281
+ self.model = SmolLMModel(config)
282
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
283
+
284
+ # Weight tying
285
+ if config.tie_word_embeddings:
286
+ self.lm_head.weight = self.model.embed_tokens.weight
287
+
288
+ def forward(self, input_ids):
289
+ x = self.model(input_ids)
290
+ logits = self.lm_head(x)
291
+ return logits
292
+
293
+ def test_model():
294
+ config = SmolLMConfig()
295
+ print(f"Initializing SmolLM-135M with config: {config}")
296
+
297
+ model = SmolLMForCausalLM(config)
298
+ print(f"Model keys: {model.state_dict().keys().__len__()}")
299
+
300
+ # Test forward pass
301
+ dummy_input = torch.randint(0, config.vocab_size, (1, 32)) # Batch size 1, seq len 32
302
+ print(f"Running forward pass with input shape {dummy_input.shape}")
303
+
304
+ logits = model(dummy_input)
305
+ print(f"Output shape: {logits.shape}") # Should be [1, 32, 49152]
306
+
307
+ assert logits.shape == (1, 32, config.vocab_size)
308
+ print("Test passed!")
309
+
310
+ if __name__ == "__main__":
311
+ test_model()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ numpy