Keeby-smilyai commited on
Commit
d9e88a5
Β·
verified Β·
1 Parent(s): 90b1095

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -51
app.py CHANGED
@@ -19,7 +19,6 @@ import json
19
  from tokenizers import Tokenizer
20
  import numpy as np
21
  import time
22
- from functools import lru_cache
23
 
24
  # Configure TF threading
25
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
@@ -47,43 +46,40 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
47
  CACHE_DIR = "./model_cache"
48
 
49
  # ============================================================================
50
- # Model Architecture - MUST MATCH CHECKPOINT STRUCTURE
51
  # ============================================================================
52
 
53
  @keras.saving.register_keras_serializable()
54
  class RotaryEmbedding(keras.layers.Layer):
55
- """RoPE with pre-computed cache (no trainable weights - compatible with checkpoint)."""
56
 
57
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
58
  super().__init__(**kwargs)
59
  self.dim = dim
60
  self.max_len = max_len
61
  self.theta = theta
62
- self.built_cache = False
63
- self.cos_cached = None
64
- self.sin_cached = None
65
 
66
  def build(self, input_shape):
 
 
 
 
 
 
 
 
 
 
67
  super().build(input_shape)
68
 
69
- def _build_cache(self):
70
- if not self.built_cache:
71
- inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
72
- t = np.arange(self.max_len, dtype=np.float32)
73
- freqs = np.outer(t, inv_freq)
74
- emb = np.concatenate([freqs, freqs], axis=-1)
75
- self.cos_cached = tf.constant(np.cos(emb), dtype=tf.float32)
76
- self.sin_cached = tf.constant(np.sin(emb), dtype=tf.float32)
77
- self.built_cache = True
78
-
79
  def call(self, q, k, offset=0):
80
  """Apply rotary embeddings with position offset for KV-cache."""
81
- self._build_cache()
82
  seq_len = tf.shape(q)[2]
83
  dtype = q.dtype
84
 
85
- cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
86
- sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
 
87
 
88
  # Fused rotate_half
89
  x1_q, x2_q = tf.split(q, 2, axis=-1)
@@ -176,7 +172,10 @@ class TransformerBlock(keras.layers.Layer):
176
  v = tf.transpose(v, [0, 2, 1, 3])
177
 
178
  # Determine position offset for RoPE
179
- past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0
 
 
 
180
 
181
  # Apply RoPE with position offset
182
  q, k = self.rope(q, k, offset=past_len)
@@ -192,7 +191,7 @@ class TransformerBlock(keras.layers.Layer):
192
  full_len = tf.shape(k)[2]
193
  scores = tf.matmul(q, k, transpose_b=True) * self.scale
194
 
195
- # Optimized causal mask
196
  q_positions = tf.range(past_len, past_len + T)
197
  k_positions = tf.range(full_len)
198
  mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
@@ -288,10 +287,8 @@ class FastSampler:
288
 
289
  def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
290
  """Optimized sampling with vectorized operations."""
291
- # Make a copy to avoid modifying original
292
  logits = logits.copy()
293
 
294
- # Temperature scaling
295
  if temperature != 1.0:
296
  logits = logits / temperature
297
 
@@ -415,43 +412,67 @@ if model:
415
  # Initialize fast sampler
416
  sampler = FastSampler(config['vocab_size'])
417
 
418
- # Warm up with trace compilation
419
- print("πŸ”₯ Warming up model and compiling traces...")
420
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
421
 
422
- # Warm up prefill
423
- for _ in range(3):
424
  logits, past_kv = model(warmup_input, training=False, use_cache=True)
425
 
426
- # Warm up decode step
427
  single_token = tf.constant([[1]], dtype=tf.int32)
428
- for _ in range(3):
429
  logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
430
 
431
- print("βœ… Model warmed up and traces compiled")
 
432
 
433
  # ============================================================================
434
- # Compiled Inference Functions
435
  # ============================================================================
436
 
