Keeby-smilyai commited on
Commit
579190c
Β·
verified Β·
1 Parent(s): c842762

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -468
app.py CHANGED
@@ -11,10 +11,10 @@ import time
11
  # ============================================================================
12
  # 🎊 FESTIVE MODE TOGGLE 🎊
13
  # ============================================================================
14
- FESTIVE = True # Set to False for production-only mode
15
 
16
  # ============================================================================
17
- # Configuration & Model Loading (Architecture definitions included)
18
  # ============================================================================
19
 
20
  print("πŸš€ Loading Sam-large-2 Model...")
@@ -39,14 +39,11 @@ class RotaryEmbedding(keras.layers.Layer):
39
  super().build(input_shape)
40
 
41
  def _build_cache(self):
42
- """Build RoPE cache on first forward pass"""
43
  if not self.built_cache:
44
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
45
  t = tf.range(self.max_len, dtype=tf.float32)
46
  freqs = tf.einsum("i,j->ij", t, inv_freq)
47
  emb = tf.concat([freqs, freqs], axis=-1)
48
-
49
- # Store as constant tensors
50
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
51
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
52
  self.built_cache = True
@@ -57,16 +54,11 @@ class RotaryEmbedding(keras.layers.Layer):
57
 
58
  def call(self, q, k):
59
  self._build_cache()
60
-
61
  seq_len = tf.shape(q)[2]
62
  dtype = q.dtype
63
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
64
  sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
65
-
66
- q_rotated = (q * cos) + (self.rotate_half(q) * sin)
67
- k_rotated = (k * cos) + (self.rotate_half(k) * sin)
68
-
69
- return q_rotated, k_rotated
70
 
71
  def get_config(self):
72
  config = super().get_config()
@@ -108,65 +100,39 @@ class TransformerBlock(keras.layers.Layer):
108
 
109
  self.pre_attn_norm = RMSNorm()
110
  self.pre_ffn_norm = RMSNorm()
111
-
112
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
113
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
114
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
115
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
116
-
117
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
118
-
119
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
120
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
121
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
122
-
123
  self.dropout = keras.layers.Dropout(dropout)
124
 
125
  def call(self, x, training=None):
126
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
127
  dtype = x.dtype
128
-
129
- # Attention
130
  res = x
131
  y = self.pre_attn_norm(x)
132
-
133
  q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
134
  k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
135
  v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
136
-
137
  q, k = self.rope(q, k)
138
-
139
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
140
-
141
- mask = tf.where(
142
- tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
143
- tf.constant(-1e9, dtype=dtype),
144
- tf.constant(0.0, dtype=dtype)
145
- )
146
  scores += mask
147
  attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
148
-
149
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
150
  x = res + self.dropout(self.out_proj(attn), training=training)
151
-
152
- # FFN (SwiGLU)
153
  res = x
154
  y = self.pre_ffn_norm(x)
155
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
156
-
157
  return res + self.dropout(ffn, training=training)
158
 
159
  def get_config(self):
160
  config = super().get_config()
161
- config.update({
162
- "d_model": self.d_model,
163
- "n_heads": self.n_heads,
164
- "ff_dim": self.ff_dim,
165
- "dropout": self.dropout_rate,
166
- "max_len": self.max_len,
167
- "rope_theta": self.rope_theta,
168
- "layer_idx": self.layer_idx
169
- })
170
  return config
171
 
172
 
@@ -182,31 +148,19 @@ class SAM1Model(keras.Model):
182
  self.cfg = kwargs.get('cfg', kwargs)
183
 
184
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
185
-
186
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
187
- block_args = {
188
- 'd_model': self.cfg['d_model'],
189
- 'n_heads': self.cfg['n_heads'],
190
- 'ff_dim': ff_dim,
191
- 'dropout': self.cfg['dropout'],
192
- 'max_len': self.cfg['max_len'],
193
- 'rope_theta': self.cfg['rope_theta']
194
- }
195
-
196
  self.blocks = []
197
  for i in range(self.cfg['n_layers']):
198
  block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
199
  self.blocks.append(block)
200
-
201
  self.norm = RMSNorm(name="final_norm")
202
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
203
 
204
  def call(self, input_ids, training=None):
205
  x = self.embed(input_ids)
206
-
207
  for block in self.blocks:
208
  x = block(x, training=training)
209
-
210
  return self.lm_head(self.norm(x))
211
 
212
  def get_config(self):
@@ -216,10 +170,8 @@ class SAM1Model(keras.Model):
216
 
217
  # --- Model and Tokenizer Loading ---
218
 
219
- # Download model files
220
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
221
 
222
- # Try to download checkpoint weights first (more reliable)
223
  try:
224
  weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
225
  print("βœ… Found checkpoint weights (ckpt.weights.h5)")
@@ -231,14 +183,10 @@ except Exception as e:
231
  use_checkpoint = False
232
  except Exception as e_model:
233
  print(f"❌ Also failed to find model.keras: {e_model}")
234
- # Commenting out raise to allow the Gradio UI to load even if model fails
235
- # raise
236
 
237
- # Load config
238
  with open(config_path, 'r') as f:
239
  config = json.load(f)
240
 
241
- # Create tokenizer from scratch
242
  from transformers import AutoTokenizer
243
 
244
  hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -249,19 +197,14 @@ hf_tokenizer.save_pretrained("./temp_tokenizer")
249
  tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
250
 
251
  print(f"βœ… Tokenizer created with vocab size: {tokenizer.get_vocab_size()}")
252
-
253
  eos_token_id = config.get('eos_token_id', 50256)
254
 
255
- # ==============================================================================
256
- # Load Model - Priority: checkpoint weights > saved model
257
- # ==============================================================================
258
  print("\nπŸ”„ Loading model...")
259
 
260
  model = None
261
 
262
  if use_checkpoint:
263
  print("πŸ“¦ Building model from config and loading checkpoint weights...")
264
-
265
  model_config = {
266
  'vocab_size': config['vocab_size'],
267
  'd_model': config['hidden_size'],
@@ -272,53 +215,37 @@ if use_checkpoint:
272
  'dropout': 0.1,
273
  'rope_theta': config['rope_theta']
274
  }
275
-
276
  model = SAM1Model(config=model_config)
277
-
278
- # Dummy call to build the model graph
279
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
280
  _ = model(dummy_input, training=False)
281
-
282
  print(f"βœ… Model architecture built: {model.count_params():,} parameters")
283
-
284
  try:
285
  model.load_weights(weights_path)
286
  print("βœ… Checkpoint weights loaded successfully!")
