Keeby-smilyai commited on
Commit
cbfe110
Β·
verified Β·
1 Parent(s): 5d1d6ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -139
app.py CHANGED
@@ -10,6 +10,7 @@ os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES)
10
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
11
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
12
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
 
13
 
14
  import gradio as gr
15
  import tensorflow as tf
@@ -19,12 +20,16 @@ import json
19
  from tokenizers import Tokenizer
20
  import numpy as np
21
  import time
 
22
 
23
  # Configure TF threading
24
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
25
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
26
 
27
- print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled")
 
 
 
28
 
29
  # ============================================================================
30
  # 🎊 FESTIVE MODE TOGGLE 🎊
@@ -41,48 +46,59 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
41
  CACHE_DIR = "./model_cache"
42
 
43
  # ============================================================================
44
- # Model Architecture Definitions (Optimized with KV-Cache)
45
  # ============================================================================
46
 
47
  @keras.saving.register_keras_serializable()
48
  class RotaryEmbedding(keras.layers.Layer):
 
 
49
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
50
  super().__init__(**kwargs)
51
  self.dim = dim
52
  self.max_len = max_len
53
  self.theta = theta
54
- self.built_cache = False
55
  self.cos_cached = None
56
  self.sin_cached = None
57
 
58
  def build(self, input_shape):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  super().build(input_shape)
60
 
61
- def _build_cache(self):
62
- if not self.built_cache:
63
- inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
64
- t = tf.range(self.max_len, dtype=tf.float32)
65
- freqs = tf.einsum("i,j->ij", t, inv_freq)
66
- emb = tf.concat([freqs, freqs], axis=-1)
67
- self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
68
- self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
69
- self.built_cache = True
70
-
71
- def rotate_half(self, x):
72
- x1, x2 = tf.split(x, 2, axis=-1)
73
- return tf.concat([-x2, x1], axis=-1)
74
-
75
  def call(self, q, k, offset=0):
76
  """Apply rotary embeddings with position offset for KV-cache."""
77
- self._build_cache()
78
  seq_len = tf.shape(q)[2]
79
  dtype = q.dtype
80
 
81
  cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
82
  sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
83
 
84
- q_embed = (q * cos) + (self.rotate_half(q) * sin)
85
- k_embed = (k * cos) + (self.rotate_half(k) * sin)
 
 
 
 
 
86
  return q_embed, k_embed
87
 
88
  def get_config(self):
@@ -93,6 +109,8 @@ class RotaryEmbedding(keras.layers.Layer):
93
 
94
  @keras.saving.register_keras_serializable()
95
  class RMSNorm(keras.layers.Layer):
 
 
96
  def __init__(self, epsilon=1e-5, **kwargs):
97
  super().__init__(**kwargs)
98
  self.epsilon = epsilon
@@ -102,7 +120,9 @@ class RMSNorm(keras.layers.Layer):
102
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
103
  super().build(input_shape)
104
 
 
105
  def call(self, x):
 
106
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
107
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
108
 
@@ -114,6 +134,8 @@ class RMSNorm(keras.layers.Layer):
114
 
115
  @keras.saving.register_keras_serializable()
116
  class TransformerBlock(keras.layers.Layer):
 
 
117
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
118
  super().__init__(**kwargs)
119
  self.d_model = d_model
@@ -124,52 +146,40 @@ class TransformerBlock(keras.layers.Layer):
124
  self.rope_theta = rope_theta
125
  self.head_dim = d_model // n_heads
126
  self.layer_idx = layer_idx
 
127
 
128
  def build(self, input_shape):
129
  self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
130
  self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
131
- self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
132
- self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
133
- self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
134
  self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
 
135
  self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
136
- self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
137
- self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
 
138
  self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
 
139
  self.dropout = keras.layers.Dropout(self.dropout_rate)
140
  super().build(input_shape)
141
 
142
  def call(self, x, training=None, past_kv=None, use_cache=False):
143
- """
144
- Args:
145
- x: input tensor [B, T, D] (T=1 during cached generation)
146
- past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
147
- use_cache: whether to return updated kv cache
148
- Returns:
149
- output, (new_k, new_v) if use_cache else output, None
150
- """
151
  B = tf.shape(x)[0]
152
  T = tf.shape(x)[1]
