Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,6 @@ import json
|
|
| 19 |
from tokenizers import Tokenizer
|
| 20 |
import numpy as np
|
| 21 |
import time
|
| 22 |
-
from functools import lru_cache
|
| 23 |
|
| 24 |
# Configure TF threading
|
| 25 |
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
|
@@ -47,43 +46,40 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
|
|
| 47 |
CACHE_DIR = "./model_cache"
|
| 48 |
|
| 49 |
# ============================================================================
|
| 50 |
-
# Model Architecture -
|
| 51 |
# ============================================================================
|
| 52 |
|
| 53 |
@keras.saving.register_keras_serializable()
|
| 54 |
class RotaryEmbedding(keras.layers.Layer):
|
| 55 |
-
"""RoPE with
|
| 56 |
|
| 57 |
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
|
| 58 |
super().__init__(**kwargs)
|
| 59 |
self.dim = dim
|
| 60 |
self.max_len = max_len
|
| 61 |
self.theta = theta
|
| 62 |
-
self.built_cache = False
|
| 63 |
-
self.cos_cached = None
|
| 64 |
-
self.sin_cached = None
|
| 65 |
|
| 66 |
def build(self, input_shape):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
super().build(input_shape)
|
| 68 |
|
| 69 |
-
def _build_cache(self):
|
| 70 |
-
if not self.built_cache:
|
| 71 |
-
inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
|
| 72 |
-
t = np.arange(self.max_len, dtype=np.float32)
|
| 73 |
-
freqs = np.outer(t, inv_freq)
|
| 74 |
-
emb = np.concatenate([freqs, freqs], axis=-1)
|
| 75 |
-
self.cos_cached = tf.constant(np.cos(emb), dtype=tf.float32)
|
| 76 |
-
self.sin_cached = tf.constant(np.sin(emb), dtype=tf.float32)
|
| 77 |
-
self.built_cache = True
|
| 78 |
-
|
| 79 |
def call(self, q, k, offset=0):
|
| 80 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
| 81 |
-
self._build_cache()
|
| 82 |
seq_len = tf.shape(q)[2]
|
| 83 |
dtype = q.dtype
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
# Fused rotate_half
|
| 89 |
x1_q, x2_q = tf.split(q, 2, axis=-1)
|
|
@@ -176,7 +172,10 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 176 |
v = tf.transpose(v, [0, 2, 1, 3])
|
| 177 |
|
| 178 |
# Determine position offset for RoPE
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Apply RoPE with position offset
|
| 182 |
q, k = self.rope(q, k, offset=past_len)
|
|
@@ -192,7 +191,7 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 192 |
full_len = tf.shape(k)[2]
|
| 193 |
scores = tf.matmul(q, k, transpose_b=True) * self.scale
|
| 194 |
|
| 195 |
-
#
|
| 196 |
q_positions = tf.range(past_len, past_len + T)
|
| 197 |
k_positions = tf.range(full_len)
|
| 198 |
mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
|
|
@@ -288,10 +287,8 @@ class FastSampler:
|
|
| 288 |
|
| 289 |
def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 290 |
"""Optimized sampling with vectorized operations."""
|
| 291 |
-
# Make a copy to avoid modifying original
|
| 292 |
logits = logits.copy()
|
| 293 |
|
| 294 |
-
# Temperature scaling
|
| 295 |
if temperature != 1.0:
|
| 296 |
logits = logits / temperature
|
| 297 |
|
|
@@ -415,43 +412,67 @@ if model:
|
|
| 415 |
# Initialize fast sampler
|
| 416 |
sampler = FastSampler(config['vocab_size'])
|
| 417 |
|
| 418 |
-
# Warm up
|
| 419 |
-
print("π₯ Warming up model
|
| 420 |
warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 421 |
|
| 422 |
-
#
|
| 423 |
-
for _ in range(
|
| 424 |
logits, past_kv = model(warmup_input, training=False, use_cache=True)
|
| 425 |
|
| 426 |
-
#
|
| 427 |
single_token = tf.constant([[1]], dtype=tf.int32)
|
| 428 |
-
for _ in range(
|
| 429 |
logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
|
| 430 |
|
| 431 |
-
print("β
Model warmed up
|
|
|
|
| 432 |
|
| 433 |
# ============================================================================
|
| 434 |
-
#
|
| 435 |
# ============================================================================
|
| 436 |
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
@tf.function(reduce_retracing=True)
|
| 445 |
-
def model_decode(input_ids, past_kv):
|
| 446 |
-
"""Compiled single-token decode function."""
|
| 447 |
-
return model(input_ids, training=False, past_kv=past_kv, use_cache=True)
|
| 448 |
|
|
|
|
|
|
|
| 449 |
|
| 450 |
-
#
|
| 451 |
-
|
| 452 |
-
_ = model_prefill(warmup_input)
|
| 453 |
-
_ = model_decode(single_token, past_kv)
|
| 454 |
-
print("β
Compiled functions ready")
|
| 455 |
|
| 456 |
|
| 457 |
# ============================================================================
|
|
@@ -502,7 +523,7 @@ def generate_stream(
|
|
| 502 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 503 |
|
| 504 |
try:
|
| 505 |
-
logits, past_kv =
|
| 506 |
except Exception as e:
|
| 507 |
yield f"Error during prefill: {e}"
|
| 508 |
return
|
|
@@ -522,7 +543,7 @@ def generate_stream(
|
|
| 522 |
yield generated_text + "\n\n*[Generation stopped]*"
|
| 523 |
return
|
| 524 |
|
| 525 |
-
# Sample next token
|
| 526 |
next_token_id = sampler.sample(
|
| 527 |
next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
|
| 528 |
)
|
|
@@ -540,18 +561,18 @@ def generate_stream(
|
|
| 540 |
token_count += 1
|
| 541 |
yield generated_text
|
| 542 |
|
| 543 |
-
# === DECODE PHASE
|
| 544 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 545 |
|
| 546 |
try:
|
| 547 |
-
logits, past_kv =
|
| 548 |
except Exception as e:
|
| 549 |
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 550 |
return
|
| 551 |
|
| 552 |
next_token_logits = logits[0, -1, :].numpy()
|
| 553 |
|
| 554 |
-
# Truncate cache if too long
|
| 555 |
if step % 100 == 99:
|
| 556 |
current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
|
| 557 |
if current_len > max_context:
|
|
|
|
| 19 |
from tokenizers import Tokenizer
|
| 20 |
import numpy as np
|
| 21 |
import time
|
|
|
|
| 22 |
|
| 23 |
# Configure TF threading
|
| 24 |
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
|
|
|
| 46 |
CACHE_DIR = "./model_cache"
|
| 47 |
|
| 48 |
# ============================================================================
|
| 49 |
+
# Model Architecture - MATCHES CHECKPOINT STRUCTURE
|
| 50 |
# ============================================================================
|
| 51 |
|
| 52 |
@keras.saving.register_keras_serializable()
|
| 53 |
class RotaryEmbedding(keras.layers.Layer):
|
| 54 |
+
"""RoPE with cache built during layer build phase."""
|
| 55 |
|
| 56 |
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
|
| 57 |
super().__init__(**kwargs)
|
| 58 |
self.dim = dim
|
| 59 |
self.max_len = max_len
|
| 60 |
self.theta = theta
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def build(self, input_shape):
|
| 63 |
+
# Pre-compute RoPE cache as numpy arrays during build
|
| 64 |
+
inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
|
| 65 |
+
t = np.arange(self.max_len, dtype=np.float32)
|
| 66 |
+
freqs = np.outer(t, inv_freq)
|
| 67 |
+
emb = np.concatenate([freqs, freqs], axis=-1)
|
| 68 |
+
|
| 69 |
+
# Store as numpy arrays - will be converted to tensors in call()
|
| 70 |
+
self._cos_cached = np.cos(emb).astype(np.float32)
|
| 71 |
+
self._sin_cached = np.sin(emb).astype(np.float32)
|
| 72 |
+
|
| 73 |
super().build(input_shape)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def call(self, q, k, offset=0):
|
| 76 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
|
|
|
| 77 |
seq_len = tf.shape(q)[2]
|
| 78 |
dtype = q.dtype
|
| 79 |
|
| 80 |
+
# Slice the pre-computed values
|
| 81 |
+
cos = tf.cast(self._cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 82 |
+
sin = tf.cast(self._sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 83 |
|
| 84 |
# Fused rotate_half
|
| 85 |
x1_q, x2_q = tf.split(q, 2, axis=-1)
|
|
|
|
| 172 |
v = tf.transpose(v, [0, 2, 1, 3])
|
| 173 |
|
| 174 |
# Determine position offset for RoPE
|
| 175 |
+
if past_kv is not None:
|
| 176 |
+
past_len = tf.shape(past_kv[0])[2]
|
| 177 |
+
else:
|
| 178 |
+
past_len = 0
|
| 179 |
|
| 180 |
# Apply RoPE with position offset
|
| 181 |
q, k = self.rope(q, k, offset=past_len)
|
|
|
|
| 191 |
full_len = tf.shape(k)[2]
|
| 192 |
scores = tf.matmul(q, k, transpose_b=True) * self.scale
|
| 193 |
|
| 194 |
+
# Causal mask
|
| 195 |
q_positions = tf.range(past_len, past_len + T)
|
| 196 |
k_positions = tf.range(full_len)
|
| 197 |
mask = tf.cast(q_positions[:, None] < k_positions[None, :], scores.dtype) * -1e9
|
|
|
|
| 287 |
|
| 288 |
def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 289 |
"""Optimized sampling with vectorized operations."""
|
|
|
|
| 290 |
logits = logits.copy()
|
| 291 |
|
|
|
|
| 292 |
if temperature != 1.0:
|
| 293 |
logits = logits / temperature
|
| 294 |
|
|
|
|
| 412 |
# Initialize fast sampler
|
| 413 |
sampler = FastSampler(config['vocab_size'])
|
| 414 |
|
| 415 |
+
# Warm up the model (without tf.function first)
|
| 416 |
+
print("π₯ Warming up model...")
|
| 417 |
warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 418 |
|
| 419 |
+
# Initial warmup to build all internal caches
|
| 420 |
+
for _ in range(2):
|
| 421 |
logits, past_kv = model(warmup_input, training=False, use_cache=True)
|
| 422 |
|
| 423 |
+
# Warmup decode step
|
| 424 |
single_token = tf.constant([[1]], dtype=tf.int32)
|
| 425 |
+
for _ in range(2):
|
| 426 |
logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
|
| 427 |
|
| 428 |
+
print("β
Model warmed up")
|
| 429 |
+
|
| 430 |
|
| 431 |
# ============================================================================
|
| 432 |
+
# Inference wrapper class for clean tf.function usage
|
| 433 |
# ============================================================================
|
| 434 |
|
| 435 |
+
class InferenceEngine:
|
| 436 |
+
"""Wrapper for compiled inference functions."""
|
| 437 |
+
|
| 438 |
+
def __init__(self, model):
|
| 439 |
+
self.model = model
|
| 440 |
+
self._prefill_fn = None
|
| 441 |
+
self._decode_fn = None
|
| 442 |
+
|
| 443 |
+
def prefill(self, input_ids):
|
| 444 |
+
"""Run prefill (first call builds trace)."""
|
| 445 |
+
if self._prefill_fn is None:
|
| 446 |
+
# First call - run eagerly to ensure all caches are built
|
| 447 |
+
return self.model(input_ids, training=False, use_cache=True)
|
| 448 |
+
return self._prefill_fn(input_ids)
|
| 449 |
+
|
| 450 |
+
def decode(self, input_ids, past_kv):
|
| 451 |
+
"""Run single-token decode."""
|
| 452 |
+
return self.model(input_ids, training=False, past_kv=past_kv, use_cache=True)
|
| 453 |
+
|
| 454 |
+
def compile_traces(self):
|
| 455 |
+
"""Compile tf.function traces after warmup."""
|
| 456 |
+
print("π₯ Compiling optimized traces...")
|
| 457 |
+
|
| 458 |
+
@tf.function(reduce_retracing=True)
|
| 459 |
+
def prefill_fn(input_ids):
|
| 460 |
+
return self.model(input_ids, training=False, use_cache=True)
|
| 461 |
+
|
| 462 |
+
self._prefill_fn = prefill_fn
|
| 463 |
+
|
| 464 |
+
# Trace with sample inputs
|
| 465 |
+
sample_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 466 |
+
_ = self._prefill_fn(sample_input)
|
| 467 |
+
|
| 468 |
+
print("β
Traces compiled")
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
# Create inference engine
|
| 472 |
+
engine = InferenceEngine(model)
|
| 473 |
|
| 474 |
+
# Compile traces after warmup
|
| 475 |
+
engine.compile_traces()
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
|
| 478 |
# ============================================================================
|
|
|
|
| 523 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 524 |
|
| 525 |
try:
|
| 526 |
+
logits, past_kv = engine.prefill(input_tensor)
|
| 527 |
except Exception as e:
|
| 528 |
yield f"Error during prefill: {e}"
|
| 529 |
return
|
|
|
|
| 543 |
yield generated_text + "\n\n*[Generation stopped]*"
|
| 544 |
return
|
| 545 |
|
| 546 |
+
# Sample next token
|
| 547 |
next_token_id = sampler.sample(
|
| 548 |
next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
|
| 549 |
)
|
|
|
|
| 561 |
token_count += 1
|
| 562 |
yield generated_text
|
| 563 |
|
| 564 |
+
# === DECODE PHASE ===
|
| 565 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 566 |
|
| 567 |
try:
|
| 568 |
+
logits, past_kv = engine.decode(next_input, past_kv)
|
| 569 |
except Exception as e:
|
| 570 |
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 571 |
return
|
| 572 |
|
| 573 |
next_token_logits = logits[0, -1, :].numpy()
|
| 574 |
|
| 575 |
+
# Truncate cache if too long
|
| 576 |
if step % 100 == 99:
|
| 577 |
current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
|
| 578 |
if current_len > max_context:
|