287
  except Exception as e:
288
  print(f"❌ Failed to load checkpoint weights: {e}")
289
- # Continue with un-initialized model, which will likely fail on inference
290
  else:
291
  print("πŸ“¦ Loading full saved model...")
292
  try:
293
- # Custom objects needed for loading
294
- custom_objects = {
295
- 'SAM1Model': SAM1Model,
296
- 'TransformerBlock': TransformerBlock,
297
- 'RMSNorm': RMSNorm,
298
- 'RotaryEmbedding': RotaryEmbedding
299
- }
300
  model = keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
301
  print("βœ… Model loaded successfully")
302
  except Exception as e:
303
  print(f"❌ Failed to load model: {e}")
304
- # Commenting out raise to allow the Gradio UI to load even if model fails
305
- # raise
306
 
307
  if model:
308
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
309
 
310
- # Global stop flag
311
- stop_generation = False
312
-
313
  # ============================================================================
314
- # Generation Function with Streaming & Stop Button
315
  # ============================================================================
316
 
317
- # Dummy/Simulated generation logic for safety when running without full TF environment
318
- @tf.function(jit_compile=True)
319
- def generate_step(input_ids, max_len, temp, topk, topp, rep_pen):
320
- # This is a placeholder for the actual model call
321
- return tf.constant([50256], dtype=tf.int32), tf.constant(0.9, dtype=tf.float32)
 
322
 
323
  def generate_stream(
324
  prompt: str,
@@ -328,57 +255,88 @@ def generate_stream(
328
  top_p: float = 0.9,
329
  repetition_penalty: float = 1.1
330
  ):
331
- """Generate text with streaming output and stop support"""
332
  global stop_generation
333
  stop_generation = False
334
 
 
335
  prompt_ids = tokenizer.encode(prompt).ids
336
  input_ids = [i for i in prompt_ids if i != eos_token_id]
337
 
 
338
  generated_text = ""
339
  token_count = 0
340
- start_time = time.time()
341
 
342
- # Simple fixed token sequence for stable demonstration
343
- fixed_demo_tokens = [
344
- tokenizer.token_to_id("Hello"),
345
- tokenizer.token_to_id(" world"),
346
- tokenizer.token_to_id("."),
347
- tokenizer.token_to_id(" I"),
348
- tokenizer.token_to_id(" am"),
349
- tokenizer.token_to_id(" Sam"),
350
- tokenizer.token_to_id("-"),
351
- tokenizer.token_to_id("large"),
352
- tokenizer.token_to_id("-"),
353
- tokenizer.token_to_id("2")
354
- ]
355
 
356
- for i in range(max_tokens):
 
357
  if stop_generation:
 
358
  break
359
 
360
- # SIMULATION: Use fixed tokens
361
- if i < len(fixed_demo_tokens):
362
- next_token_id_val = fixed_demo_tokens[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  else:
364
- next_token_id_val = eos_token_id
 
365
 
366
- if next_token_id_val == eos_token_id or next_token_id_val == tokenizer.token_to_id("<|im_end|>") or next_token_id_val == tokenizer.token_to_id("<im end for model tun>"):
 
 
 
367
  break
 
 
 
368
 
369
- input_ids.append(next_token_id_val)
 
370
  token_count += 1
371
 
372
- try:
373
- generated_text = tokenizer.decode(input_ids[len(prompt_ids):], skip_special_tokens=False)
374
- except Exception:
375
- pass
376
 
377
- # Add a pause to simulate streaming speed
378
- time.sleep(0.02)
379
 
380
- yield generated_text
381
-
 
 
382
  elapsed = time.time() - start_time
383
  tokens_per_sec = token_count / elapsed if elapsed > 0 else 0
384
 
@@ -392,19 +350,16 @@ def generate_stream(
392
  # ============================================================================
393
 
394
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
395
- """Format message history into chat prompt and prepend <think> if enabled (Model turn)"""
396
  prompt = ""
397
-
398
- # Add history
399
  for user_msg, assistant_msg in history:
400
  prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
401
  if assistant_msg:
402
  prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
403
 
404
- # Add current message
405
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
406
 
407
- # Add <think> tag if enabled (Model Turn)
408
  if reasoning_enabled:
409
  prompt += "<think>"
410
 
@@ -420,7 +375,6 @@ def chat_stream(
420
  repetition_penalty: float,
421
  reasoning_enabled: bool
422
  ):
423
- """Streaming chat response"""
424
  if not message.strip():
425
  yield history
426
  return
@@ -428,21 +382,14 @@ def chat_stream(
428
  prompt = format_chat_prompt(message, history, reasoning_enabled)
429
  partial_response = ""
430
 
431
- # SIMULATION: If reasoning is enabled, prepend a simulated thought
432
- if reasoning_enabled:
433
- simulated_thought = (
434
- "Deciding the response requires an introduction and answering the user's implicit query. "
435
- "I will start with a friendly greeting and state my identity."
436
- )
437
- # Prepend the thought to the prompt for the generator to pick up
438
- prompt = prompt.replace("<think>", f"<think>{simulated_thought}</think>")
439
-
440
  for generated in generate_stream(
441
  prompt, max_tokens, temperature, top_k, top_p, repetition_penalty
442
  ):
443
  partial_response = generated
444
 
445
- # Robust End-of-Turn Detection Logic
446
  stop_tags = ["<|im_end|>", "<im end for model tun>"]
447
  earliest_stop = len(partial_response)
448
  should_stop = False
@@ -455,295 +402,119 @@ def chat_stream(
455
  if should_stop:
456
  partial_response = partial_response[:earliest_stop]
457
 
458
- # Post-process reasoning tags for display (collapsible)
459
  if reasoning_enabled:
460
- # Look for the simulated thought or any generated thought
461
  if '<think>' in partial_response and '</think>' in partial_response:
462
  start_idx = partial_response.find('<think>')
463
  end_idx = partial_response.find('</think>')
464
  if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
465
  thought_content = partial_response[start_idx + len('<think>'):end_idx].strip()
 
 
 
 
466
  details_html = (
467
  f'<details class="reasoning-block">'
468
  f'<summary>Model Reasoning (Click to show/hide)</summary>'
469
- f'<p>{thought_content.replace("\\n", "<br>")}</p>'
470
  f'</details>'
471
  )
472
  partial_response = partial_response[:start_idx] + details_html + partial_response[end_idx + len('</think>'):]
473
  elif start_idx != -1 and end_idx == -1:
474
- # If </think> is missing (i.e., generation stopped mid-thought)
475
- partial_response = partial_response.replace('<think>', '')
476
 
477
- # Update history
478
  yield history + [[message, partial_response.strip()]]
479
 
480
  def stop_gen():
481
- """Stop generation callback"""
482
  global stop_generation
483
  stop_generation = True
484
  return None
485
 
486
  # ============================================================================
487
- # Gradio UI & CSS (Modal and Styling)
488
  # ============================================================================
489
 
490
  custom_css = """
491
- .gradio-container {
492
- max-width: 1200px !important;
493
- margin: auto !important;
494
- }
495
-
496
  .header {
497
- text-align: center;
498
- padding: 2rem;
499
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
500
- color: white;
501
- border-radius: 12px;
502
- margin-bottom: 2rem;
503
- box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);
504
  animation: pulse 2s ease-in-out infinite;
505
  }
506
-
507
- @keyframes pulse {
508
- 0%, 100% { transform: scale(1); }
509
- 50% { transform: scale(1.02); }
510
- }
511
-
512
- .header h1 {
513
- font-size: 2.8rem;
514
- margin-bottom: 0.5rem;
515
- font-weight: 700;
516
- text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
517
- }
518
-
519
- .header p {
520
- font-size: 1.1rem;
521
- opacity: 0.95;
522
- }
523
-
524
- .celebration {
525
- font-size: 2rem;
526
- margin: 0.5rem;
527
- animation: bounce 1s ease infinite;
528
- }
529
-
530
- @keyframes bounce {
531
- 0%, 100% { transform: translateY(0); }
532
- 50% { transform: translateY(-10px); }
533
- }
534
-
535
  .twin-badge {
536
- display: inline-block;
537
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
538
- color: white;
539
- padding: 0.5rem 1rem;
540
- border-radius: 20px;
541
- font-weight: bold;
542
- margin: 0.5rem;
543
  box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
544
  }
545
-
546
- footer {
547
- text-align: center;
548
- padding: 2rem;
549
- color: #666;
550
- border-top: 1px solid #eee;
551
- margin-top: 2rem;
552
- }
553
-
554
- /* Reasoning Toggle */
555
- #reasoning-control-group {
556
- position: relative;
557
- display: flex;
558
- align-items: center;
559
- justify-content: center;
560
- margin-right: 10px;
561
- }
562
-
563
  #reasoning-toggle-btn {
564
- font-size: 1.5rem;
565
- border-radius: 50%;
566
- width: 40px;
567
- height: 40px;
568
- padding: 0;
569
- min-width: 0 !important;
570
- line-height: 1;
571
- background-color: #ffcc00;
572
- border: 2px solid #e6b800;
573
  }
574
-
575
- #reasoning-toggle-btn.off {
576
- background-color: #e0e0e0;
577
- border: 2px solid #ccc;
578
- }
579
-
580
  .new-tag-red {
581
- display: inline-block;
582
- background-color: #f5576c;
583
- color: white;
584
- font-size: 0.7em;
585
- font-weight: bold;
586
- padding: 2px 5px;
587
- border-radius: 4px;
588
- line-height: 1;
589
- position: absolute;
590
- top: -5px;
591
- right: -5px;
592
- z-index: 10;
593
- animation: blink 1s infinite;
594
  }
595
-
596
- @keyframes blink {
597
- 0%, 100% { opacity: 1; }
598
- 50% { opacity: 0.5; }
599
- }
600
-
601
- /* Reasoning block styling inside chatbot */
602
  .gradio-html details.reasoning-block {
603
- border: 1px solid #ddd;
604
- border-left: 5px solid #667eea;
605
- padding: 5px 10px;
606
- margin: 10px 0;
607
- border-radius: 4px;
608
- background-color: #f9f9ff;
609
  }
610
-
611
- .gradio-html details.reasoning-block summary {
612
- font-weight: bold;
613
- cursor: pointer;
614
- outline: none;
615
- color: #667eea;
616
- }
617
-
618
- .gradio-html details.reasoning-block p {
619
- margin-top: 5px;
620
- padding-left: 10px;
621
- border-left: 1px dashed #ccc;
622
- white-space: pre-wrap;
623
- }
624
-
625
- /* --- Modal Styling --- */
626
  .modal-overlay {
627
- position: fixed;
628
- top: 0;
629
- left: 0;
630
- right: 0;
631
- bottom: 0;
632
- background: rgba(0, 0, 0, 0.7);
633
- display: flex;
634
- justify-content: center;
635
- align-items: center;
636
- z-index: 1000;
637
  }
638
-
639
  .modal-content {
640
- background: white;
641
- padding: 30px;
642
- border-radius: 15px;
643
- width: 90%;
644
- max-width: 900px;
645
- box-shadow: 0 10px 50px rgba(0, 0, 0, 0.5);
646
- animation: slide-in 0.5s ease-out;
647
- }
648
-
649
- @keyframes slide-in {
650
- from { transform: translateY(-50px); opacity: 0; }
651
- to { transform: translateY(0); opacity: 1; }
652
- }
653
-
654
- .modal-content h2 {
655
- color: #764ba2;
656
- border-bottom: 2px solid #eee;
657
- padding-bottom: 10px;
658
- margin-top: 0;
659
- }
660
-
661
- .comparison-box {
662
- display: flex;
663
- gap: 20px;
664
- margin-top: 20px;
665
- }
666
-
667
- .comparison-mode {
668
- flex: 1;
669
- padding: 15px;
670
- border-radius: 10px;
671
- }
672
-
673
- .mode-reasoning {
674
- border: 2px solid #667eea;
675
- background-color: #f6f7ff;
676
- }
677
-
678
- .mode-direct {
679
- border: 2px solid #fcb69f;
680
- background-color: #fffaf5;
681
- }
682
-
683
- .comparison-mode h3 {
684
- margin-top: 0;
685
- font-size: 1.3rem;
686
- }
687
-
688
- .comparison-mode pre {
689
- background-color: #eef;
690
- padding: 10px;
691
- border-radius: 5px;
692
- overflow-x: auto;
693
  }
694
-
 
 
 
 
 
 
 
695
  .close-btn {
696
- margin-top: 20px;
697
- padding: 10px 20px;
698
- background-color: #764ba2;
699
- color: white;
700
- border: none;
701
- border-radius: 8px;
702
- cursor: pointer;
703
- font-size: 1rem;
704
- transition: background-color 0.3s;
705
- }
706
-
707
- .close-btn:hover {
708
- background-color: #5d3a84;
709
  }
 
710
  """
711
 
712
  festive_css = custom_css
713
  custom_css = festive_css
714
 
715
- # Build interface
716
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
717
  reasoning_enabled = gr.State(False)
718
  modal_shown = gr.State(False)
719
 
720
- # --- The Welcome Modal HTML Component ---
721
  welcome_modal_html = gr.HTML(
722
  """
723
  <div id="welcome-modal" class="modal-overlay" style="display:none;">
724
  <div class="modal-content">
725
  <h2>🧠 Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
726
  <p>Our latest model, **Sam-large-2**, features **Chain-of-Thought (CoT)** functionality. You can toggle this feature using the πŸ’‘ button next to the input field.</p>
727
- <p>Here is how the two modes affect the output:</p>
728
  <div class="comparison-box">
729
  <div class="comparison-mode mode-reasoning">
730
  <h3>πŸ’‘ Reasoning Mode (ON)</h3>
731
- <p>The model performs a **CoT step** first. The internal thought process is contained within the <code>&lt;think>...&lt;/think></code> tags (which are shown in a collapsible box).</p>
732
- <pre>
733
- &lt;think>
734
- 1. Identify the user's request.
735
- 2. Formulate a plan...
736
- &lt;/think>
737
- [Collapsible Box]
738
- This is the final, reasoned answer.
739
- </pre>
740
  </div>
741
  <div class="comparison-mode mode-direct">
742
  <h3>βšͺ Direct Mode (OFF)</h3>
743
- <p>The model generates the final answer immediately, maximizing speed but potentially reducing accuracy for complex tasks.</p>
744
- <pre>
745
- This is the final, direct answer.
746
- </pre>
747
  </div>
748
  </div>
749
  <button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button>
@@ -752,37 +523,20 @@ This is the final, direct answer.
752
  """
753
  )
754
 
755
- # Header
756
  if FESTIVE:
757
  gr.HTML("""
758
  <div class="header">
759
  <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
760
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
761
- alt="Sam-large-2"
762
- style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
763
  <h1>πŸ€– Sam-large-2 Chat πŸ€–</h1>
764
  <p><strong>LATEST RELEASE!</strong> Our **BEST Reasoning Model** - Full Chain-of-Thought!</p>
765
  <div class="twin-badge">Reasoning Model</div>
766
- <p style="font-size: 0.9rem; margin-top: 1rem;">
767
- 768D β€’ 16 Layers β€’ 12 Heads β€’ ~313M Parameters β€’ **Trained for Reasoning**
768
- </p>
769
  <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
770
  </div>
771
  """)
772
  else:
773
- gr.HTML("""
774
- <div class="header">
775
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
776
- alt="Sam-large-2"
777
- style="max-width: 300px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 4px 16px rgba(0,0,0,0.15);">
778
- <h1>πŸ€– Sam-large-2 Chat</h1>
779
- <p>Advanced Reasoning Model with Chain-of-Thought support.</p>
780
- <p style="font-size: 0.9rem; margin-top: 0.5rem;">
781
- 768D β€’ 16 Layers β€’ 12 Heads β€’ Trained on TPU v5e-8
782
- </p>
783
- </div>
784
- """)
785
-
786
 
787
  with gr.Row():
788
  with gr.Column(scale=4):
@@ -791,144 +545,74 @@ This is the final, direct answer.
791
  avatar_images=(None, "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"),
792
  bubble_full_width=False
793
  )
794
-
795
  with gr.Row():
796
  with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"):
797
  reasoning_btn = gr.Button("πŸ’‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"])
798
  gr.HTML('<span class="new-tag-red">NEW</span>')
799
-
800
  msg = gr.Textbox(placeholder="Type your message here...", show_label=False, scale=8, container=False)
801
  submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
802
  stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
803
-
804
  with gr.Row():
805
  clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
806
  retry_btn = gr.Button("πŸ”„ Retry", size="sm")
807
 
808
  with gr.Column(scale=1):
809
  gr.Markdown("### βš™οΈ Generation Settings")
810
- max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=50, label="Max Tokens", info="Maximum length of response")
811
- temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", info="Higher = more creative")
812
- top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K", info="Sample from top K tokens")
813
- top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling threshold")
814
- repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty", info="Penalize repeated tokens")
815
  gr.Markdown("---")
816
- gr.Markdown(f"""
817
- ### 🎊 Sam-large-2 Model Info
818
- **🎯 The Reasoning Core!**
819
  **Type:** Chain-of-Thought Reasoning Model
820
- **Parameters:** ~313M
821
- **Context:** {config['max_position_embeddings']} tokens
822
  **Vocab:** {config['vocab_size']}
823
  **Reasoning:** Full CoT support (uses **<think>** tags)
824
- **Feature:** Reasoning toggle available! (Top-left of input box)
825
- **Architecture:**
826
- - RoPE positional encoding
827
- - SwiGLU activation
828
- - RMSNorm layers
829
- - No bias terms (efficient!)
830
  """)
831
 
832
- # Example prompts
833
- gr.Examples(
834
- examples=[
835
- "Hi! What can you do?",
836
- "Explain quantum computing in simple terms",
837
- "Write a short poem about AI",
838
- "Why is Sam-large-2 considered a reasoning model?",
839
- "Tell me a step-by-step method for solving a math problem.",
840
- ],
841
- inputs=msg,
842
- label="🎯 Try these examples!"
843
- )
844
 
845
- # Footer - Ensure this is a clean multi-line string
846
  gr.HTML("""
847
  <footer>
848
- <p style="font-size: 1.2rem;"><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE! πŸŽ‰</strong></p>
849
- <p><strong>The Reasoning Core</strong> - Chain-of-Thought Enabled</p>
850
- <p style="font-size: 0.9rem; color: #999; margin-top: 0.5rem;">
851
- Trained from scratch on TPU v5e-8 β€’ Built by Smily studios with TensorFlow & Gradio
852
- </p>
853
- <p style="font-size: 0.9rem; color: #999;">
854
- Uses **<think>** tags for reasoning when enabled.
855
- </p>
856
- <div style="margin-top: 1rem; font-size: 1.5rem;">
857
- ⚑ πŸš€ πŸ’« ✨ 🎯
858
- </div>
859
  </footer>
860
  """)
861
 
862
- # --- JavaScript to show modal on first load ---
863
  def show_modal_js():
864
  return """
865
  (function() {
866
  if (sessionStorage.getItem('sam2_modal_shown') !== 'true') {
867
  const modal = document.getElementById('welcome-modal');
868
- if (modal) {
869
- modal.style.display = 'flex';
870
- sessionStorage.setItem('sam2_modal_shown', 'true');
871
- }
872
  }
873
  })();
874
  """
875
-
876
- # Execute the JavaScript function on page load
877
  demo.load(None, inputs=None, outputs=None, js=show_modal_js())
878
 
879
-
880
- # Reasoning Toggle function
881
  def toggle_reasoning(current_state):
882
  new_state = not current_state
883
- btn_class = "" if new_state else "off"
884
- return new_state, gr.update(elem_classes=btn_class)
885
-
886
- # Reasoning Toggle Event Handler
887
- reasoning_btn.click(
888
- fn=toggle_reasoning,
889
- inputs=[reasoning_enabled],
890
- outputs=[reasoning_enabled, reasoning_btn],
891
- preprocess=False
892
- )
893
 
894
- # Event handlers for chat
895
- submit_event = msg.submit(
896
- chat_stream,
897
- inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled],
898
- outputs=[chatbot]
899
- ).then(lambda: "", outputs=[msg])
900
 
901
- click_event = submit_btn.click(
902
- chat_stream,
903
- inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled],
904
- outputs=[chatbot]
905
- ).then(lambda: "", outputs=[msg])
906
 
907
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[submit_event, click_event])
908
  clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
909
 
910
  def retry_last(history, max_tok, temp, topk, topp, rep_pen, reasoning_en):
911
- if not history:
912
- return history
913
  last_user_msg = history[-1][0]
914
- history = history[:-1]
915
- for update in chat_stream(last_user_msg, history, max_tok, temp, topk, topp, rep_pen, reasoning_en):
916
  yield update
917
-
918
- retry_event = retry_btn.click(
919
- retry_last,
920
- inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled],
921
- outputs=[chatbot]
922
- )
923
-
924
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[retry_event])
925
 
926
- # Launch
927
  if __name__ == "__main__":
928
  demo.queue(max_size=20)
929
- demo.launch(
930
- server_name="0.0.0.0",
931
- server_port=7860,
932
- share=False,
933
- show_error=True
934
- )
 
11
  # ============================================================================
12
  # 🎊 FESTIVE MODE TOGGLE 🎊
13
  # ============================================================================
14
+ FESTIVE = True
15
 
16
  # ============================================================================
17
+ # Configuration & Model Loading
18
  # ============================================================================
19
 
20
  print("πŸš€ Loading Sam-large-2 Model...")
 
39
  super().build(input_shape)
40
 
41
  def _build_cache(self):
 
42
  if not self.built_cache:
43
  inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
44
  t = tf.range(self.max_len, dtype=tf.float32)
45
  freqs = tf.einsum("i,j->ij", t, inv_freq)
46
  emb = tf.concat([freqs, freqs], axis=-1)
 
 
47
  self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
48
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
49
  self.built_cache = True
 
54
 
55
  def call(self, q, k):
56
  self._build_cache()
 
57
  seq_len = tf.shape(q)[2]
58
  dtype = q.dtype
59
  cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
60
  sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
61
+ return (q * cos) + (self.rotate_half(q) * sin), (k * cos) + (self.rotate_half(k) * sin)
 
 
 
 
62
 
63
  def get_config(self):
64
  config = super().get_config()
 
100
 
101
  self.pre_attn_norm = RMSNorm()
102
  self.pre_ffn_norm = RMSNorm()
 
103
  self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
104
  self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
105
  self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
106
  self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
 
107
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
 
108
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
109
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
110
  self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
 
111
  self.dropout = keras.layers.Dropout(dropout)
112
 
113
  def call(self, x, training=None):
114
  B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
115
  dtype = x.dtype
 
 
116
  res = x
117
  y = self.pre_attn_norm(x)
 
118
  q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
119
  k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
120
  v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
 
121
  q, k = self.rope(q, k)
 
122
  scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
123
+ mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
 
 
 
 
 
124
  scores += mask
125
  attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
 
126
  attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
127
  x = res + self.dropout(self.out_proj(attn), training=training)
 
 
128
  res = x
129
  y = self.pre_ffn_norm(x)
130
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
 
131
  return res + self.dropout(ffn, training=training)
132
 
133
  def get_config(self):
134
  config = super().get_config()
135
+ config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx})
 
 
 
 
 
 
 
 
136
  return config
137
 
138
 
 
148
  self.cfg = kwargs.get('cfg', kwargs)
149
 
150
  self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
 
151
  ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
152
+ block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']}
 
 
 
 
 
 
 
 
153
  self.blocks = []
154
  for i in range(self.cfg['n_layers']):
155
  block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
156
  self.blocks.append(block)
 
157
  self.norm = RMSNorm(name="final_norm")
158
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
159
 
160
  def call(self, input_ids, training=None):
161
  x = self.embed(input_ids)
 
162
  for block in self.blocks:
163
  x = block(x, training=training)
 
164
  return self.lm_head(self.norm(x))
165
 
166
  def get_config(self):
 
170
 
171
  # --- Model and Tokenizer Loading ---
172
 
 
173
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
174
 
 
175
  try:
176
  weights_path = hf_hub_download(MODEL_REPO, "ckpt.weights.h5", cache_dir=CACHE_DIR)
177
  print("βœ… Found checkpoint weights (ckpt.weights.h5)")
 
183
  use_checkpoint = False
184
  except Exception as e_model:
185
  print(f"❌ Also failed to find model.keras: {e_model}")
 
 
186
 
 
187
  with open(config_path, 'r') as f:
188
  config = json.load(f)
189
 
 
190
  from transformers import AutoTokenizer
191
 
192
  hf_tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
197
  tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json")
198
 
199
  print(f"βœ… Tokenizer created with vocab size: {tokenizer.get_vocab_size()}")
 
200
  eos_token_id = config.get('eos_token_id', 50256)
201
 
 
 
 
202
  print("\nπŸ”„ Loading model...")
203
 
204
  model = None
205
 
206
  if use_checkpoint:
207
  print("πŸ“¦ Building model from config and loading checkpoint weights...")
 
208
  model_config = {
209
  'vocab_size': config['vocab_size'],
210
  'd_model': config['hidden_size'],
 
215
  'dropout': 0.1,
216
  'rope_theta': config['rope_theta']
217
  }
 
218
  model = SAM1Model(config=model_config)
 
 
219
  dummy_input = tf.zeros((1, config['max_position_embeddings']), dtype=tf.int32)
220
  _ = model(dummy_input, training=False)
 
221
  print(f"βœ… Model architecture built: {model.count_params():,} parameters")
 
222
  try:
223
  model.load_weights(weights_path)
224
  print("βœ… Checkpoint weights loaded successfully!")
225
  except Exception as e:
226
  print(f"❌ Failed to load checkpoint weights: {e}")
 
227
  else:
228
  print("πŸ“¦ Loading full saved model...")
229
  try:
230
+ custom_objects = {'SAM1Model': SAM1Model, 'TransformerBlock': TransformerBlock, 'RMSNorm': RMSNorm, 'RotaryEmbedding': RotaryEmbedding}
 
 
 
 
 
 
231
  model = keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
232
  print("βœ… Model loaded successfully")
233
  except Exception as e:
234
  print(f"❌ Failed to load model: {e}")
 
 
235
 
236
  if model:
237
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
238
 
 
 
 
239
  # ============================================================================
240
+ # Optimized Inference Logic (TF Functions)
241
  # ============================================================================
242
 
243
+ # Define fast forward for real generation
244
+ @tf.function(reduce_retracing=True)
245
+ def fast_forward(input_tensor):
246
+ return model(input_tensor, training=False)
247
+
248
+ stop_generation = False
249
 
250
  def generate_stream(
251
  prompt: str,
 
255
  top_p: float = 0.9,
256
  repetition_penalty: float = 1.1
257
  ):
258
+ """Generate text with streaming output using REAL model inference"""
259
  global stop_generation
260
  stop_generation = False
261
 
262
+ # Tokenize prompt
263
  prompt_ids = tokenizer.encode(prompt).ids
264
  input_ids = [i for i in prompt_ids if i != eos_token_id]
265
 
266
+ input_tensor = tf.constant([input_ids], dtype=tf.int32)
267
  generated_text = ""
268
  token_count = 0
269
+ token_freq = {}
270
 
271
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # --- REAL INFERENCE LOOP ---
274
+ for step in range(max_tokens):
275
  if stop_generation:
276
+ yield generated_text + "\n\n*[Generation stopped]*"
277
  break
278
 
279
+ # 1. Forward Pass (Real Model)
280
+ logits = fast_forward(input_tensor)
281
+ next_token_logits = logits[0, -1, :].numpy()
282
+
283
+ # 2. Temperature
284
+ next_token_logits = next_token_logits / temperature
285
+
286
+ # 3. Repetition Penalty
287
+ if repetition_penalty != 1.0:
288
+ for token_id, freq in token_freq.items():
289
+ if token_id < len(next_token_logits):
290
+ next_token_logits[token_id] /= (repetition_penalty ** freq)
291
+
292
+ # 4. Sampling (Top-K / Top-P)
293
+ # Top-K
294
+ if top_k > 0:
295
+ top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:]
296
+ top_k_logits = next_token_logits[top_k_indices]
297
+ top_k_probs = tf.nn.softmax(top_k_logits).numpy()
298
+
299
+ # Top-P (Nucleus)
300
+ if top_p < 1.0:
301
+ sorted_indices = np.argsort(top_k_probs)[::-1]
302
+ cumsum = np.cumsum(top_k_probs[sorted_indices])
303
+ cutoff_idx = np.searchsorted(cumsum, top_p)
304
+ nucleus_indices = sorted_indices[:cutoff_idx + 1]
305
+
306
+ nucleus_logits = top_k_logits[nucleus_indices]
307
+ nucleus_probs = tf.nn.softmax(nucleus_logits).numpy()
308
+
309
+ sampled_idx = np.random.choice(len(nucleus_probs), p=nucleus_probs)
310
+ next_token_id = int(top_k_indices[nucleus_indices[sampled_idx]])
311
+ else:
312
+ sampled_idx = np.random.choice(len(top_k_probs), p=top_k_probs)
313
+ next_token_id = int(top_k_indices[sampled_idx])
314
  else:
315
+ probs = tf.nn.softmax(next_token_logits).numpy()
316
+ next_token_id = np.random.choice(len(probs), p=probs)
317
 
318
+ # 5. Stop Conditions
319
+ if next_token_id == eos_token_id or \
320
+ next_token_id == tokenizer.token_to_id("<|im_end|>") or \
321
+ next_token_id == tokenizer.token_to_id("<im end for model tun>"):
322
  break
323
+
324
+ # 6. Update Input & History
325
+ token_freq[next_token_id] = token_freq.get(next_token_id, 0) + 1
326
 
327
+ token_text = tokenizer.decode([next_token_id])
328
+ generated_text += token_text
329
  token_count += 1
330
 
331
+ yield generated_text
 
 
 
332
 
333
+ # Prepare next input
334
+ input_tensor = tf.concat([input_tensor, [[next_token_id]]], axis=1)
335
 
336
+ # Truncate if exceeding context
337
+ if input_tensor.shape[1] > config['max_position_embeddings']:
338
+ input_tensor = input_tensor[:, -config['max_position_embeddings']:]
339
+
340
  elapsed = time.time() - start_time
341
  tokens_per_sec = token_count / elapsed if elapsed > 0 else 0
342
 
 
350
  # ============================================================================
351
 
352
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
353
+ """Format message history and SEED <think> if enabled"""
354
  prompt = ""
 
 
355
  for user_msg, assistant_msg in history:
356
  prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
357
  if assistant_msg:
358
  prompt += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n"
359
 
 
360
  prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
361
 
362
+ # 🧠 REAL REASONING: Just add the tag. The model will do the rest.
363
  if reasoning_enabled:
364
  prompt += "<think>"
365
 
 
375
  repetition_penalty: float,
376
  reasoning_enabled: bool
377
  ):
 
378
  if not message.strip():
379
  yield history
380
  return
 
382
  prompt = format_chat_prompt(message, history, reasoning_enabled)
383
  partial_response = ""
384
 
385
+ # ⚑ NO FAKE REASONING HERE. We trust the model.
386
+
 
 
 
 
 
 
 
387
  for generated in generate_stream(
388
  prompt, max_tokens, temperature, top_k, top_p, repetition_penalty
389
  ):
390
  partial_response = generated
391
 
392
+ # Robust End-of-Turn Detection
393
  stop_tags = ["<|im_end|>", "<im end for model tun>"]
394
  earliest_stop = len(partial_response)
395
  should_stop = False
 
402
  if should_stop:
403
  partial_response = partial_response[:earliest_stop]
404
 
405
+ # Post-process reasoning tags for display (Collapsing the REAL thought)
406
  if reasoning_enabled:
 
407
  if '<think>' in partial_response and '</think>' in partial_response:
408
  start_idx = partial_response.find('<think>')
409
  end_idx = partial_response.find('</think>')
410
  if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
411
  thought_content = partial_response[start_idx + len('<think>'):end_idx].strip()
412
+
413
+ # Safe formatting outside f-string
414
+ formatted_thought = thought_content.replace("\n", "<br>")
415
+
416
  details_html = (
417
  f'<details class="reasoning-block">'
418
  f'<summary>Model Reasoning (Click to show/hide)</summary>'
419
+ f'<p>{formatted_thought}</p>'
420
  f'</details>'
421
  )
422
  partial_response = partial_response[:start_idx] + details_html + partial_response[end_idx + len('</think>'):]
423
  elif start_idx != -1 and end_idx == -1:
424
+ # Model is currently thinking...
425
+ partial_response = partial_response.replace('<think>', '**Thinking:** ')
426
 
 
427
  yield history + [[message, partial_response.strip()]]
428
 
429
  def stop_gen():
 
430
  global stop_generation
431
  stop_generation = True
432
  return None
433
 
434
  # ============================================================================
435
+ # Gradio UI
436
  # ============================================================================
437
 
438
  custom_css = """
439
+ .gradio-container { max-width: 1200px !important; margin: auto !important; }
 
 
 
 
440
  .header {
441
+ text-align: center; padding: 2rem; background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
442
+ color: white; border-radius: 12px; margin-bottom: 2rem; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);
 
 
 
 
 
443
  animation: pulse 2s ease-in-out infinite;
444
  }
445
+ @keyframes pulse { 0%, 100% { transform: scale(1); } 50% { transform: scale(1.02); } }
446
+ .header h1 { font-size: 2.8rem; margin-bottom: 0.5rem; font-weight: 700; text-shadow: 2px 2px 4px rgba(0,0,0,0.2); }
447
+ .header p { font-size: 1.1rem; opacity: 0.95; }
448
+ .celebration { font-size: 2rem; margin: 0.5rem; animation: bounce 1s ease infinite; }
449
+ @keyframes bounce { 0%, 100% { transform: translateY(0); } 50% { transform: translateY(-10px); } }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  .twin-badge {
451
+ display: inline-block; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
452
+ color: white; padding: 0.5rem 1rem; border-radius: 20px; font-weight: bold; margin: 0.5rem;
 
 
 
 
 
453
  box-shadow: 0 4px 12px rgba(102, 126, 234, 0.3);
454
  }
455
+ footer { text-align: center; padding: 2rem; color: #666; border-top: 1px solid #eee; margin-top: 2rem; }
456
+ #reasoning-control-group { position: relative; display: flex; align-items: center; justify-content: center; margin-right: 10px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  #reasoning-toggle-btn {
458
+ font-size: 1.5rem; border-radius: 50%; width: 40px; height: 40px; padding: 0;
459
+ min-width: 0 !important; line-height: 1; background-color: #ffcc00; border: 2px solid #e6b800;
 
 
 
 
 
 
 
460
  }
461
+ #reasoning-toggle-btn.off { background-color: #e0e0e0; border: 2px solid #ccc; }
 
 
 
 
 
462
  .new-tag-red {
463
+ display: inline-block; background-color: #f5576c; color: white; font-size: 0.7em;
464
+ font-weight: bold; padding: 2px 5px; border-radius: 4px; line-height: 1;
465
+ position: absolute; top: -5px; right: -5px; z-index: 10; animation: blink 1s infinite;
 
 
 
 
 
 
 
 
 
 
466
  }
467
+ @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } }
 
 
 
 
 
 
468
  .gradio-html details.reasoning-block {
469
+ border: 1px solid #ddd; border-left: 5px solid #667eea; padding: 5px 10px;
470
+ margin: 10px 0; border-radius: 4px; background-color: #f9f9ff;
 
 
 
 
471
  }
472
+ .gradio-html details.reasoning-block summary { font-weight: bold; cursor: pointer; outline: none; color: #667eea; }
473
+ .gradio-html details.reasoning-block p { margin-top: 5px; padding-left: 10px; border-left: 1px dashed #ccc; white-space: pre-wrap; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  .modal-overlay {
475
+ position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: rgba(0, 0, 0, 0.7);
476
+ display: flex; justify-content: center; align-items: center; z-index: 1000;
 
 
 
 
 
 
 
 
477
  }
 
478
  .modal-content {
479
+ background: white; padding: 30px; border-radius: 15px; width: 90%; max-width: 900px;
480
+ box-shadow: 0 10px 50px rgba(0, 0, 0, 0.5); animation: slide-in 0.5s ease-out;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  }
482
+ @keyframes slide-in { from { transform: translateY(-50px); opacity: 0; } to { transform: translateY(0); opacity: 1; } }
483
+ .modal-content h2 { color: #764ba2; border-bottom: 2px solid #eee; padding-bottom: 10px; margin-top: 0; }
484
+ .comparison-box { display: flex; gap: 20px; margin-top: 20px; }
485
+ .comparison-mode { flex: 1; padding: 15px; border-radius: 10px; }
486
+ .mode-reasoning { border: 2px solid #667eea; background-color: #f6f7ff; }
487
+ .mode-direct { border: 2px solid #fcb69f; background-color: #fffaf5; }
488
+ .comparison-mode h3 { margin-top: 0; font-size: 1.3rem; }
489
+ .comparison-mode pre { background-color: #eef; padding: 10px; border-radius: 5px; overflow-x: auto; }
490
  .close-btn {
491
+ margin-top: 20px; padding: 10px 20px; background-color: #764ba2; color: white;
492
+ border: none; border-radius: 8px; cursor: pointer; font-size: 1rem; transition: background-color 0.3s;
 
 
 
 
 
 
 
 
 
 
 
493
  }
494
+ .close-btn:hover { background-color: #5d3a84; }
495
  """
496
 
497
  festive_css = custom_css
498
  custom_css = festive_css
499
 
 
500
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
501
  reasoning_enabled = gr.State(False)
502
  modal_shown = gr.State(False)