153
- dtype = x.dtype
154
 
155
  res = x
156
  y = self.pre_attn_norm(x)
157
 
158
- # Project Q, K, V for current input
159
- q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
160
- q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
161
-
162
- k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
163
- k = tf.transpose(k, [0, 2, 1, 3])
164
-
165
- v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
166
- v = tf.transpose(v, [0, 2, 1, 3])
167
 
168
  # Determine position offset for RoPE
169
- if past_kv is not None:
170
- past_len = tf.shape(past_kv[0])[2]
171
- else:
172
- past_len = 0
173
 
174
  # Apply RoPE with position offset
175
  q, k = self.rope(q, k, offset=past_len)
@@ -181,15 +191,14 @@ class TransformerBlock(keras.layers.Layer):
181
 
182
  new_kv = (k, v) if use_cache else None
183
 
184
- # Attention
185
  full_len = tf.shape(k)[2]
186
- scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
187
 
188
- # Causal mask
189
  q_positions = tf.range(past_len, past_len + T)
190
  k_positions = tf.range(full_len)
191
- mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
192
- mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
193
  scores = scores + mask[None, None, :, :]
194
 
195
  attn = tf.nn.softmax(scores, axis=-1)
@@ -199,10 +208,12 @@ class TransformerBlock(keras.layers.Layer):
199
 
200
  x = res + self.dropout(self.out_proj(attn_out), training=training)
201
 
202
- # FFN
203
  res = x
204
  y = self.pre_ffn_norm(x)
205
- ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
 
 
206
  output = res + self.dropout(ffn, training=training)
207
 
208
  return output, new_kv
@@ -223,6 +234,8 @@ class TransformerBlock(keras.layers.Layer):
223
 
224
  @keras.saving.register_keras_serializable()
225
  class SAM1Model(keras.Model):
 
 
226
  def __init__(self, **kwargs):
227
  super().__init__()
228
  if 'config' in kwargs and isinstance(kwargs['config'], dict):
@@ -248,16 +261,11 @@ class SAM1Model(keras.Model):
248
  ]
249
  self.norm = RMSNorm(name="final_norm")
250
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
 
 
 
251
 
252
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
253
- """
254
- Args:
255
- input_ids: [B, T]
256
- past_kv: list of (k, v) tuples, one per layer
257
- use_cache: whether to return updated cache
258
- Returns:
259
- logits, new_past_kv (or None)
260
- """
261
  x = self.embed(input_ids)
262
 
263
  new_past_kv = [] if use_cache else None
@@ -271,12 +279,85 @@ class SAM1Model(keras.Model):
271
  logits = self.lm_head(self.norm(x))
272
  return logits, new_past_kv
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def get_config(self):
275
  base_config = super().get_config()
276
  base_config['config'] = self.cfg
277
  return base_config
278
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # --- Model and Tokenizer Loading ---
281
 
282
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
@@ -322,7 +403,7 @@ if use_checkpoint:
322
  'n_heads': config['num_attention_heads'],
323
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
324
  'max_len': config['max_position_embeddings'],
325
- 'dropout': 0.1,
326
  'rope_theta': config['rope_theta']
327
  }
328
  model = SAM1Model(config=model_config)
@@ -356,11 +437,23 @@ else:
356
  if model:
357
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
358
 
359
- # Warm up the model
360
- print("πŸ”₯ Warming up model...")
 
 
 
361
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
362
- _, _ = model(warmup_input, training=False, use_cache=True)
363
- print("βœ… Model warmed up")
 
 
 
 
 
 
 
 
 
364
 
365
  # ============================================================================
366
  # Optimized Inference Logic with KV-Cache
@@ -369,45 +462,6 @@ print("βœ… Model warmed up")
369
  stop_generation = False
370
 
371
 
