Keeby-smilyai commited on
Commit
230b53a
ยท
verified ยท
1 Parent(s): 3f68529

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +507 -0
app.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # HuggingFace Space - Sam Model Chat Interface with Streaming
3
+ # ==============================================================================
4
+ # Loads model directly from HuggingFace Hub: Smilyai-labs/Sam-1-large
5
+ # ==============================================================================
6
+
7
+ import gradio as gr
8
+ import tensorflow as tf
9
+ import keras
10
+ import numpy as np
11
+ from tokenizers import Tokenizer
12
+ from huggingface_hub import hf_hub_download
13
+ import os
14
+
15
+ # ==============================================================================
16
+ # Model Configuration
17
+ # ==============================================================================
18
+
19
+ MODEL_REPO = "Smilyai-labs/Sam-1-large" # Your HuggingFace model repo
20
+ MAX_NEW_TOKENS = 512
21
+ TEMPERATURE = 0.8
22
+ TOP_P = 0.9
23
+ TOP_K = 50
24
+
25
+ # ==============================================================================
26
+ # Custom Keras Layers (Must match training code)
27
+ # ==============================================================================
28
+
29
+ @keras.saving.register_keras_serializable()
30
+ class RotaryEmbedding(keras.layers.Layer):
31
+ def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
32
+ super().__init__(**kwargs)
33
+ self.dim = dim
34
+ self.max_len = max_len
35
+ self.theta = theta
36
+ self.built_cache = False
37
+
38
+ def build(self, input_shape):
39
+ if not self.built_cache:
40
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
41
+ t = tf.range(self.max_len, dtype=tf.float32)
42
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
43
+ emb = tf.concat([freqs, freqs], axis=-1)
44
+
45
+ self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
46
+ self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
47
+ self.built_cache = True
48
+
49
+ super().build(input_shape)
50
+
51
+ def rotate_half(self, x):
52
+ x1, x2 = tf.split(x, 2, axis=-1)
53
+ return tf.concat([-x2, x1], axis=-1)
54
+
55
+ def call(self, q, k):
56
+ seq_len = tf.shape(q)[2]
57
+ dtype = q.dtype
58
+ cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
59
+ sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
60
+
61
+ q_rotated = (q * cos) + (self.rotate_half(q) * sin)
62
+ k_rotated = (k * cos) + (self.rotate_half(k) * sin)
63
+
64
+ return q_rotated, k_rotated
65
+
66
+ def get_config(self):
67
+ config = super().get_config()
68
+ config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
69
+ return config
70
+
71
+
72
+ @keras.saving.register_keras_serializable()
73
+ class RMSNorm(keras.layers.Layer):
74
+ def __init__(self, epsilon=1e-5, **kwargs):
75
+ super().__init__(**kwargs)
76
+ self.epsilon = epsilon
77
+
78
+ def build(self, input_shape):
79
+ self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
80
+
81
+ def call(self, x):
82
+ variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
83
+ return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
84
+
85
+ def get_config(self):
86
+ config = super().get_config()
87
+ config.update({"epsilon": self.epsilon})
88
+ return config
89
+
90
+
91
+ @keras.saving.register_keras_serializable()
92
+ class TransformerBlock(keras.layers.Layer):
93
+ def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
94
+ super().__init__(**kwargs)
95
+ self.d_model = d_model
96
+ self.n_heads = n_heads
97
+ self.ff_dim = ff_dim
98
+ self.dropout_rate = dropout
99
+ self.max_len = max_len
100
+ self.rope_theta = rope_theta
101
+ self.head_dim = d_model // n_heads
102
+ self.layer_idx = layer_idx
103
+
104
+ self.pre_attn_norm = RMSNorm()
105
+ self.pre_ffn_norm = RMSNorm()
106
+
107
+ self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
108
+ self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
109
+ self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
110
+ self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
111
+
112
+ self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
113
+
114
+ self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
115
+ self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
116
+ self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
117
+
118
+ self.dropout = keras.layers.Dropout(dropout)
119
+
120
+ def call(self, x, training=None):
121
+ B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
122
+ dtype = x.dtype
123
+
124
+ res = x
125
+ y = self.pre_attn_norm(x)
126
+
127
+ q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
128
+ k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
129
+ v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
130
+
131
+ q, k = self.rope(q, k)
132
+
133
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
134
+
135
+ mask = tf.where(
136
+ tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
137
+ tf.constant(-1e9, dtype=dtype),
138
+ tf.constant(0.0, dtype=dtype)
139
+ )
140
+ scores += mask
141
+ attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
142
+
143
+ attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
144
+ x = res + self.dropout(self.out_proj(attn), training=training)
145
+
146
+ res = x
147
+ y = self.pre_ffn_norm(x)
148
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
149
+
150
+ return res + self.dropout(ffn, training=training)
151
+
152
+ def get_config(self):
153
+ config = super().get_config()
154
+ config.update({
155
+ "d_model": self.d_model,
156
+ "n_heads": self.n_heads,
157
+ "ff_dim": self.ff_dim,
158
+ "dropout": self.dropout_rate,
159
+ "max_len": self.max_len,
160
+ "rope_theta": self.rope_theta,
161
+ "layer_idx": self.layer_idx
162
+ })
163
+ return config
164
+
165
+
166
+ @keras.saving.register_keras_serializable()
167
+ class SAM1Model(keras.Model):
168
+ def __init__(self, **kwargs):
169
+ super().__init__()
170
+ if 'config' in kwargs and isinstance(kwargs['config'], dict):
171
+ self.cfg = kwargs['config']
172
+ elif 'vocab_size' in kwargs:
173
+ self.cfg = kwargs
174
+ else:
175
+ self.cfg = kwargs.get('cfg', kwargs)
176
+
177
+ self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
178
+
179
+ ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
180
+ block_args = {
181
+ 'd_model': self.cfg['d_model'],
182
+ 'n_heads': self.cfg['n_heads'],
183
+ 'ff_dim': ff_dim,
184
+ 'dropout': self.cfg['dropout'],
185
+ 'max_len': self.cfg['max_len'],
186
+ 'rope_theta': self.cfg['rope_theta']
187
+ }
188
+
189
+ self.blocks = []
190
+ for i in range(self.cfg['n_layers']):
191
+ block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
192
+ self.blocks.append(block)
193
+
194
+ self.norm = RMSNorm(name="final_norm")
195
+ self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
196
+
197
+ def call(self, input_ids, training=None):
198
+ x = self.embed(input_ids)
199
+
200
+ for block in self.blocks:
201
+ x = block(x, training=training)
202
+
203
+ return self.lm_head(self.norm(x))
204
+
205
+ def get_config(self):
206
+ base_config = super().get_config()
207
+ base_config['config'] = self.cfg
208
+ return base_config
209
+
210
+ # ==============================================================================
211
+ # Load Model and Tokenizer from HuggingFace Hub
212
+ # ==============================================================================
213
+
214
+ print("๐Ÿ”ฅ Loading Sam model from HuggingFace Hub...")
215
+ print(f" Repository: {MODEL_REPO}")
216
+
217
+ try:
218
+ # Download model file
219
+ print("๐Ÿ“ฅ Downloading model.keras...")
220
+ model_path = hf_hub_download(
221
+ repo_id=MODEL_REPO,
222
+ filename="model.keras",
223
+ cache_dir="./model_cache"
224
+ )
225
+ print(f"โœ… Model downloaded to: {model_path}")
226
+
227
+ # Download tokenizer
228
+ print("๐Ÿ“ฅ Downloading tokenizer.json...")
229
+ tokenizer_path = hf_hub_download(
230
+ repo_id=MODEL_REPO,
231
+ filename="tokenizer.json",
232
+ cache_dir="./model_cache"
233
+ )
234
+ print(f"โœ… Tokenizer downloaded to: {tokenizer_path}")
235
+
236
+ # Load tokenizer
237
+ tokenizer = Tokenizer.from_file(tokenizer_path)
238
+ eos_token = "<|endoftext|>"
239
+ eos_token_id = tokenizer.token_to_id(eos_token)
240
+ print(f"โœ… Tokenizer loaded (vocab_size={tokenizer.get_vocab_size()})")
241
+
242
+ # Load model
243
+ print("๐Ÿ”„ Loading model into memory...")
244
+ model = keras.models.load_model(model_path)
245
+ print(f"โœ… Model loaded successfully!")
246
+
247
+ except Exception as e:
248
+ print(f"โŒ Error loading model: {e}")
249
+ print("\n๐Ÿ’ก Troubleshooting:")
250
+ print("1. Make sure the model repo exists: https://huggingface.co/Smilyai-labs/Sam-1-large")
251
+ print("2. Check that model.keras and tokenizer.json are in the repo")
252
+ print("3. If repo is private, you may need to login: huggingface-cli login")
253
+ raise
254
+
255
+ # ==============================================================================
256
+ # Generation Functions
257
+ # ==============================================================================
258
+
259
+ def sample_token(logits, temperature=1.0, top_p=0.9, top_k=50):
260
+ """Sample next token with temperature, top-p, and top-k"""
261
+ logits = logits / temperature
262
+
263
+ # Top-k filtering
264
+ if top_k > 0:
265
+ top_k_logits, top_k_indices = tf.nn.top_k(logits, k=min(top_k, logits.shape[-1]))
266
+ logits = tf.where(
267
+ tf.reduce_any(tf.equal(tf.expand_dims(tf.range(logits.shape[-1]), 0),
268
+ tf.expand_dims(top_k_indices, -1)), axis=1),
269
+ logits,
270
+ tf.fill(logits.shape, -1e10)
271
+ )
272
+
273
+ # Top-p (nucleus) filtering
274
+ if top_p < 1.0:
275
+ sorted_logits = tf.sort(logits, direction='DESCENDING')
276
+ sorted_probs = tf.nn.softmax(sorted_logits)
277
+ cumsum_probs = tf.cumsum(sorted_probs)
278
+
279
+ sorted_indices_to_remove = cumsum_probs > top_p
280
+ sorted_indices_to_remove = tf.concat([
281
+ [False],
282
+ sorted_indices_to_remove[:-1]
283
+ ], axis=0)
284
+
285
+ sorted_indices = tf.argsort(logits, direction='DESCENDING')
286
+ indices_to_remove = tf.gather(sorted_indices_to_remove, tf.argsort(sorted_indices))
287
+
288
+ logits = tf.where(indices_to_remove, -1e10, logits)
289
+
290
+ # Sample
291
+ probs = tf.nn.softmax(logits)
292
+ next_token = tf.random.categorical(tf.math.log(probs[None, :]), num_samples=1)[0, 0]
293
+
294
+ return next_token.numpy()
295
+
296
+
297
+ def generate_stream(prompt, max_new_tokens=512, temperature=0.8, top_p=0.9, top_k=50):
298
+ """Generate text with streaming (yields tokens as they're generated)"""
299
+
300
+ # Format prompt
301
+ formatted_prompt = f"User: {prompt}\nSam:"
302
+
303
+ # Tokenize
304
+ encoding = tokenizer.encode(formatted_prompt)
305
+ input_ids = np.array([encoding.ids], dtype=np.int32)
306
+
307
+ # Check if prompt is too long
308
+ if input_ids.shape[1] > 1000:
309
+ yield "โŒ Error: Prompt is too long (max 1000 tokens)"
310
+ return
311
+
312
+ generated_text = ""
313
+
314
+ for _ in range(max_new_tokens):
315
+ # Get logits
316
+ logits = model(input_ids, training=False)
317
+ next_token_logits = logits[0, -1, :].numpy()
318
+
319
+ # Sample next token
320
+ next_token = sample_token(next_token_logits, temperature, top_p, top_k)
321
+
322
+ # Stop if EOS
323
+ if next_token == eos_token_id:
324
+ break
325
+
326
+ # Decode token
327
+ token_text = tokenizer.decode([next_token])
328
+ generated_text += token_text
329
+
330
+ # Yield for streaming
331
+ yield generated_text
332
+
333
+ # Append to input
334
+ input_ids = np.concatenate([input_ids, [[next_token]]], axis=1)
335
+
336
+ # Stop if we hit max length
337
+ if input_ids.shape[1] >= 1024:
338
+ break
339
+
340
+
341
+ def chat_interface(message, history, temperature, top_p, top_k, max_tokens):
342
+ """Gradio chat interface with streaming"""
343
+
344
+ if not message.strip():
345
+ return ""
346
+
347
+ # Build conversation context from history (last 3 turns to save tokens)
348
+ conversation = ""
349
+ recent_history = history[-3:] if len(history) > 3 else history
350
+
351
+ for user_msg, bot_msg in recent_history:
352
+ conversation += f"User: {user_msg}\nSam: {bot_msg}\n"
353
+
354
+ # Add current message
355
+ full_prompt = conversation + message if conversation else message
356
+
357
+ # Generate with streaming
358
+ for response_chunk in generate_stream(
359
+ full_prompt,
360
+ max_new_tokens=max_tokens,
361
+ temperature=temperature,
362
+ top_p=top_p,
363
+ top_k=top_k
364
+ ):
365
+ yield response_chunk
366
+
367
+
368
+ # ==============================================================================
369
+ # Gradio Interface
370
+ # ==============================================================================
371
+
372
+ with gr.Blocks(theme=gr.themes.Soft(), title="Chat with Sam") as demo:
373
+ gr.Markdown("""
374
+ # ๐Ÿค– Chat with Sam
375
+
376
+ **Sam** is a fine-tuned language model trained on math, code, reasoning, and conversational data.
377
+
378
+ ### โœจ Capabilities:
379
+ - ๐Ÿงฎ **Math**: Solve arithmetic and word problems (trained on GSM8K)
380
+ - ๐Ÿ’ป **Code**: Write Python, JavaScript, and more (trained on CodeAlpaca)
381
+ - ๐Ÿค” **Reasoning**: Show step-by-step thinking with `<think>` tags
382
+ - ๐Ÿ’ฌ **Chat**: Natural conversations on any topic
383
+
384
+ ### ๐Ÿ“Š Model Info:
385
+ - **Architecture**: 768d, 16 layers, 12 heads (~100M parameters)
386
+ - **Context**: 1024 tokens
387
+ - **Training**: TPU v5e-8 on multi-dataset mix
388
+ """)
389
+
390
+ chatbot = gr.Chatbot(
391
+ label="๐Ÿ’ฌ Conversation",
392
+ height=450,
393
+ show_copy_button=True,
394
+ avatar_images=(None, "๐Ÿค–"),
395
+ )
396
+
397
+ with gr.Row():
398
+ msg = gr.Textbox(
399
+ label="Your message",
400
+ placeholder="Ask Sam anything... (e.g., 'What is 127 * 43?' or 'Write a function to sort a list')",
401
+ lines=2,
402
+ scale=4,
403
+ autofocus=True
404
+ )
405
+ submit = gr.Button("Send ๐Ÿš€", scale=1, variant="primary")
406
+
407
+ with gr.Accordion("โš™๏ธ Generation Settings", open=False):
408
+ with gr.Row():
409
+ temperature = gr.Slider(
410
+ minimum=0.1,
411
+ maximum=2.0,
412
+ value=TEMPERATURE,
413
+ step=0.1,
414
+ label="Temperature",
415
+ info="Higher = more creative/random"
416
+ )
417
+ top_p = gr.Slider(
418
+ minimum=0.1,
419
+ maximum=1.0,
420
+ value=TOP_P,
421
+ step=0.05,
422
+ label="Top-p",
423
+ info="Nucleus sampling threshold"
424
+ )
425
+ with gr.Row():
426
+ top_k = gr.Slider(
427
+ minimum=1,
428
+ maximum=100,
429
+ value=TOP_K,
430
+ step=1,
431
+ label="Top-k",
432
+ info="Vocabulary size limit"
433
+ )
434
+ max_tokens = gr.Slider(
435
+ minimum=50,
436
+ maximum=512,
437
+ value=MAX_NEW_TOKENS,
438
+ step=50,
439
+ label="Max tokens",
440
+ info="Maximum response length"
441
+ )
442
+
443
+ with gr.Row():
444
+ clear = gr.Button("๐Ÿ—‘๏ธ Clear Chat")
445
+
446
+ with gr.Accordion("๐Ÿ’ก Example Prompts", open=False):
447
+ gr.Examples(
448
+ examples=[
449
+ ["What is 127 * 43?"],
450
+ ["Write a Python function to reverse a string"],
451
+ ["Explain how photosynthesis works"],
452
+ ["What's the capital of France?"],
453
+ ["Write a haiku about coding"],
454
+ ["How do I sort a list in Python?"],
455
+ ],
456
+ inputs=msg,
457
+ label="Click to try:"
458
+ )
459
+
460
+ gr.Markdown("""
461
+ ---
462
+ ### ๐Ÿ“ Tips:
463
+ - Sam uses conversational format: `User: ... Sam: ...`
464
+ - Watch for `<think>` tags showing reasoning process
465
+ - Adjust temperature for more creative (higher) or focused (lower) responses
466
+ - Model remembers last 3 conversation turns for context
467
+
468
+ ### ๐Ÿ”— Links:
469
+ - Model: [Smilyai-labs/Sam-1-large](https://huggingface.co/Smilyai-labs/Sam-1-large)
470
+ - Training: TPU v5e-8 on Kaggle
471
+ - Framework: TensorFlow/Keras
472
+ """)
473
+
474
+ # Event handlers
475
+ msg.submit(
476
+ chat_interface,
477
+ inputs=[msg, chatbot, temperature, top_p, top_k, max_tokens],
478
+ outputs=msg,
479
+ ).then(
480
+ lambda: gr.update(value=""),
481
+ None,
482
+ msg
483
+ )
484
+
485
+ submit.click(
486
+ chat_interface,
487
+ inputs=[msg, chatbot, temperature, top_p, top_k, max_tokens],
488
+ outputs=msg,
489
+ ).then(
490
+ lambda: gr.update(value=""),
491
+ None,
492
+ msg
493
+ )
494
+
495
+ clear.click(lambda: None, None, chatbot, queue=False)
496
+
497
+ # Launch
498
+ if __name__ == "__main__":
499
+ print("\n" + "="*70)
500
+ print("๐Ÿš€ STARTING SAM CHAT INTERFACE".center(70))
501
+ print("="*70)
502
+ print(f"\nโœ… Model loaded from: {MODEL_REPO}")
503
+ print(f"โœ… Vocab size: {tokenizer.get_vocab_size()}")
504
+ print(f"โœ… Ready to chat!\n")
505
+
506
+ demo.queue() # Enable streaming
507
+ demo.launch()