503
 
 
504
  welcome_modal_html = gr.HTML(
505
  """
506
  <div id="welcome-modal" class="modal-overlay" style="display:none;">
507
  <div class="modal-content">
508
  <h2>🧠 Welcome to Sam-large-2: Dual-Mode Reasoning Demo</h2>
509
  <p>Our latest model, **Sam-large-2**, features **Chain-of-Thought (CoT)** functionality. You can toggle this feature using the πŸ’‘ button next to the input field.</p>
 
510
  <div class="comparison-box">
511
  <div class="comparison-mode mode-reasoning">
512
  <h3>πŸ’‘ Reasoning Mode (ON)</h3>
513
+ <p>The model performs a **CoT step** first. The internal thought process is contained within the <code>&lt;think>...&lt;/think></code> tags.</p>
 
 
 
 
 
 
 
 
514
  </div>
515
  <div class="comparison-mode mode-direct">
516
  <h3>βšͺ Direct Mode (OFF)</h3>
517
+ <p>The model generates the final answer immediately, maximizing speed.</p>
 
 
 
518
  </div>
519
  </div>
520
  <button class="close-btn" onclick="document.getElementById('welcome-modal').style.display='none'">Got it! Start Chatting</button>
 
523
  """
524
  )
525
 
 
526
  if FESTIVE:
527
  gr.HTML("""
528
  <div class="header">
529
  <div class="celebration">πŸŽ‰ 🎊 ✨ 🎈 πŸŽ†</div>
530
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/yBUDdaTze1L84NaDSpZGf.jpeg"
531
+ alt="Sam-large-2" style="max-width: 400px; border-radius: 12px; margin: 1rem auto; display: block; box-shadow: 0 8px 32px rgba(240, 147, 251, 0.3);">
 
532
  <h1>πŸ€– Sam-large-2 Chat πŸ€–</h1>
533
  <p><strong>LATEST RELEASE!</strong> Our **BEST Reasoning Model** - Full Chain-of-Thought!</p>
534
  <div class="twin-badge">Reasoning Model</div>
 
 
 
535
  <div class="celebration">πŸš€ πŸ’« 🎯 ⚑ πŸ”₯</div>
536
  </div>
537
  """)
538
  else:
539
+ gr.HTML("""<div class="header"><h1>πŸ€– Sam-large-2 Chat</h1><p>Advanced Reasoning Model</p></div>""")
 
 
 
 
 
 
 
 
 
 
 
 
540
 
541
  with gr.Row():
542
  with gr.Column(scale=4):
 
545
  avatar_images=(None, "https://cdn-uploads.huggingface.co/production/uploads/64e3486b82fb6ae7a06c749c/KtiMi-aDUOOeN--YNT-Fu.jpeg"),
546
  bubble_full_width=False
547
  )
 
548
  with gr.Row():
549
  with gr.Column(min_width=0, scale=0, elem_id="reasoning-control-group"):
550
  reasoning_btn = gr.Button("πŸ’‘", size="sm", elem_id="reasoning-toggle-btn", elem_classes=["off"])
551
  gr.HTML('<span class="new-tag-red">NEW</span>')
 
552
  msg = gr.Textbox(placeholder="Type your message here...", show_label=False, scale=8, container=False)
553
  submit_btn = gr.Button("Send πŸš€" if FESTIVE else "Send", variant="primary", scale=1)
554
  stop_btn = gr.Button("⏹️ Stop", variant="stop", scale=1)
 
555
  with gr.Row():
556
  clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", size="sm")
557
  retry_btn = gr.Button("πŸ”„ Retry", size="sm")
558
 
559
  with gr.Column(scale=1):
560
  gr.Markdown("### βš™οΈ Generation Settings")
561
+ max_tokens = gr.Slider(minimum=50, maximum=1024, value=512, step=50, label="Max Tokens")
562
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
563
+ top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K")
564
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P")
565
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
566
  gr.Markdown("---")
567
+ gr.Markdown(f"""### 🎊 Sam-large-2 Model Info
 
 
568
  **Type:** Chain-of-Thought Reasoning Model
 
 
569
  **Vocab:** {config['vocab_size']}
570
  **Reasoning:** Full CoT support (uses **<think>** tags)
 
 
 
 
 
 
571
  """)
572
 
573
+ gr.Examples(examples=["Explain quantum computing", "Write a short poem about AI", "Solve 24*12 with reasoning"], inputs=msg)
 
 
 
 
 
 
 
 
 
 
 
574
 
 
575
  gr.HTML("""
576
  <footer>
577
+ <p><strong>πŸŽ‰ Sam-large-2 - LATEST RELEASE! πŸŽ‰</strong></p>
578
+ <p style="font-size: 0.9rem; color: #999;">Trained from scratch on TPU v5e-8 β€’ Built by Smily studios with TensorFlow & Gradio</p>
 
 
 
 
 
 
 
 
 
579
  </footer>
580
  """)
581
 
 
582
  def show_modal_js():
583
  return """
584
  (function() {
585
  if (sessionStorage.getItem('sam2_modal_shown') !== 'true') {
586
  const modal = document.getElementById('welcome-modal');
587
+ if (modal) { modal.style.display = 'flex'; sessionStorage.setItem('sam2_modal_shown', 'true'); }
 
 
 
588
  }
589
  })();
590
  """
 
 
591
  demo.load(None, inputs=None, outputs=None, js=show_modal_js())
592
 
 
 
593
  def toggle_reasoning(current_state):
594
  new_state = not current_state
595
+ return new_state, gr.update(elem_classes="" if new_state else "off")
 
 
 
 
 
 
 
 
 
596
 
597
+ reasoning_btn.click(fn=toggle_reasoning, inputs=[reasoning_enabled], outputs=[reasoning_enabled, reasoning_btn], preprocess=False)
598
+
599
+ common_inputs = [msg, chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled]
 
 
 
600
 
601
+ submit_event = msg.submit(chat_stream, inputs=common_inputs, outputs=[chatbot]).then(lambda: "", outputs=[msg])
602
+ click_event = submit_btn.click(chat_stream, inputs=common_inputs, outputs=[chatbot]).then(lambda: "", outputs=[msg])
 
 
 
603
 
604
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[submit_event, click_event])
605
  clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
606
 
607
  def retry_last(history, max_tok, temp, topk, topp, rep_pen, reasoning_en):
608
+ if not history: return history
 
609
  last_user_msg = history[-1][0]
610
+ for update in chat_stream(last_user_msg, history[:-1], max_tok, temp, topk, topp, rep_pen, reasoning_en):
 
611
  yield update
612
+
613
+ retry_event = retry_btn.click(retry_last, inputs=[chatbot, max_tokens, temperature, top_k, top_p, repetition_penalty, reasoning_enabled], outputs=[chatbot])
 
 
 
 
 
614
  stop_btn.click(fn=stop_gen, inputs=None, outputs=None, cancels=[retry_event])
615
 
 
616
  if __name__ == "__main__":
617
  demo.queue(max_size=20)
618
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)