372
- def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
373
- """Pure NumPy sampling for speed."""
374
- # Temperature scaling
375
- scaled_logits = logits / temperature
376
-
377
- # Repetition penalty
378
- if repetition_penalty != 1.0:
379
- for token_id, freq in token_freq.items():
380
- if token_id < len(scaled_logits):
381
- scaled_logits[token_id] /= (repetition_penalty ** freq)
382
-
383
- # Top-K filtering
384
- if top_k > 0 and top_k < len(scaled_logits):
385
- top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
386
- top_k_logits = scaled_logits[top_k_indices]
387
- else:
388
- top_k_indices = np.arange(len(scaled_logits))
389
- top_k_logits = scaled_logits
390
-
391
- # Softmax (numerically stable)
392
- top_k_logits = top_k_logits - np.max(top_k_logits)
393
- top_k_probs = np.exp(top_k_logits)
394
- top_k_probs /= top_k_probs.sum()
395
-
396
- # Top-P (nucleus) filtering
397
- if top_p < 1.0:
398
- sorted_idx = np.argsort(top_k_probs)[::-1]
399
- cumsum = np.cumsum(top_k_probs[sorted_idx])
400
- cutoff = np.searchsorted(cumsum, top_p) + 1
401
- nucleus_idx = sorted_idx[:cutoff]
402
- nucleus_probs = top_k_probs[nucleus_idx]
403
- nucleus_probs /= nucleus_probs.sum()
404
- sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
405
- return int(top_k_indices[nucleus_idx[sampled]])
406
- else:
407
- sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
408
- return int(top_k_indices[sampled])
409
-
410
-
411
  def generate_stream(
412
  prompt: str,
413
  max_tokens: int = 512,
@@ -440,10 +494,9 @@ def generate_stream(
440
 
441
  max_context = config['max_position_embeddings']
442
 
443
- start_time = time.time()
444
 
445
  # === PREFILL PHASE ===
446
- # Truncate if prompt is too long
447
  if len(input_ids) > max_context - max_tokens:
448
  input_ids = input_ids[-(max_context - max_tokens):]
449
 
@@ -455,22 +508,25 @@ def generate_stream(
455
  yield f"Error during prefill: {e}"
456
  return
457
 
458
- # Get logits for last position
459
  next_token_logits = logits[0, -1, :].numpy()
460
 
461
- prefill_time = time.time() - start_time
462
- print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
463
 
464
  # === GENERATION LOOP ===
465
- decode_start = time.time()
 
 
 
466
 
467
  for step in range(max_tokens):
468
  if stop_generation:
469
  yield generated_text + "\n\n*[Generation stopped]*"
470
  return
471
 
472
- # Sample next token
473
- next_token_id = sample_token(
474
  next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
475
  )
476
 
@@ -485,7 +541,9 @@ def generate_stream(
485
  token_text = tokenizer.decode([next_token_id])
486
  generated_text += token_text
487
  token_count += 1
488
- yield generated_text
 
 
489
 
490
  # === DECODE PHASE (single token, reuse cache) ===
491
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
@@ -498,25 +556,25 @@ def generate_stream(
498
 
499
  next_token_logits = logits[0, -1, :].numpy()
500
 
501
- # Truncate cache if too long
502
- current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
503
- if current_len > max_context:
504
- trim_amount = current_len - max_context + 100 # Keep some buffer
505
- past_kv = [
506
- (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
507
- for k, v in past_kv
508
- ]
509
-
510
- decode_time = time.time() - decode_start
511
- total_time = time.time() - start_time
 
512
 
513
  if token_count > 0:
514
  decode_tps = token_count / decode_time if decode_time > 0 else 0
515
- total_tps = token_count / total_time if total_time > 0 else 0
516
 
517
  stats = (
518
  f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
519
- f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*"
520
  )
521
 
522
  if not stop_generation:
@@ -531,20 +589,21 @@ def generate_stream(
531
 
532
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
533
  """Format message history and seed <think> if enabled."""
534
- prompt = ""
 
535
  for user_msg, assistant_msg in history:
536
- prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
537
  if assistant_msg:
538
- # Clean up any stats from previous messages
539
  clean_msg = assistant_msg.split("\n\n*[")[0]
540
- prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
541
-
542
- prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
543
 
 
 
 
544
  if reasoning_enabled:
545
- prompt += "<think>"
546
 
547
- return prompt
548
 
549
 
550
  def chat_stream(
@@ -583,7 +642,6 @@ def chat_stream(
583
 
584
  display_response = partial_response
585
  if should_stop:
586
- # Keep the stats portion if present
587
  stats_start = partial_response.find("\n\n*[")
588
  if stats_start > earliest_stop:
589
  display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
@@ -769,7 +827,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
769
  **Vocab:** {config['vocab_size']:,}
770
  **Layers:** {config['num_hidden_layers']}
771
  **Context:** {config['max_position_embeddings']:,} tokens
772
- **Optimization:** KV-Cache enabled ⚑
773
  """)
774
 
775
  gr.Examples(
@@ -845,7 +903,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
845
 
846
  if __name__ == "__main__":
847
  print("\n" + "=" * 60)
848
- print("πŸš€ Starting Sam-large-2 Chat with KV-Cache Optimization")
849
  print("=" * 60 + "\n")
850
  demo.queue(max_size=20)
851
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
 
10
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # Force CPU only
11
  os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' # Intel optimization
12
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Reduce TF logging
13
+ os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0' # We'll handle precision manually
14
 
15
  import gradio as gr
16
  import tensorflow as tf
 
20
  from tokenizers import Tokenizer
21
  import numpy as np
22
  import time
23
+ from functools import lru_cache
24
 
25
  # Configure TF threading
26
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
27
  tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
28
 
29
+ # Enable XLA JIT compilation for CPU
30
+ tf.config.optimizer.set_jit(True)
31
+
32
+ print(f"βœ… CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
33
 
34
  # ============================================================================
35
  # 🎊 FESTIVE MODE TOGGLE 🎊
 
46
  CACHE_DIR = "./model_cache"
47
 
48
  # ============================================================================
49
+ # Optimized Model Architecture with KV-Cache
50
  # ============================================================================
51
 
52
  @keras.saving.register_keras_serializable()
53
  class RotaryEmbedding(keras.layers.Layer):
54
+ """Optimized RoPE with pre-computed cache."""
55
+
56
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
57
  super().__init__(**kwargs)
58
  self.dim = dim
59
  self.max_len = max_len
60
  self.theta = theta
 
61
  self.cos_cached = None
62
  self.sin_cached = None
63
 
64
  def build(self, input_shape):
65
+ # Pre-compute RoPE cache during build
66
+ inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
67
+ t = np.arange(self.max_len, dtype=np.float32)
68
+ freqs = np.outer(t, inv_freq)
69
+ emb = np.concatenate([freqs, freqs], axis=-1)
70
+
71
+ # Store as non-trainable weights for better graph optimization
72
+ self.cos_cached = self.add_weight(
73
+ name="cos_cache",
74
+ shape=emb.shape,
75
+ initializer=keras.initializers.Constant(np.cos(emb)),
76
+ trainable=False
77
+ )
78
+ self.sin_cached = self.add_weight(
79
+ name="sin_cache",
80
+ shape=emb.shape,
81
+ initializer=keras.initializers.Constant(np.sin(emb)),
82
+ trainable=False
83
+ )
84
  super().build(input_shape)
85
 
86
+ @tf.function(reduce_retracing=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def call(self, q, k, offset=0):
88
  """Apply rotary embeddings with position offset for KV-cache."""
 
89
  seq_len = tf.shape(q)[2]
90
  dtype = q.dtype
91
 
92
  cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
93
  sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
94
 
95
+ # Fused rotate_half operation
96
+ x1_q, x2_q = tf.split(q, 2, axis=-1)
97
+ x1_k, x2_k = tf.split(k, 2, axis=-1)
98
+
99
+ q_embed = (q * cos) + (tf.concat([-x2_q, x1_q], axis=-1) * sin)
100
+ k_embed = (k * cos) + (tf.concat([-x2_k, x1_k], axis=-1) * sin)
101
+
102
  return q_embed, k_embed
103
 
104
  def get_config(self):
 
109
 
110
  @keras.saving.register_keras_serializable()
111
  class RMSNorm(keras.layers.Layer):
112
+ """Optimized RMSNorm."""
113
+
114
  def __init__(self, epsilon=1e-5, **kwargs):
115
  super().__init__(**kwargs)
116
  self.epsilon = epsilon
 
120
  self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
121
  super().build(input_shape)
122
 
123
+ @tf.function(reduce_retracing=True)
124
  def call(self, x):
125
+ # Fused computation
126
  variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
127
  return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
128
 
 
134
 
135
  @keras.saving.register_keras_serializable()
136
  class TransformerBlock(keras.layers.Layer):
137
+ """Optimized transformer block with efficient attention."""
138
+
139
  def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
140
  super().__init__(**kwargs)
141
  self.d_model = d_model
 
146
  self.rope_theta = rope_theta
147
  self.head_dim = d_model // n_heads
148
  self.layer_idx = layer_idx
149
+ self.scale = 1.0 / np.sqrt(self.head_dim)
150
 
151
  def build(self, input_shape):
152
  self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
153
  self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
154
+
155
+ # Fused QKV projection for better memory access
156
+ self.qkv_proj = keras.layers.Dense(self.d_model * 3, use_bias=False, name="qkv_proj")
157
  self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
158
+
159
  self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
160
+
161
+ # Fused gate/up projection
162
+ self.gate_up_proj = keras.layers.Dense(self.ff_dim * 2, use_bias=False, name="gate_up_proj")
163
  self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
164
+
165
  self.dropout = keras.layers.Dropout(self.dropout_rate)
166
  super().build(input_shape)
167
 
168
  def call(self, x, training=None, past_kv=None, use_cache=False):
 
 
 
 
 
 
 
 
169
  B = tf.shape(x)[0]
170
  T = tf.shape(x)[1]
 
171
 
172
  res = x
173
  y = self.pre_attn_norm(x)
174
 
175
+ # Fused QKV projection
176
+ qkv = self.qkv_proj(y)
177
+ qkv = tf.reshape(qkv, [B, T, 3, self.n_heads, self.head_dim])
178
+ qkv = tf.transpose(qkv, [2, 0, 3, 1, 4]) # [3, B, n_heads, T, head_dim]
179
+ q, k, v = qkv[0], qkv[1], qkv[2]
 
 
 
 
180
 
181
  # Determine position offset for RoPE
182
+ past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
 
 
 
183
 
184
  # Apply RoPE with position offset
185
  q, k = self.rope(q, k, offset=past_len)
 
191
 
192
  new_kv = (k, v) if use_cache else None
193
 
194
+ # Scaled dot-product attention
195
  full_len = tf.shape(k)[2]
196
+ scores = tf.matmul(q, k, transpose_b=True) * self.scale
197
 
198
+ # Optimized causal mask
199
  q_positions = tf.range(past_len, past_len + T)
200
  k_positions = tf.range(full_len)
201
+ mask = tf.cast(q_positions[:, None] < k_positions[None, :], q.dtype) * -1e9
 
202
  scores = scores + mask[None, None, :, :]
203
 
204
  attn = tf.nn.softmax(scores, axis=-1)
 
208
 
209
  x = res + self.dropout(self.out_proj(attn_out), training=training)
210
 
211
+ # Optimized FFN with fused gate/up
212
  res = x
213
  y = self.pre_ffn_norm(x)
214
+ gate_up = self.gate_up_proj(y)
215
+ gate, up = tf.split(gate_up, 2, axis=-1)
216
+ ffn = self.down_proj(tf.nn.silu(gate) * up)
217
  output = res + self.dropout(ffn, training=training)
218
 
219
  return output, new_kv
 
234
 
235
  @keras.saving.register_keras_serializable()
236
  class SAM1Model(keras.Model):
237
+ """Optimized SAM model with compiled inference."""
238
+
239
  def __init__(self, **kwargs):
240
  super().__init__()
241
  if 'config' in kwargs and isinstance(kwargs['config'], dict):
 
261
  ]
262
  self.norm = RMSNorm(name="final_norm")
263
  self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
264
+
265
+ self._compiled_prefill = None
266
+ self._compiled_decode = None
267
 
268
  def call(self, input_ids, training=None, past_kv=None, use_cache=False):
 
 
 
 
 
 
 
 
269
  x = self.embed(input_ids)
270
 
271
  new_past_kv = [] if use_cache else None
 
279
  logits = self.lm_head(self.norm(x))
280
  return logits, new_past_kv
281
 
282
+ @tf.function(reduce_retracing=True)
283
+ def prefill(self, input_ids):
284
+ """Compiled prefill for initial prompt processing."""
285
+ return self.call(input_ids, training=False, past_kv=None, use_cache=True)
286
+
287
+ @tf.function(reduce_retracing=True, input_signature=[
288
+ tf.TensorSpec(shape=[1, 1], dtype=tf.int32),
289
+ tf.TensorSpec(shape=[None], dtype=tf.variant) # For the list of KV tuples
290
+ ])
291
+ def decode_step(self, input_ids, past_kv):
292
+ """Compiled single-token decode step."""
293
+ return self.call(input_ids, training=False, past_kv=past_kv, use_cache=True)
294
+
295
  def get_config(self):
296
  base_config = super().get_config()
297
  base_config['config'] = self.cfg
298
  return base_config
299
 
300
 
301
+ # ============================================================================
302
+ # Optimized Sampling Functions
303
+ # ============================================================================
304
+
305
+ @lru_cache(maxsize=128)
306
+ def get_top_k_mask(vocab_size, top_k):
307
+ """Cache top-k masks for common vocab sizes."""
308
+ return top_k
309
+
310
+
311
+ class FastSampler:
312
+ """Vectorized sampler for faster token selection."""
313
+
314
+ def __init__(self, vocab_size):
315
+ self.vocab_size = vocab_size
316
+ self.rng = np.random.default_rng()
317
+
318
+ def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
319
+ """Optimized sampling with vectorized operations."""
320
+ # Temperature scaling
321
+ if temperature != 1.0:
322
+ logits = logits / temperature
323
+
324
+ # Vectorized repetition penalty
325
+ if repetition_penalty != 1.0 and token_freq:
326
+ freq_tokens = np.array(list(token_freq.keys()), dtype=np.int32)
327
+ freq_values = np.array(list(token_freq.values()), dtype=np.float32)
328
+ valid_mask = freq_tokens < len(logits)
329
+ freq_tokens = freq_tokens[valid_mask]
330
+ freq_values = freq_values[valid_mask]
331
+ logits[freq_tokens] /= np.power(repetition_penalty, freq_values)
332
+
333
+ # Top-K filtering with partial sort
334
+ if 0 < top_k < len(logits):
335
+ top_k_indices = np.argpartition(logits, -top_k)[-top_k:]
336
+ top_k_logits = logits[top_k_indices]
337
+ else:
338
+ top_k_indices = np.arange(len(logits))
339
+ top_k_logits = logits
340
+
341
+ # Stable softmax
342
+ top_k_logits = top_k_logits - np.max(top_k_logits)
343
+ exp_logits = np.exp(top_k_logits)
344
+ top_k_probs = exp_logits / exp_logits.sum()
345
+
346
+ # Top-P (nucleus) filtering
347
+ if top_p < 1.0:
348
+ sorted_idx = np.argsort(top_k_probs)[::-1]
349
+ cumsum = np.cumsum(top_k_probs[sorted_idx])
350
+ cutoff = np.searchsorted(cumsum, top_p) + 1
351
+ nucleus_idx = sorted_idx[:cutoff]
352
+ nucleus_probs = top_k_probs[nucleus_idx]
353
+ nucleus_probs /= nucleus_probs.sum()
354
+ sampled = self.rng.choice(len(nucleus_probs), p=nucleus_probs)
355
+ return int(top_k_indices[nucleus_idx[sampled]])
356
+ else:
357
+ sampled = self.rng.choice(len(top_k_probs), p=top_k_probs)
358
+ return int(top_k_indices[sampled])
359
+
360
+
361
  # --- Model and Tokenizer Loading ---
362
 
363
  config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
 
403
  'n_heads': config['num_attention_heads'],
404
  'ff_mult': config['intermediate_size'] / config['hidden_size'],
405
  'max_len': config['max_position_embeddings'],
406
+ 'dropout': 0.0, # Disable dropout for inference
407
  'rope_theta': config['rope_theta']
408
  }
409
  model = SAM1Model(config=model_config)
 
437
  if model:
438
  print(f"βœ… Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
439
 
440
+ # Initialize fast sampler
441
+ sampler = FastSampler(config['vocab_size'])
442
+
443
+ # Warm up with trace compilation
444
+ print("πŸ”₯ Warming up model and compiling traces...")
445
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
446
+
447
+ # Warm up prefill
448
+ for _ in range(3):
449
+ logits, past_kv = model(warmup_input, training=False, use_cache=True)
450
+
451
+ # Warm up decode step
452
+ single_token = tf.constant([[1]], dtype=tf.int32)
453
+ for _ in range(3):
454
+ logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
455
+
456
+ print("βœ… Model warmed up and traces compiled")
457
 
458
  # ============================================================================
459
  # Optimized Inference Logic with KV-Cache
 
462
  stop_generation = False
463
 
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  def generate_stream(
466
  prompt: str,
467
  max_tokens: int = 512,
 
494
 
495
  max_context = config['max_position_embeddings']
496
 
497
+ start_time = time.perf_counter() # More precise timing
498
 
499
  # === PREFILL PHASE ===
 
500
  if len(input_ids) > max_context - max_tokens:
501
  input_ids = input_ids[-(max_context - max_tokens):]
502
 
 
508
  yield f"Error during prefill: {e}"
509
  return
510
 
511
+ # Get logits for last position (avoid copy with indexing)
512
  next_token_logits = logits[0, -1, :].numpy()
513
 
514
+ prefill_time = time.perf_counter() - start_time
515
+ print(f"⚑ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({len(input_ids)/prefill_time:.1f} tok/s)")
516
 
517
  # === GENERATION LOOP ===
518
+ decode_start = time.perf_counter()
519
+
520
+ # Pre-compute constants
521
+ yield_interval = 1 # Yield every token for streaming
522
 
523
  for step in range(max_tokens):
524
  if stop_generation:
525
  yield generated_text + "\n\n*[Generation stopped]*"
526
  return
527
 
528
+ # Sample next token using optimized sampler
529
+ next_token_id = sampler.sample(
530
  next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
531
  )
532
 
 
541
  token_text = tokenizer.decode([next_token_id])
542
  generated_text += token_text
543
  token_count += 1
544
+
545
+ if step % yield_interval == 0:
546
+ yield generated_text
547
 
548
  # === DECODE PHASE (single token, reuse cache) ===
549
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
 
556
 
557
  next_token_logits = logits[0, -1, :].numpy()
558
 
559
+ # Truncate cache if too long (less frequent check)
560
+ if step % 100 == 0:
561
+ current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
562
+ if current_len > max_context:
563
+ trim_amount = current_len - max_context + 100
564
+ past_kv = [
565
+ (k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
566
+ for k, v in past_kv
567
+ ]
568
+
569
+ decode_time = time.perf_counter() - decode_start
570
+ total_time = time.perf_counter() - start_time
571
 
572
  if token_count > 0:
573
  decode_tps = token_count / decode_time if decode_time > 0 else 0
 
574
 
575
  stats = (
576
  f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
577
+ f"(prefill: {prefill_time:.2f}s, decode: {decode_tps:.1f} tok/s)]*"
578
  )
579
 
580
  if not stop_generation:
 
589
 
590
  def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
591
  """Format message history and seed <think> if enabled."""
592
+ prompt_parts = []
593
+
594
  for user_msg, assistant_msg in history:
595
+ prompt_parts.append(f"<|im_start|>user\n{user_msg}<|im_end|>")
596
  if assistant_msg:
 
597
  clean_msg = assistant_msg.split("\n\n*[")[0]
598
+ prompt_parts.append(f"<|im_start|>assistant\n{clean_msg}<|im_end|>")
 
 
599
 
600
+ prompt_parts.append(f"<|im_start|>user\n{message}<|im_end|>")
601
+ prompt_parts.append("<|im_start|>assistant")
602
+
603
  if reasoning_enabled:
604
+ prompt_parts.append("<think>")
605
 
606
+ return "\n".join(prompt_parts)
607
 
608
 
609
  def chat_stream(
 
642
 
643
  display_response = partial_response
644
  if should_stop:
 
645
  stats_start = partial_response.find("\n\n*[")
646
  if stats_start > earliest_stop:
647
  display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
 
827
  **Vocab:** {config['vocab_size']:,}
828
  **Layers:** {config['num_hidden_layers']}
829
  **Context:** {config['max_position_embeddings']:,} tokens
830
+ **Optimization:** KV-Cache + XLA JIT ⚑
831
  """)
832
 
833
  gr.Examples(
 
903
 
904
  if __name__ == "__main__":
905
  print("\n" + "=" * 60)
906
+ print("πŸš€ Starting Sam-large-2 Chat with Optimized Inference")
907
  print("=" * 60 + "\n")
908
  demo.queue(max_size=20)
909
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)