437
- # Create tf.function wrapped inference for speed
438
- @tf.function(reduce_retracing=True)
439
- def model_prefill(input_ids):
440
- """Compiled prefill function."""
441
- return model(input_ids, training=False, use_cache=True)
442
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- @tf.function(reduce_retracing=True)
445
- def model_decode(input_ids, past_kv):
446
- """Compiled single-token decode function."""
447
- return model(input_ids, training=False, past_kv=past_kv, use_cache=True)
448
 
 
 
449
 
450
- # Additional warmup for compiled functions
451
- print("πŸ”₯ Compiling tf.function traces...")
452
- _ = model_prefill(warmup_input)
453
- _ = model_decode(single_token, past_kv)
454
- print("βœ… Compiled functions ready")
455
 
456
 
457
  # ============================================================================
@@ -502,7 +523,7 @@ def generate_stream(
502
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
503
 
504
  try:
505
- logits, past_kv = model_prefill(input_tensor)
506
  except Exception as e:
507
  yield f"Error during prefill: {e}"
508
  return
@@ -522,7 +543,7 @@ def generate_stream(
522
  yield generated_text + "\n\n*[Generation stopped]*"
523
  return
524
 
525
- # Sample next token using optimized sampler
526
  next_token_id = sampler.sample(
527
  next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
528
  )
@@ -540,18 +561,18 @@ def generate_stream(
540
  token_count += 1
541
  yield generated_text
542
 
543
- # === DECODE PHASE (single token, reuse cache) ===
544
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
545
 
546
  try:
547
- logits, past_kv = model_decode(next_input, past_kv)
548
  except Exception as e:
549
  yield generated_text + f"\n\n*[Error during generation: {e}]*"
550
  return
551
 
552
  next_token_logits = logits[0, -1, :].numpy()
553
 
554
- # Truncate cache if too long (check less frequently)
555
  if step % 100 == 99:
556
  current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
557
  if current_len > max_context:
 
19
  from tokenizers import Tokenizer
20
  import numpy as np
21
  import time
 
22
 
23
  # Configure TF threading
24
  tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
 
46
  CACHE_DIR = "./model_cache"
47
 
48
  # ============================================================================
49
+ # Model Architecture - MATCHES CHECKPOINT STRUCTURE
50
  # ============================================================================
51
 
52
  @keras.saving.register_keras_serializable()
53
  class RotaryEmbedding(keras.layers.Layer):
54
+ """RoPE with cache built during layer build phase."""
55
 
56
  def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
57
  super().__init__(**kwargs)
58
  self.dim = dim
59
  self.max_len = max_len
60
  self.theta = theta
 
 
 
61
 
62
  def build(self, input_shape):
63
+ # Pre-compute RoPE cache as numpy arrays during build
64
+ inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
65
+ t = np.arange(self.max_len, dtype=np.float32)
66
+ freqs = np.outer(t, inv_freq)
67
+ emb = np.concatenate([freqs, freqs], axis=-1)
68
+
69
+ # Store as numpy arrays - will be converted to tensors in call()
70
+ self._cos_cached = np.cos(emb).astype(np.float32)
71
+ self._sin_cached = np.sin(emb).astype(np.float32)
72
+
73
  super().build(input_shape)
74
 
 
 
 
 
 
 
 
 
 
 
75
  def call(self, q, k, offset=0):
76
  """Apply rotary embeddings with position offset for KV-cache."""
 
77
  seq_len = tf.shape(q)[2]
78
  dtype = q.dtype
79
 
80
+ # Slice the pre-computed values
81
+ cos = tf.cast(self._cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
82
+ sin = tf.cast(self._sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
83
 
84
  # Fused rotate_half
85
  x1_q, x2_q = tf.split(q, 2, axis=-1)
 
172
  v = tf.transpose(v, [0, 2, 1, 3])
173
 
174
  # Determine position offset for RoPE
175
+ if past_kv is not None:
176
+ past_len = tf.shape(past_kv[0])[2]
177
+ else:
178
+ past_len = 0
179
 
180
  # Apply RoPE with position offset
181
  q, k = self.rope(q, k, offset=past_len)
 
191
  full_len = tf.shape(k)[2]
192
  scores = tf.matmul(q, k, transpose_b=True) * self.scale
193
 
194
+ # Causal mask
195
  q_positions = tf.range(past_len, past_len + T)
196
  k_positions = tf.range(full_len)
197
  mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
 
287
 
288
  def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
289
  """Optimized sampling with vectorized operations."""
 
290
  logits = logits.copy()
291
 
 
292
  if temperature != 1.0:
293
  logits = logits / temperature
294
 
 
412
  # Initialize fast sampler
413
  sampler = FastSampler(config['vocab_size'])
414
 
415
+ # Warm up the model (without tf.function first)
416
+ print("πŸ”₯ Warming up model...")
417
  warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
418
 
419
+ # Initial warmup to build all internal caches
420
+ for _ in range(2):
421
  logits, past_kv = model(warmup_input, training=False, use_cache=True)
422
 
423
+ # Warmup decode step
424
  single_token = tf.constant([[1]], dtype=tf.int32)
425
+ for _ in range(2):
426
  logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
427
 
428
+ print("βœ… Model warmed up")
429
+
430
 
431
  # ============================================================================
432
+ # Inference wrapper class for clean tf.function usage
433
  # ============================================================================
434
 
435
+ class InferenceEngine:
436
+ """Wrapper for compiled inference functions."""
437
+
438
+ def __init__(self, model):
439
+ self.model = model
440
+ self._prefill_fn = None
441
+ self._decode_fn = None
442
+
443
+ def prefill(self, input_ids):
444
+ """Run prefill (first call builds trace)."""
445
+ if self._prefill_fn is None:
446
+ # First call - run eagerly to ensure all caches are built
447
+ return self.model(input_ids, training=False, use_cache=True)
448
+ return self._prefill_fn(input_ids)
449
+
450
+ def decode(self, input_ids, past_kv):
451
+ """Run single-token decode."""
452
+ return self.model(input_ids, training=False, past_kv=past_kv, use_cache=True)
453
+
454
+ def compile_traces(self):
455
+ """Compile tf.function traces after warmup."""
456
+ print("πŸ”₯ Compiling optimized traces...")
457
+
458
+ @tf.function(reduce_retracing=True)
459
+ def prefill_fn(input_ids):
460
+ return self.model(input_ids, training=False, use_cache=True)
461
+
462
+ self._prefill_fn = prefill_fn
463
+
464
+ # Trace with sample inputs
465
+ sample_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
466
+ _ = self._prefill_fn(sample_input)
467
+
468
+ print("βœ… Traces compiled")
469
 
 
 
 
 
470
 
471
+ # Create inference engine
472
+ engine = InferenceEngine(model)
473
 
474
+ # Compile traces after warmup
475
+ engine.compile_traces()
 
 
 
476
 
477
 
478
  # ============================================================================
 
523
  input_tensor = tf.constant([input_ids], dtype=tf.int32)
524
 
525
  try:
526
+ logits, past_kv = engine.prefill(input_tensor)
527
  except Exception as e:
528
  yield f"Error during prefill: {e}"
529
  return
 
543
  yield generated_text + "\n\n*[Generation stopped]*"
544
  return
545
 
546
+ # Sample next token
547
  next_token_id = sampler.sample(
548
  next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
549
  )
 
561
  token_count += 1
562
  yield generated_text
563
 
564
+ # === DECODE PHASE ===
565
  next_input = tf.constant([[next_token_id]], dtype=tf.int32)
566
 
567
  try:
568
+ logits, past_kv = engine.decode(next_input, past_kv)
569
  except Exception as e:
570
  yield generated_text + f"\n\n*[Error during generation: {e}]*"
571
  return
572
 
573
  next_token_logits = logits[0, -1, :].numpy()
574
 
575
+ # Truncate cache if too long
576
  if step % 100 == 99:
577
  current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
578
  if current_len > max_context: