Keeby-smilyai commited on
Commit
c7eedeb
Β·
verified Β·
1 Parent(s): d9e88a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -194
app.py CHANGED
@@ -24,12 +24,7 @@ import time
24
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
25
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
26
 
27
- # Enable XLA JIT compilation for CPU
28
- try:
29
- tf.config.optimizer.set_jit(True)
30
- print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
31
- except:
32
- print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
33
 
34
  # ============================================================================
35
  # 🎊 FESTIVE MODE TOGGLE 🎊
@@ -46,48 +41,48 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
46
  CACHE_DIR = "./model_cache"
47
 
48
  # ============================================================================
49
- # Model Architecture - MATCHES CHECKPOINT STRUCTURE
50
  # ============================================================================
51
 
52
  @keras.saving.register_keras_serializable()
53
  class RotaryEmbedding(keras.layers.Layer):
54
- """RoPE with cache built during layer build phase."""
55
-
56
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
57
  super().__init__(**kwargs)
58
  self.dim = dim
59
  self.max_len = max_len
60
  self.theta = theta
 
 
 
61
 
62
  def build(self, input_shape):
63
- # Pre-compute RoPE cache as numpy arrays during build
64
- inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
65
- t = np.arange(self.max_len, dtype=np.float32)
66
- freqs = np.outer(t, inv_freq)
67
- emb = np.concatenate([freqs, freqs], axis=-1)
68
-
69
- # Store as numpy arrays - will be converted to tensors in call()
70
- self._cos_cached = np.cos(emb).astype(np.float32)
71
- self._sin_cached = np.sin(emb).astype(np.float32)
72
-
73
  super().build(input_shape)
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def call(self, q, k, offset=0):
76
  """Apply rotary embeddings with position offset for KV-cache."""
 
77
  seq_len = tf.shape(q)[2]
78
  dtype = q.dtype
79
 
80
- # Slice the pre-computed values
81
- cos = tf.cast(self._cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
82
- sin = tf.cast(self._sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
83
 
84
- # Fused rotate_half
85
- x1_q, x2_q = tf.split(q, 2, axis=-1)
86
- x1_k, x2_k = tf.split(k, 2, axis=-1)
87
-
88
- q_embed = (q * cos) + (tf.concat([-x2_q, x1_q], axis=-1) * sin)
89
- k_embed = (k * cos) + (tf.concat([-x2_k, x1_k], axis=-1) * sin)
90
-
91
  return q_embed, k_embed
92
 
93
  def get_config(self):
@@ -119,8 +114,6 @@ class RMSNorm(keras.layers.Layer):
119
 
120
  @keras.saving.register_keras_serializable()
121
  class TransformerBlock(keras.layers.Layer):
122
- """Transformer block - MATCHES ORIGINAL CHECKPOINT STRUCTURE."""
123
-
124
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
125
  super().__init__(**kwargs)
126
  self.d_model = d_model
@@ -131,37 +124,38 @@ class TransformerBlock(keras.layers.Layer):
131
  self.rope_theta = rope_theta
132
  self.head_dim = d_model // n_heads
133
  self.layer_idx = layer_idx
134
- self.scale = 1.0 / np.sqrt(self.head_dim)
135
 
136
  def build(self, input_shape):
137
- # MUST use same layer names as checkpoint
138
  self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
139
  self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
140
-
141
- # Separate Q, K, V projections (matches checkpoint)
142
  self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
143
  self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
144
  self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
145
  self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
146
-
147
  self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
148
-
149
- # Separate gate, up, down projections (matches checkpoint)
150
  self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
151
  self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
152
  self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
153
-
154
  self.dropout = keras.layers.Dropout(self.dropout_rate)
155
  super().build(input_shape)
156
 
157
  def call(self, x, training=None, past_kv=None, use_cache=False):
 
 
 
 
 
 
 
 
158
  B = tf.shape(x)[0]
159
  T = tf.shape(x)[1]
 
160
 
161
  res = x
162
  y = self.pre_attn_norm(x)
163
 
164
- # Separate Q, K, V projections
165
  q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
166
  q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
167
 
@@ -187,14 +181,15 @@ class TransformerBlock(keras.layers.Layer):
187
 
188
  new_kv = (k, v) if use_cache else None
189
 
190
- # Scaled dot-product attention
191
  full_len = tf.shape(k)[2]
192
- scores = tf.matmul(q, k, transpose_b=True) * self.scale
193
 
194
  # Causal mask
195
  q_positions = tf.range(past_len, past_len + T)
196
  k_positions = tf.range(full_len)
197
- mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
 
198
  scores = scores + mask[None, None, :, :]
199
 
200
  attn = tf.nn.softmax(scores, axis=-1)
@@ -204,10 +199,10 @@ class TransformerBlock(keras.layers.Layer):
204
 
205
  x = res + self.dropout(self.out_proj(attn_out), training=training)
206
 
207
- # FFN with SwiGLU
208
  res = x
209
  y = self.pre_ffn_norm(x)
210
- ffn = self.down_proj(tf.nn.silu(self.gate_proj(y)) * self.up_proj(y))
211
  output = res + self.dropout(ffn, training=training)
212
 
213
  return output, new_kv
@@ -255,6 +250,14 @@ class SAM1Model(keras.Model):
255
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
256
 
257
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
 
 
 
 
 
 
 
 
258
  x = self.embed(input_ids)
259
 
260
  new_past_kv = [] if use_cache else None
@@ -274,62 +277,6 @@ class SAM1Model(keras.Model):
274
  return base_config
275
 
276
 
277
- # ============================================================================
278
- # Optimized Sampling
279
- # ============================================================================
280
-
281
- class FastSampler:
282
- """Vectorized sampler for faster token selection."""
283
-
284
- def __init__(self, vocab_size):
285
- self.vocab_size = vocab_size
286
- self.rng = np.random.default_rng()
287
-
288
- def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
289
- """Optimized sampling with vectorized operations."""
290
- logits = logits.copy()
291
-
292
- if temperature != 1.0:
293
- logits = logits / temperature
294
-
295
- # Vectorized repetition penalty
296
- if repetition_penalty != 1.0 and token_freq:
297
- freq_tokens = np.array(list(token_freq.keys()), dtype=np.int32)
298
- freq_values = np.array(list(token_freq.values()), dtype=np.float32)
299
- valid_mask = freq_tokens < len(logits)
300
- freq_tokens = freq_tokens[valid_mask]
301
- freq_values = freq_values[valid_mask]
302
- if len(freq_tokens) > 0:
303
- logits[freq_tokens] /= np.power(repetition_penalty, freq_values)
304
-
305
- # Top-K filtering with partial sort
306
- if 0 < top_k < len(logits):
307
- top_k_indices = np.argpartition(logits, -top_k)[-top_k:]
308
- top_k_logits = logits[top_k_indices]
309
- else:
310
- top_k_indices = np.arange(len(logits))
311
- top_k_logits = logits
312
-
313
- # Stable softmax
314
- top_k_logits = top_k_logits - np.max(top_k_logits)
315
- exp_logits = np.exp(top_k_logits)
316
- top_k_probs = exp_logits / exp_logits.sum()
317
-
318
- # Top-P (nucleus) filtering
319
- if top_p < 1.0:
320
- sorted_idx = np.argsort(top_k_probs)[::-1]
321
- cumsum = np.cumsum(top_k_probs[sorted_idx])
322
- cutoff = np.searchsorted(cumsum, top_p) + 1
323
- nucleus_idx = sorted_idx[:cutoff]
324
- nucleus_probs = top_k_probs[nucleus_idx]
325
- nucleus_probs /= nucleus_probs.sum()
326
- sampled = self.rng.choice(len(nucleus_probs), p=nucleus_probs)
327
- return int(top_k_indices[nucleus_idx[sampled]])
328
- else:
329
- sampled = self.rng.choice(len(top_k_probs), p=top_k_probs)
330
- return int(top_k_indices[sampled])
331
-
332
-
333
  # --- Model and Tokenizer Loading ---
334
 
335
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
@@ -375,7 +322,7 @@ if use_checkpoint:
375
  'n_heads': config['num_attention_heads'],
376
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
377
  'max_len': config['max_position_embeddings'],
378
- 'dropout': 0.0, # Disable dropout for inference
379
  'rope_theta': config['rope_theta']
380
  }
381
  model = SAM1Model(config=model_config)
@@ -409,77 +356,56 @@ else:
409
  if model:
410
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
411
 
412
- # Initialize fast sampler
413
- sampler = FastSampler(config['vocab_size'])
414
-
415
- # Warm up the model (without tf.function first)
416
  print("πŸ”₯ Warming up model...")
417
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
418
-
419
- # Initial warmup to build all internal caches
420
- for _ in range(2):
421
- logits, past_kv = model(warmup_input, training=False, use_cache=True)
422
-
423
- # Warmup decode step
424
- single_token = tf.constant([[1]], dtype=tf.int32)
425
- for _ in range(2):
426
- logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
427
-
428
  print("βœ… Model warmed up")
429
 
430
-
431
  # ============================================================================
432
- # Inference wrapper class for clean tf.function usage
433
  # ============================================================================
434
 
435
- class InferenceEngine:
436
- """Wrapper for compiled inference functions."""
437
-
438
- def __init__(self, model):
439
- self.model = model
440
- self._prefill_fn = None
441
- self._decode_fn = None
442
-
443
- def prefill(self, input_ids):
444
- """Run prefill (first call builds trace)."""
445
- if self._prefill_fn is None:
446
- # First call - run eagerly to ensure all caches are built
447
- return self.model(input_ids, training=False, use_cache=True)
448
- return self._prefill_fn(input_ids)
449
-
450
- def decode(self, input_ids, past_kv):
451
- """Run single-token decode."""
452
- return self.model(input_ids, training=False, past_kv=past_kv, use_cache=True)
453
-
454
- def compile_traces(self):
455
- """Compile tf.function traces after warmup."""
456
- print("πŸ”₯ Compiling optimized traces...")
457
-
458
- @tf.function(reduce_retracing=True)
459
- def prefill_fn(input_ids):
460
- return self.model(input_ids, training=False, use_cache=True)
461
-
462
- self._prefill_fn = prefill_fn
463
-
464
- # Trace with sample inputs
465
- sample_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
466
- _ = self._prefill_fn(sample_input)
467
-
468
- print("βœ… Traces compiled")
469
-
470
-
471
- # Create inference engine
472
- engine = InferenceEngine(model)
473
 
474
- # Compile traces after warmup
475
- engine.compile_traces()
476
 
 
 
 
 
477
 
478
- # ============================================================================
479
- # Optimized Inference Logic with KV-Cache
480
- # ============================================================================
 
 
481
 
482
- stop_generation = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
484
 
485
  def generate_stream(
@@ -514,16 +440,17 @@ def generate_stream(
514
 
515
  max_context = config['max_position_embeddings']
516
 
517
- start_time = time.perf_counter()
518
 
519
  # === PREFILL PHASE ===
 
520
  if len(input_ids) > max_context - max_tokens:
521
  input_ids = input_ids[-(max_context - max_tokens):]
522
 
523
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
524
 
525
  try:
526
- logits, past_kv = engine.prefill(input_tensor)
527
  except Exception as e:
528
  yield f"Error during prefill: {e}"
529
  return
@@ -531,12 +458,11 @@ def generate_stream(
531
  # Get logits for last position
532
  next_token_logits = logits[0, -1, :].numpy()
533
 
534
- prefill_time = time.perf_counter() - start_time
535
- prefill_tps = len(input_ids) / prefill_time if prefill_time > 0 else 0
536
- print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({prefill_tps:.1f} tok/s)")
537
 
538
  # === GENERATION LOOP ===
539
- decode_start = time.perf_counter()
540
 
541
  for step in range(max_tokens):
542
  if stop_generation:
@@ -544,7 +470,7 @@ def generate_stream(
544
  return
545
 
546
  # Sample next token
547
- next_token_id = sampler.sample(
548
  next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
549
  )
550
 
@@ -561,11 +487,11 @@ def generate_stream(
561
  token_count += 1
562
  yield generated_text
563
 
564
- # === DECODE PHASE ===
565
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
566
 
567
  try:
568
- logits, past_kv = engine.decode(next_input, past_kv)
569
  except Exception as e:
570
  yield generated_text + f"\n\n*[Error during generation: {e}]*"
571
  return
@@ -573,24 +499,24 @@ def generate_stream(
573
  next_token_logits = logits[0, -1, :].numpy()
574
 
575
  # Truncate cache if too long
576
- if step % 100 == 99:
577
- current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
578
- if current_len > max_context:
579
- trim_amount = current_len - max_context + 100
580
- past_kv = [
581
- (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
582
- for k, v in past_kv
583
- ]
584
-
585
- decode_time = time.perf_counter() - decode_start
586
- total_time = time.perf_counter() - start_time
587
 
588
  if token_count > 0:
589
  decode_tps = token_count / decode_time if decode_time > 0 else 0
 
590
 
591
  stats = (
592
  f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
593
- f"(prefill: {prefill_time:.2f}s, decode: {decode_tps:.1f} tok/s)]*"
594
  )
595
 
596
  if not stop_generation:
@@ -605,21 +531,20 @@ def generate_stream(
605
 
606
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
607
  """Format message history and seed <think> if enabled."""
608
- prompt_parts = []
609
-
610
  for user_msg, assistant_msg in history:
611
- prompt_parts.append(f"<|im_start|>user\n{user_msg}<|im_end|>")
612
  if assistant_msg:
 
613
  clean_msg = assistant_msg.split("\n\n*[")[0]
614
- prompt_parts.append(f"<|im_start|>assistant\n{clean_msg}<|im_end|>")
 
 
615
 
616
- prompt_parts.append(f"<|im_start|>user\n{message}<|im_end|>")
617
- prompt_parts.append("<|im_start|>assistant")
618
-
619
  if reasoning_enabled:
620
- prompt_parts.append("<think>")
621
 
622
- return "\n".join(prompt_parts)
623
 
624
 
625
  def chat_stream(
@@ -658,6 +583,7 @@ def chat_stream(
658
 
659
  display_response = partial_response
660
  if should_stop:
 
661
  stats_start = partial_response.find("\n\n*[")
662
  if stats_start > earliest_stop:
663
  display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
@@ -843,7 +769,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
843
  **Vocab:** {config['vocab_size']:,}
844
  **Layers:** {config['num_hidden_layers']}
845
  **Context:** {config['max_position_embeddings']:,} tokens
846
- **Optimization:** KV-Cache + XLA ⚑
847
  """)
848
 
849
  gr.Examples(
@@ -919,7 +845,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
919
 
920
  if __name__ == "__main__":
921
  print("\n" + "=" * 60)
922
- print("πŸš€ Starting Sam-large-2 Chat with Optimized Inference")
923
  print("=" * 60 + "\n")
924
  demo.queue(max_size=20)
925
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
 
24
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
25
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
26
 
27
+ print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
 
 
 
 
 
28
 
29
  # ============================================================================
30
  # 🎊 FESTIVE MODE TOGGLE 🎊
 
41
  CACHE_DIR = "./model_cache"
42
 
43
  # ============================================================================
44
+ # Model Architecture Definitions (Optimized with KV-Cache)
45
  # ============================================================================
46
 
47
  @keras.saving.register_keras_serializable()
48
  class RotaryEmbedding(keras.layers.Layer):
 
 
49
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
50
  super().__init__(**kwargs)
51
  self.dim = dim
52
  self.max_len = max_len
53
  self.theta = theta
54
+ self.built_cache = False
55
+ self.cos_cached = None
56
+ self.sin_cached = None
57
 
58
  def build(self, input_shape):
 
 
 
 
 
 
 
 
 
 
59
  super().build(input_shape)
60
 
61
+ def _build_cache(self):
62
+ if not self.built_cache:
63
+ inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
64
+ t = tf.range(self.max_len, dtype=tf.float32)
65
+ freqs = tf.einsum("i,j->ij", t, inv_freq)
66
+ emb = tf.concat([freqs, freqs], axis=-1)
67
+ self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
68
+ self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
69
+ self.built_cache = True
70
+
71
+ def rotate_half(self, x):
72
+ x1, x2 = tf.split(x, 2, axis=-1)
73
+ return tf.concat([-x2, x1], axis=-1)
74
+
75
  def call(self, q, k, offset=0):
76
  """Apply rotary embeddings with position offset for KV-cache."""
77
+ self._build_cache()
78
  seq_len = tf.shape(q)[2]
79
  dtype = q.dtype
80
 
81
+ cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
82
+ sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
 
83
 
84
+ q_embed = (q * cos) + (self.rotate_half(q) * sin)
85
+ k_embed = (k * cos) + (self.rotate_half(k) * sin)
 
 
 
 
 
86
  return q_embed, k_embed
87
 
88
  def get_config(self):
 
114
 
115
  @keras.saving.register_keras_serializable()
116
  class TransformerBlock(keras.layers.Layer):
 
 
117
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
118
  super().__init__(**kwargs)
119
  self.d_model = d_model
 
124
  self.rope_theta = rope_theta
125
  self.head_dim = d_model // n_heads
126
  self.layer_idx = layer_idx
 
127
 
128
  def build(self, input_shape):
 
129
  self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
130
  self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
 
 
131
  self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
132
  self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
133
  self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
134
  self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
 
135
  self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
 
 
136
  self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
137
  self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
138
  self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
 
139
  self.dropout = keras.layers.Dropout(self.dropout_rate)
140
  super().build(input_shape)
141
 
142
  def call(self, x, training=None, past_kv=None, use_cache=False):
143
+ """
144
+ Args:
145
+ x: input tensor [B, T, D] (T=1 during cached generation)
146
+ past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
147
+ use_cache: whether to return updated kv cache
148
+ Returns:
149
+ output, (new_k, new_v) if use_cache else output, None
150
+ """
151
  B = tf.shape(x)[0]
152
  T = tf.shape(x)[1]
153
+ dtype = x.dtype
154
 
155
  res = x
156
  y = self.pre_attn_norm(x)
157
 
158
+ # Project Q, K, V for current input
159
  q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
160
  q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
161
 
 
181
 
182
  new_kv = (k, v) if use_cache else None
183
 
184
+ # Attention
185
  full_len = tf.shape(k)[2]
186
+ scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
187
 
188
  # Causal mask
189
  q_positions = tf.range(past_len, past_len + T)
190
  k_positions = tf.range(full_len)
191
+ mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
192
+ mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
193
  scores = scores + mask[None, None, :, :]
194
 
195
  attn = tf.nn.softmax(scores, axis=-1)
 
199
 
200
  x = res + self.dropout(self.out_proj(attn_out), training=training)
201
 
202
+ # FFN
203
  res = x
204
  y = self.pre_ffn_norm(x)
205
+ ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
206
  output = res + self.dropout(ffn, training=training)
207
 
208
  return output, new_kv
 
250
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
251
 
252
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
253
+ """
254
+ Args:
255
+ input_ids: [B, T]
256
+ past_kv: list of (k, v) tuples, one per layer
257
+ use_cache: whether to return updated cache
258
+ Returns:
259
+ logits, new_past_kv (or None)
260
+ """
261
  x = self.embed(input_ids)
262
 
263
  new_past_kv = [] if use_cache else None
 
277
  return base_config
278
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # --- Model and Tokenizer Loading ---
281
 
282
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
322
  'n_heads': config['num_attention_heads'],
323
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
324
  'max_len': config['max_position_embeddings'],
325
+ 'dropout': 0.1,
326
  'rope_theta': config['rope_theta']
327
  }
328
  model = SAM1Model(config=model_config)
 
356
  if model:
357
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
358
 
359
+ # Warm up the model
 
 
 
360
  print("πŸ”₯ Warming up model...")
361
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
362
+ _, _ = model(warmup_input, training=False, use_cache=True)
 
 
 
 
 
 
 
 
 
363
  print("βœ… Model warmed up")
364
 
 
365
  # ============================================================================
366
+ # Optimized Inference Logic with KV-Cache
367
  # ============================================================================
368
 
369
+ stop_generation = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
 
 
371
 
372
+ def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
373
+ """Pure NumPy sampling for speed."""
374
+ # Temperature scaling
375
+ scaled_logits = logits / temperature
376
 
377
+ # Repetition penalty
378
+ if repetition_penalty != 1.0:
379
+ for token_id, freq in token_freq.items():
380
+ if token_id < len(scaled_logits):
381
+ scaled_logits[token_id] /= (repetition_penalty ** freq)
382
 
383
+ # Top-K filtering
384
+ if top_k > 0 and top_k < len(scaled_logits):
385
+ top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
386
+ top_k_logits = scaled_logits[top_k_indices]
387
+ else:
388
+ top_k_indices = np.arange(len(scaled_logits))
389
+ top_k_logits = scaled_logits
390
+
391
+ # Softmax (numerically stable)
392
+ top_k_logits = top_k_logits - np.max(top_k_logits)
393
+ top_k_probs = np.exp(top_k_logits)
394
+ top_k_probs /= top_k_probs.sum()
395
+
396
+ # Top-P (nucleus) filtering
397
+ if top_p < 1.0:
398
+ sorted_idx = np.argsort(top_k_probs)[::-1]
399
+ cumsum = np.cumsum(top_k_probs[sorted_idx])
400
+ cutoff = np.searchsorted(cumsum, top_p) + 1
401
+ nucleus_idx = sorted_idx[:cutoff]
402
+ nucleus_probs = top_k_probs[nucleus_idx]
403
+ nucleus_probs /= nucleus_probs.sum()
404
+ sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
405
+ return int(top_k_indices[nucleus_idx[sampled]])
406
+ else:
407
+ sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
408
+ return int(top_k_indices[sampled])
409
 
410
 
411
  def generate_stream(
 
440
 
441
  max_context = config['max_position_embeddings']
442
 
443
+ start_time = time.time()
444
 
445
  # === PREFILL PHASE ===
446
+ # Truncate if prompt is too long
447
  if len(input_ids) > max_context - max_tokens:
448
  input_ids = input_ids[-(max_context - max_tokens):]
449
 
450
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
451
 
452
  try:
453
+ logits, past_kv = model(input_tensor, training=False, use_cache=True)
454
  except Exception as e:
455
  yield f"Error during prefill: {e}"
456
  return
 
458
  # Get logits for last position
459
  next_token_logits = logits[0, -1, :].numpy()
460
 
461
+ prefill_time = time.time() - start_time
462
+ print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
 
463
 
464
  # === GENERATION LOOP ===
465
+ decode_start = time.time()
466
 
467
  for step in range(max_tokens):
468
  if stop_generation:
 
470
  return
471
 
472
  # Sample next token
473
+ next_token_id = sample_token(
474
  next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
475
  )
476
 
 
487
  token_count += 1
488
  yield generated_text
489
 
490
+ # === DECODE PHASE (single token, reuse cache) ===
491
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
492
 
493
  try:
494
+ logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
495
  except Exception as e:
496
  yield generated_text + f"\n\n*[Error during generation: {e}]*"
497
  return
 
499
  next_token_logits = logits[0, -1, :].numpy()
500
 
501
  # Truncate cache if too long
502
+ current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
503
+ if current_len > max_context:
504
+ trim_amount = current_len - max_context + 100 # Keep some buffer
505
+ past_kv = [
506
+ (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
507
+ for k, v in past_kv
508
+ ]
509
+
510
+ decode_time = time.time() - decode_start
511
+ total_time = time.time() - start_time
 
512
 
513
  if token_count > 0:
514
  decode_tps = token_count / decode_time if decode_time > 0 else 0
515
+ total_tps = token_count / total_time if total_time > 0 else 0
516
 
517
  stats = (
518
  f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
519
+ f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*"
520
  )
521
 
522
  if not stop_generation:
 
531
 
532
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
533
  """Format message history and seed <think> if enabled."""
534
+ prompt = ""
 
535
  for user_msg, assistant_msg in history:
536
+ prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
537
  if assistant_msg:
538
+ # Clean up any stats from previous messages
539
  clean_msg = assistant_msg.split("\n\n*[")[0]
540
+ prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
541
+
542
+ prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
543
 
 
 
 
544
  if reasoning_enabled:
545
+ prompt += "<think>"
546
 
547
+ return prompt
548
 
549
 
550
  def chat_stream(
 
583
 
584
  display_response = partial_response
585
  if should_stop:
586
+ # Keep the stats portion if present
587
  stats_start = partial_response.find("\n\n*[")
588
  if stats_start > earliest_stop:
589
  display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
 
769
  **Vocab:** {config['vocab_size']:,}
770
  **Layers:** {config['num_hidden_layers']}
771
  **Context:** {config['max_position_embeddings']:,} tokens
772
+ **Optimization:** KV-Cache enabled ⚑
773
  """)
774
 
775
  gr.Examples(
 
845
 
846
  if __name__ == "__main__":
847
  print("\n" + "=" * 60)
848
+ print("πŸš€ Starting Sam-large-2 Chat with KV-Cache Optimization")
849
  print("=" * 60 + "\n")
850
  demo.queue(max_size=20)
851
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)