Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,12 +24,7 @@ import time
|
|
| 24 |
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
| 25 |
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
tf.config.optimizer.set_jit(True)
|
| 30 |
-
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled, XLA JIT enabled")
|
| 31 |
-
except:
|
| 32 |
-
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled")
|
| 33 |
|
| 34 |
# ============================================================================
|
| 35 |
# π FESTIVE MODE TOGGLE π
|
|
@@ -46,48 +41,48 @@ MODEL_REPO = "Smilyai-labs/Sam-large-2"
|
|
| 46 |
CACHE_DIR = "./model_cache"
|
| 47 |
|
| 48 |
# ============================================================================
|
| 49 |
-
# Model Architecture
|
| 50 |
# ============================================================================
|
| 51 |
|
| 52 |
@keras.saving.register_keras_serializable()
|
| 53 |
class RotaryEmbedding(keras.layers.Layer):
|
| 54 |
-
"""RoPE with cache built during layer build phase."""
|
| 55 |
-
|
| 56 |
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
|
| 57 |
super().__init__(**kwargs)
|
| 58 |
self.dim = dim
|
| 59 |
self.max_len = max_len
|
| 60 |
self.theta = theta
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def build(self, input_shape):
|
| 63 |
-
# Pre-compute RoPE cache as numpy arrays during build
|
| 64 |
-
inv_freq = 1.0 / (self.theta ** (np.arange(0, self.dim, 2, dtype=np.float32) / self.dim))
|
| 65 |
-
t = np.arange(self.max_len, dtype=np.float32)
|
| 66 |
-
freqs = np.outer(t, inv_freq)
|
| 67 |
-
emb = np.concatenate([freqs, freqs], axis=-1)
|
| 68 |
-
|
| 69 |
-
# Store as numpy arrays - will be converted to tensors in call()
|
| 70 |
-
self._cos_cached = np.cos(emb).astype(np.float32)
|
| 71 |
-
self._sin_cached = np.sin(emb).astype(np.float32)
|
| 72 |
-
|
| 73 |
super().build(input_shape)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def call(self, q, k, offset=0):
|
| 76 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
|
|
|
| 77 |
seq_len = tf.shape(q)[2]
|
| 78 |
dtype = q.dtype
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
sin = tf.cast(self._sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
x1_k, x2_k = tf.split(k, 2, axis=-1)
|
| 87 |
-
|
| 88 |
-
q_embed = (q * cos) + (tf.concat([-x2_q, x1_q], axis=-1) * sin)
|
| 89 |
-
k_embed = (k * cos) + (tf.concat([-x2_k, x1_k], axis=-1) * sin)
|
| 90 |
-
|
| 91 |
return q_embed, k_embed
|
| 92 |
|
| 93 |
def get_config(self):
|
|
@@ -119,8 +114,6 @@ class RMSNorm(keras.layers.Layer):
|
|
| 119 |
|
| 120 |
@keras.saving.register_keras_serializable()
|
| 121 |
class TransformerBlock(keras.layers.Layer):
|
| 122 |
-
"""Transformer block - MATCHES ORIGINAL CHECKPOINT STRUCTURE."""
|
| 123 |
-
|
| 124 |
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
|
| 125 |
super().__init__(**kwargs)
|
| 126 |
self.d_model = d_model
|
|
@@ -131,37 +124,38 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 131 |
self.rope_theta = rope_theta
|
| 132 |
self.head_dim = d_model // n_heads
|
| 133 |
self.layer_idx = layer_idx
|
| 134 |
-
self.scale = 1.0 / np.sqrt(self.head_dim)
|
| 135 |
|
| 136 |
def build(self, input_shape):
|
| 137 |
-
# MUST use same layer names as checkpoint
|
| 138 |
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
|
| 139 |
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
|
| 140 |
-
|
| 141 |
-
# Separate Q, K, V projections (matches checkpoint)
|
| 142 |
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
|
| 143 |
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
|
| 144 |
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
|
| 145 |
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
|
| 146 |
-
|
| 147 |
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
|
| 148 |
-
|
| 149 |
-
# Separate gate, up, down projections (matches checkpoint)
|
| 150 |
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
|
| 151 |
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
|
| 152 |
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
|
| 153 |
-
|
| 154 |
self.dropout = keras.layers.Dropout(self.dropout_rate)
|
| 155 |
super().build(input_shape)
|
| 156 |
|
| 157 |
def call(self, x, training=None, past_kv=None, use_cache=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
B = tf.shape(x)[0]
|
| 159 |
T = tf.shape(x)[1]
|
|
|
|
| 160 |
|
| 161 |
res = x
|
| 162 |
y = self.pre_attn_norm(x)
|
| 163 |
|
| 164 |
-
#
|
| 165 |
q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 166 |
q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
|
| 167 |
|
|
@@ -187,14 +181,15 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 187 |
|
| 188 |
new_kv = (k, v) if use_cache else None
|
| 189 |
|
| 190 |
-
#
|
| 191 |
full_len = tf.shape(k)[2]
|
| 192 |
-
scores = tf.matmul(q, k, transpose_b=True)
|
| 193 |
|
| 194 |
# Causal mask
|
| 195 |
q_positions = tf.range(past_len, past_len + T)
|
| 196 |
k_positions = tf.range(full_len)
|
| 197 |
-
mask = tf.cast(q_positions[:, None]
|
|
|
|
| 198 |
scores = scores + mask[None, None, :, :]
|
| 199 |
|
| 200 |
attn = tf.nn.softmax(scores, axis=-1)
|
|
@@ -204,10 +199,10 @@ class TransformerBlock(keras.layers.Layer):
|
|
| 204 |
|
| 205 |
x = res + self.dropout(self.out_proj(attn_out), training=training)
|
| 206 |
|
| 207 |
-
# FFN
|
| 208 |
res = x
|
| 209 |
y = self.pre_ffn_norm(x)
|
| 210 |
-
ffn = self.down_proj(
|
| 211 |
output = res + self.dropout(ffn, training=training)
|
| 212 |
|
| 213 |
return output, new_kv
|
|
@@ -255,6 +250,14 @@ class SAM1Model(keras.Model):
|
|
| 255 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 256 |
|
| 257 |
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
x = self.embed(input_ids)
|
| 259 |
|
| 260 |
new_past_kv = [] if use_cache else None
|
|
@@ -274,62 +277,6 @@ class SAM1Model(keras.Model):
|
|
| 274 |
return base_config
|
| 275 |
|
| 276 |
|
| 277 |
-
# ============================================================================
|
| 278 |
-
# Optimized Sampling
|
| 279 |
-
# ============================================================================
|
| 280 |
-
|
| 281 |
-
class FastSampler:
|
| 282 |
-
"""Vectorized sampler for faster token selection."""
|
| 283 |
-
|
| 284 |
-
def __init__(self, vocab_size):
|
| 285 |
-
self.vocab_size = vocab_size
|
| 286 |
-
self.rng = np.random.default_rng()
|
| 287 |
-
|
| 288 |
-
def sample(self, logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 289 |
-
"""Optimized sampling with vectorized operations."""
|
| 290 |
-
logits = logits.copy()
|
| 291 |
-
|
| 292 |
-
if temperature != 1.0:
|
| 293 |
-
logits = logits / temperature
|
| 294 |
-
|
| 295 |
-
# Vectorized repetition penalty
|
| 296 |
-
if repetition_penalty != 1.0 and token_freq:
|
| 297 |
-
freq_tokens = np.array(list(token_freq.keys()), dtype=np.int32)
|
| 298 |
-
freq_values = np.array(list(token_freq.values()), dtype=np.float32)
|
| 299 |
-
valid_mask = freq_tokens < len(logits)
|
| 300 |
-
freq_tokens = freq_tokens[valid_mask]
|
| 301 |
-
freq_values = freq_values[valid_mask]
|
| 302 |
-
if len(freq_tokens) > 0:
|
| 303 |
-
logits[freq_tokens] /= np.power(repetition_penalty, freq_values)
|
| 304 |
-
|
| 305 |
-
# Top-K filtering with partial sort
|
| 306 |
-
if 0 < top_k < len(logits):
|
| 307 |
-
top_k_indices = np.argpartition(logits, -top_k)[-top_k:]
|
| 308 |
-
top_k_logits = logits[top_k_indices]
|
| 309 |
-
else:
|
| 310 |
-
top_k_indices = np.arange(len(logits))
|
| 311 |
-
top_k_logits = logits
|
| 312 |
-
|
| 313 |
-
# Stable softmax
|
| 314 |
-
top_k_logits = top_k_logits - np.max(top_k_logits)
|
| 315 |
-
exp_logits = np.exp(top_k_logits)
|
| 316 |
-
top_k_probs = exp_logits / exp_logits.sum()
|
| 317 |
-
|
| 318 |
-
# Top-P (nucleus) filtering
|
| 319 |
-
if top_p < 1.0:
|
| 320 |
-
sorted_idx = np.argsort(top_k_probs)[::-1]
|
| 321 |
-
cumsum = np.cumsum(top_k_probs[sorted_idx])
|
| 322 |
-
cutoff = np.searchsorted(cumsum, top_p) + 1
|
| 323 |
-
nucleus_idx = sorted_idx[:cutoff]
|
| 324 |
-
nucleus_probs = top_k_probs[nucleus_idx]
|
| 325 |
-
nucleus_probs /= nucleus_probs.sum()
|
| 326 |
-
sampled = self.rng.choice(len(nucleus_probs), p=nucleus_probs)
|
| 327 |
-
return int(top_k_indices[nucleus_idx[sampled]])
|
| 328 |
-
else:
|
| 329 |
-
sampled = self.rng.choice(len(top_k_probs), p=top_k_probs)
|
| 330 |
-
return int(top_k_indices[sampled])
|
| 331 |
-
|
| 332 |
-
|
| 333 |
# --- Model and Tokenizer Loading ---
|
| 334 |
|
| 335 |
config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
|
|
@@ -375,7 +322,7 @@ if use_checkpoint:
|
|
| 375 |
'n_heads': config['num_attention_heads'],
|
| 376 |
'ff_mult': config['intermediate_size'] / config['hidden_size'],
|
| 377 |
'max_len': config['max_position_embeddings'],
|
| 378 |
-
'dropout': 0.
|
| 379 |
'rope_theta': config['rope_theta']
|
| 380 |
}
|
| 381 |
model = SAM1Model(config=model_config)
|
|
@@ -409,77 +356,56 @@ else:
|
|
| 409 |
if model:
|
| 410 |
print(f"β
Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
|
| 411 |
|
| 412 |
-
#
|
| 413 |
-
sampler = FastSampler(config['vocab_size'])
|
| 414 |
-
|
| 415 |
-
# Warm up the model (without tf.function first)
|
| 416 |
print("π₯ Warming up model...")
|
| 417 |
warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 418 |
-
|
| 419 |
-
# Initial warmup to build all internal caches
|
| 420 |
-
for _ in range(2):
|
| 421 |
-
logits, past_kv = model(warmup_input, training=False, use_cache=True)
|
| 422 |
-
|
| 423 |
-
# Warmup decode step
|
| 424 |
-
single_token = tf.constant([[1]], dtype=tf.int32)
|
| 425 |
-
for _ in range(2):
|
| 426 |
-
logits, past_kv = model(single_token, training=False, past_kv=past_kv, use_cache=True)
|
| 427 |
-
|
| 428 |
print("β
Model warmed up")
|
| 429 |
|
| 430 |
-
|
| 431 |
# ============================================================================
|
| 432 |
-
# Inference
|
| 433 |
# ============================================================================
|
| 434 |
|
| 435 |
-
|
| 436 |
-
"""Wrapper for compiled inference functions."""
|
| 437 |
-
|
| 438 |
-
def __init__(self, model):
|
| 439 |
-
self.model = model
|
| 440 |
-
self._prefill_fn = None
|
| 441 |
-
self._decode_fn = None
|
| 442 |
-
|
| 443 |
-
def prefill(self, input_ids):
|
| 444 |
-
"""Run prefill (first call builds trace)."""
|
| 445 |
-
if self._prefill_fn is None:
|
| 446 |
-
# First call - run eagerly to ensure all caches are built
|
| 447 |
-
return self.model(input_ids, training=False, use_cache=True)
|
| 448 |
-
return self._prefill_fn(input_ids)
|
| 449 |
-
|
| 450 |
-
def decode(self, input_ids, past_kv):
|
| 451 |
-
"""Run single-token decode."""
|
| 452 |
-
return self.model(input_ids, training=False, past_kv=past_kv, use_cache=True)
|
| 453 |
-
|
| 454 |
-
def compile_traces(self):
|
| 455 |
-
"""Compile tf.function traces after warmup."""
|
| 456 |
-
print("π₯ Compiling optimized traces...")
|
| 457 |
-
|
| 458 |
-
@tf.function(reduce_retracing=True)
|
| 459 |
-
def prefill_fn(input_ids):
|
| 460 |
-
return self.model(input_ids, training=False, use_cache=True)
|
| 461 |
-
|
| 462 |
-
self._prefill_fn = prefill_fn
|
| 463 |
-
|
| 464 |
-
# Trace with sample inputs
|
| 465 |
-
sample_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 466 |
-
_ = self._prefill_fn(sample_input)
|
| 467 |
-
|
| 468 |
-
print("β
Traces compiled")
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
# Create inference engine
|
| 472 |
-
engine = InferenceEngine(model)
|
| 473 |
|
| 474 |
-
# Compile traces after warmup
|
| 475 |
-
engine.compile_traces()
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
-
#
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
| 481 |
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
|
| 485 |
def generate_stream(
|
|
@@ -514,16 +440,17 @@ def generate_stream(
|
|
| 514 |
|
| 515 |
max_context = config['max_position_embeddings']
|
| 516 |
|
| 517 |
-
start_time = time.
|
| 518 |
|
| 519 |
# === PREFILL PHASE ===
|
|
|
|
| 520 |
if len(input_ids) > max_context - max_tokens:
|
| 521 |
input_ids = input_ids[-(max_context - max_tokens):]
|
| 522 |
|
| 523 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 524 |
|
| 525 |
try:
|
| 526 |
-
logits, past_kv =
|
| 527 |
except Exception as e:
|
| 528 |
yield f"Error during prefill: {e}"
|
| 529 |
return
|
|
@@ -531,12 +458,11 @@ def generate_stream(
|
|
| 531 |
# Get logits for last position
|
| 532 |
next_token_logits = logits[0, -1, :].numpy()
|
| 533 |
|
| 534 |
-
prefill_time = time.
|
| 535 |
-
|
| 536 |
-
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.3f}s ({prefill_tps:.1f} tok/s)")
|
| 537 |
|
| 538 |
# === GENERATION LOOP ===
|
| 539 |
-
decode_start = time.
|
| 540 |
|
| 541 |
for step in range(max_tokens):
|
| 542 |
if stop_generation:
|
|
@@ -544,7 +470,7 @@ def generate_stream(
|
|
| 544 |
return
|
| 545 |
|
| 546 |
# Sample next token
|
| 547 |
-
next_token_id =
|
| 548 |
next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
|
| 549 |
)
|
| 550 |
|
|
@@ -561,11 +487,11 @@ def generate_stream(
|
|
| 561 |
token_count += 1
|
| 562 |
yield generated_text
|
| 563 |
|
| 564 |
-
# === DECODE PHASE ===
|
| 565 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 566 |
|
| 567 |
try:
|
| 568 |
-
logits, past_kv =
|
| 569 |
except Exception as e:
|
| 570 |
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 571 |
return
|
|
@@ -573,24 +499,24 @@ def generate_stream(
|
|
| 573 |
next_token_logits = logits[0, -1, :].numpy()
|
| 574 |
|
| 575 |
# Truncate cache if too long
|
| 576 |
-
if
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
total_time = time.perf_counter() - start_time
|
| 587 |
|
| 588 |
if token_count > 0:
|
| 589 |
decode_tps = token_count / decode_time if decode_time > 0 else 0
|
|
|
|
| 590 |
|
| 591 |
stats = (
|
| 592 |
f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
|
| 593 |
-
f"(prefill: {prefill_time:.
|
| 594 |
)
|
| 595 |
|
| 596 |
if not stop_generation:
|
|
@@ -605,21 +531,20 @@ def generate_stream(
|
|
| 605 |
|
| 606 |
def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
|
| 607 |
"""Format message history and seed <think> if enabled."""
|
| 608 |
-
|
| 609 |
-
|
| 610 |
for user_msg, assistant_msg in history:
|
| 611 |
-
|
| 612 |
if assistant_msg:
|
|
|
|
| 613 |
clean_msg = assistant_msg.split("\n\n*[")[0]
|
| 614 |
-
|
|
|
|
|
|
|
| 615 |
|
| 616 |
-
prompt_parts.append(f"<|im_start|>user\n{message}<|im_end|>")
|
| 617 |
-
prompt_parts.append("<|im_start|>assistant")
|
| 618 |
-
|
| 619 |
if reasoning_enabled:
|
| 620 |
-
|
| 621 |
|
| 622 |
-
return
|
| 623 |
|
| 624 |
|
| 625 |
def chat_stream(
|
|
@@ -658,6 +583,7 @@ def chat_stream(
|
|
| 658 |
|
| 659 |
display_response = partial_response
|
| 660 |
if should_stop:
|
|
|
|
| 661 |
stats_start = partial_response.find("\n\n*[")
|
| 662 |
if stats_start > earliest_stop:
|
| 663 |
display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
|
|
@@ -843,7 +769,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 843 |
**Vocab:** {config['vocab_size']:,}
|
| 844 |
**Layers:** {config['num_hidden_layers']}
|
| 845 |
**Context:** {config['max_position_embeddings']:,} tokens
|
| 846 |
-
**Optimization:** KV-Cache
|
| 847 |
""")
|
| 848 |
|
| 849 |
gr.Examples(
|
|
@@ -919,7 +845,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
| 919 |
|
| 920 |
if __name__ == "__main__":
|
| 921 |
print("\n" + "=" * 60)
|
| 922 |
-
print("π Starting Sam-large-2 Chat with
|
| 923 |
print("=" * 60 + "\n")
|
| 924 |
demo.queue(max_size=20)
|
| 925 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
|
|
|
|
| 24 |
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES)
|
| 25 |
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES)
|
| 26 |
|
| 27 |
+
print(f"β
CPU optimized: {NUM_CORES} threads, oneDNN enabled")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# ============================================================================
|
| 30 |
# π FESTIVE MODE TOGGLE π
|
|
|
|
| 41 |
CACHE_DIR = "./model_cache"
|
| 42 |
|
| 43 |
# ============================================================================
|
| 44 |
+
# Model Architecture Definitions (Optimized with KV-Cache)
|
| 45 |
# ============================================================================
|
| 46 |
|
| 47 |
@keras.saving.register_keras_serializable()
|
| 48 |
class RotaryEmbedding(keras.layers.Layer):
|
|
|
|
|
|
|
| 49 |
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
|
| 50 |
super().__init__(**kwargs)
|
| 51 |
self.dim = dim
|
| 52 |
self.max_len = max_len
|
| 53 |
self.theta = theta
|
| 54 |
+
self.built_cache = False
|
| 55 |
+
self.cos_cached = None
|
| 56 |
+
self.sin_cached = None
|
| 57 |
|
| 58 |
def build(self, input_shape):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
super().build(input_shape)
|
| 60 |
|
| 61 |
+
def _build_cache(self):
|
| 62 |
+
if not self.built_cache:
|
| 63 |
+
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
|
| 64 |
+
t = tf.range(self.max_len, dtype=tf.float32)
|
| 65 |
+
freqs = tf.einsum("i,j->ij", t, inv_freq)
|
| 66 |
+
emb = tf.concat([freqs, freqs], axis=-1)
|
| 67 |
+
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
|
| 68 |
+
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
|
| 69 |
+
self.built_cache = True
|
| 70 |
+
|
| 71 |
+
def rotate_half(self, x):
|
| 72 |
+
x1, x2 = tf.split(x, 2, axis=-1)
|
| 73 |
+
return tf.concat([-x2, x1], axis=-1)
|
| 74 |
+
|
| 75 |
def call(self, q, k, offset=0):
|
| 76 |
"""Apply rotary embeddings with position offset for KV-cache."""
|
| 77 |
+
self._build_cache()
|
| 78 |
seq_len = tf.shape(q)[2]
|
| 79 |
dtype = q.dtype
|
| 80 |
|
| 81 |
+
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
| 82 |
+
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :]
|
|
|
|
| 83 |
|
| 84 |
+
q_embed = (q * cos) + (self.rotate_half(q) * sin)
|
| 85 |
+
k_embed = (k * cos) + (self.rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return q_embed, k_embed
|
| 87 |
|
| 88 |
def get_config(self):
|
|
|
|
| 114 |
|
| 115 |
@keras.saving.register_keras_serializable()
|
| 116 |
class TransformerBlock(keras.layers.Layer):
|
|
|
|
|
|
|
| 117 |
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
|
| 118 |
super().__init__(**kwargs)
|
| 119 |
self.d_model = d_model
|
|
|
|
| 124 |
self.rope_theta = rope_theta
|
| 125 |
self.head_dim = d_model // n_heads
|
| 126 |
self.layer_idx = layer_idx
|
|
|
|
| 127 |
|
| 128 |
def build(self, input_shape):
|
|
|
|
| 129 |
self.pre_attn_norm = RMSNorm(name="pre_attn_norm")
|
| 130 |
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm")
|
|
|
|
|
|
|
| 131 |
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj")
|
| 132 |
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj")
|
| 133 |
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj")
|
| 134 |
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj")
|
|
|
|
| 135 |
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta)
|
|
|
|
|
|
|
| 136 |
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj")
|
| 137 |
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj")
|
| 138 |
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj")
|
|
|
|
| 139 |
self.dropout = keras.layers.Dropout(self.dropout_rate)
|
| 140 |
super().build(input_shape)
|
| 141 |
|
| 142 |
def call(self, x, training=None, past_kv=None, use_cache=False):
|
| 143 |
+
"""
|
| 144 |
+
Args:
|
| 145 |
+
x: input tensor [B, T, D] (T=1 during cached generation)
|
| 146 |
+
past_kv: tuple of (past_k, past_v) each [B, n_heads, past_len, head_dim]
|
| 147 |
+
use_cache: whether to return updated kv cache
|
| 148 |
+
Returns:
|
| 149 |
+
output, (new_k, new_v) if use_cache else output, None
|
| 150 |
+
"""
|
| 151 |
B = tf.shape(x)[0]
|
| 152 |
T = tf.shape(x)[1]
|
| 153 |
+
dtype = x.dtype
|
| 154 |
|
| 155 |
res = x
|
| 156 |
y = self.pre_attn_norm(x)
|
| 157 |
|
| 158 |
+
# Project Q, K, V for current input
|
| 159 |
q = tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim])
|
| 160 |
q = tf.transpose(q, [0, 2, 1, 3]) # [B, n_heads, T, head_dim]
|
| 161 |
|
|
|
|
| 181 |
|
| 182 |
new_kv = (k, v) if use_cache else None
|
| 183 |
|
| 184 |
+
# Attention
|
| 185 |
full_len = tf.shape(k)[2]
|
| 186 |
+
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
|
| 187 |
|
| 188 |
# Causal mask
|
| 189 |
q_positions = tf.range(past_len, past_len + T)
|
| 190 |
k_positions = tf.range(full_len)
|
| 191 |
+
mask = tf.cast(q_positions[:, None] >= k_positions[None, :], dtype)
|
| 192 |
+
mask = tf.where(mask == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
|
| 193 |
scores = scores + mask[None, None, :, :]
|
| 194 |
|
| 195 |
attn = tf.nn.softmax(scores, axis=-1)
|
|
|
|
| 199 |
|
| 200 |
x = res + self.dropout(self.out_proj(attn_out), training=training)
|
| 201 |
|
| 202 |
+
# FFN
|
| 203 |
res = x
|
| 204 |
y = self.pre_ffn_norm(x)
|
| 205 |
+
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
|
| 206 |
output = res + self.dropout(ffn, training=training)
|
| 207 |
|
| 208 |
return output, new_kv
|
|
|
|
| 250 |
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
|
| 251 |
|
| 252 |
def call(self, input_ids, training=None, past_kv=None, use_cache=False):
|
| 253 |
+
"""
|
| 254 |
+
Args:
|
| 255 |
+
input_ids: [B, T]
|
| 256 |
+
past_kv: list of (k, v) tuples, one per layer
|
| 257 |
+
use_cache: whether to return updated cache
|
| 258 |
+
Returns:
|
| 259 |
+
logits, new_past_kv (or None)
|
| 260 |
+
"""
|
| 261 |
x = self.embed(input_ids)
|
| 262 |
|
| 263 |
new_past_kv = [] if use_cache else None
|
|
|
|
| 277 |
return base_config
|
| 278 |
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
# --- Model and Tokenizer Loading ---
|
| 281 |
|
| 282 |
config_path = hf_hub_download(MODEL_REPO, "config.json", cache_dir=CACHE_DIR)
|
|
|
|
| 322 |
'n_heads': config['num_attention_heads'],
|
| 323 |
'ff_mult': config['intermediate_size'] / config['hidden_size'],
|
| 324 |
'max_len': config['max_position_embeddings'],
|
| 325 |
+
'dropout': 0.1,
|
| 326 |
'rope_theta': config['rope_theta']
|
| 327 |
}
|
| 328 |
model = SAM1Model(config=model_config)
|
|
|
|
| 356 |
if model:
|
| 357 |
print(f"β
Model loaded: {config['num_hidden_layers']} layers, {config['vocab_size']} vocab")
|
| 358 |
|
| 359 |
+
# Warm up the model
|
|
|
|
|
|
|
|
|
|
| 360 |
print("π₯ Warming up model...")
|
| 361 |
warmup_input = tf.constant([[1, 2, 3, 4, 5]], dtype=tf.int32)
|
| 362 |
+
_, _ = model(warmup_input, training=False, use_cache=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
print("β
Model warmed up")
|
| 364 |
|
|
|
|
| 365 |
# ============================================================================
|
| 366 |
+
# Optimized Inference Logic with KV-Cache
|
| 367 |
# ============================================================================
|
| 368 |
|
| 369 |
+
stop_generation = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
|
|
|
|
|
|
| 371 |
|
| 372 |
+
def sample_token(logits, temperature, top_k, top_p, token_freq, repetition_penalty):
|
| 373 |
+
"""Pure NumPy sampling for speed."""
|
| 374 |
+
# Temperature scaling
|
| 375 |
+
scaled_logits = logits / temperature
|
| 376 |
|
| 377 |
+
# Repetition penalty
|
| 378 |
+
if repetition_penalty != 1.0:
|
| 379 |
+
for token_id, freq in token_freq.items():
|
| 380 |
+
if token_id < len(scaled_logits):
|
| 381 |
+
scaled_logits[token_id] /= (repetition_penalty ** freq)
|
| 382 |
|
| 383 |
+
# Top-K filtering
|
| 384 |
+
if top_k > 0 and top_k < len(scaled_logits):
|
| 385 |
+
top_k_indices = np.argpartition(scaled_logits, -top_k)[-top_k:]
|
| 386 |
+
top_k_logits = scaled_logits[top_k_indices]
|
| 387 |
+
else:
|
| 388 |
+
top_k_indices = np.arange(len(scaled_logits))
|
| 389 |
+
top_k_logits = scaled_logits
|
| 390 |
+
|
| 391 |
+
# Softmax (numerically stable)
|
| 392 |
+
top_k_logits = top_k_logits - np.max(top_k_logits)
|
| 393 |
+
top_k_probs = np.exp(top_k_logits)
|
| 394 |
+
top_k_probs /= top_k_probs.sum()
|
| 395 |
+
|
| 396 |
+
# Top-P (nucleus) filtering
|
| 397 |
+
if top_p < 1.0:
|
| 398 |
+
sorted_idx = np.argsort(top_k_probs)[::-1]
|
| 399 |
+
cumsum = np.cumsum(top_k_probs[sorted_idx])
|
| 400 |
+
cutoff = np.searchsorted(cumsum, top_p) + 1
|
| 401 |
+
nucleus_idx = sorted_idx[:cutoff]
|
| 402 |
+
nucleus_probs = top_k_probs[nucleus_idx]
|
| 403 |
+
nucleus_probs /= nucleus_probs.sum()
|
| 404 |
+
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs)
|
| 405 |
+
return int(top_k_indices[nucleus_idx[sampled]])
|
| 406 |
+
else:
|
| 407 |
+
sampled = np.random.choice(len(top_k_probs), p=top_k_probs)
|
| 408 |
+
return int(top_k_indices[sampled])
|
| 409 |
|
| 410 |
|
| 411 |
def generate_stream(
|
|
|
|
| 440 |
|
| 441 |
max_context = config['max_position_embeddings']
|
| 442 |
|
| 443 |
+
start_time = time.time()
|
| 444 |
|
| 445 |
# === PREFILL PHASE ===
|
| 446 |
+
# Truncate if prompt is too long
|
| 447 |
if len(input_ids) > max_context - max_tokens:
|
| 448 |
input_ids = input_ids[-(max_context - max_tokens):]
|
| 449 |
|
| 450 |
input_tensor = tf.constant([input_ids], dtype=tf.int32)
|
| 451 |
|
| 452 |
try:
|
| 453 |
+
logits, past_kv = model(input_tensor, training=False, use_cache=True)
|
| 454 |
except Exception as e:
|
| 455 |
yield f"Error during prefill: {e}"
|
| 456 |
return
|
|
|
|
| 458 |
# Get logits for last position
|
| 459 |
next_token_logits = logits[0, -1, :].numpy()
|
| 460 |
|
| 461 |
+
prefill_time = time.time() - start_time
|
| 462 |
+
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s")
|
|
|
|
| 463 |
|
| 464 |
# === GENERATION LOOP ===
|
| 465 |
+
decode_start = time.time()
|
| 466 |
|
| 467 |
for step in range(max_tokens):
|
| 468 |
if stop_generation:
|
|
|
|
| 470 |
return
|
| 471 |
|
| 472 |
# Sample next token
|
| 473 |
+
next_token_id = sample_token(
|
| 474 |
next_token_logits, temperature, top_k, top_p, token_freq, repetition_penalty
|
| 475 |
)
|
| 476 |
|
|
|
|
| 487 |
token_count += 1
|
| 488 |
yield generated_text
|
| 489 |
|
| 490 |
+
# === DECODE PHASE (single token, reuse cache) ===
|
| 491 |
next_input = tf.constant([[next_token_id]], dtype=tf.int32)
|
| 492 |
|
| 493 |
try:
|
| 494 |
+
logits, past_kv = model(next_input, training=False, past_kv=past_kv, use_cache=True)
|
| 495 |
except Exception as e:
|
| 496 |
yield generated_text + f"\n\n*[Error during generation: {e}]*"
|
| 497 |
return
|
|
|
|
| 499 |
next_token_logits = logits[0, -1, :].numpy()
|
| 500 |
|
| 501 |
# Truncate cache if too long
|
| 502 |
+
current_len = past_kv[0][0].shape[2] if past_kv and past_kv[0] is not None else 0
|
| 503 |
+
if current_len > max_context:
|
| 504 |
+
trim_amount = current_len - max_context + 100 # Keep some buffer
|
| 505 |
+
past_kv = [
|
| 506 |
+
(k[:, :, trim_amount:, :], v[:, :, trim_amount:, :])
|
| 507 |
+
for k, v in past_kv
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
decode_time = time.time() - decode_start
|
| 511 |
+
total_time = time.time() - start_time
|
|
|
|
| 512 |
|
| 513 |
if token_count > 0:
|
| 514 |
decode_tps = token_count / decode_time if decode_time > 0 else 0
|
| 515 |
+
total_tps = token_count / total_time if total_time > 0 else 0
|
| 516 |
|
| 517 |
stats = (
|
| 518 |
f"\n\n*[Generated {token_count} tokens in {total_time:.1f}s "
|
| 519 |
+
f"(prefill: {prefill_time:.1f}s, decode: {decode_tps:.1f} tok/s)]*"
|
| 520 |
)
|
| 521 |
|
| 522 |
if not stop_generation:
|
|
|
|
| 531 |
|
| 532 |
def format_chat_prompt(message: str, history: list, reasoning_enabled: bool) -> str:
|
| 533 |
"""Format message history and seed <think> if enabled."""
|
| 534 |
+
prompt = ""
|
|
|
|
| 535 |
for user_msg, assistant_msg in history:
|
| 536 |
+
prompt += f"<|im_start|>user\n{user_msg}<|im_end|>\n"
|
| 537 |
if assistant_msg:
|
| 538 |
+
# Clean up any stats from previous messages
|
| 539 |
clean_msg = assistant_msg.split("\n\n*[")[0]
|
| 540 |
+
prompt += f"<|im_start|>assistant\n{clean_msg}<|im_end|>\n"
|
| 541 |
+
|
| 542 |
+
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
|
| 543 |
|
|
|
|
|
|
|
|
|
|
| 544 |
if reasoning_enabled:
|
| 545 |
+
prompt += "<think>"
|
| 546 |
|
| 547 |
+
return prompt
|
| 548 |
|
| 549 |
|
| 550 |
def chat_stream(
|
|
|
|
| 583 |
|
| 584 |
display_response = partial_response
|
| 585 |
if should_stop:
|
| 586 |
+
# Keep the stats portion if present
|
| 587 |
stats_start = partial_response.find("\n\n*[")
|
| 588 |
if stats_start > earliest_stop:
|
| 589 |
display_response = partial_response[:earliest_stop] + partial_response[stats_start:]
|
|
|
|
| 769 |
**Vocab:** {config['vocab_size']:,}
|
| 770 |
**Layers:** {config['num_hidden_layers']}
|
| 771 |
**Context:** {config['max_position_embeddings']:,} tokens
|
| 772 |
+
**Optimization:** KV-Cache enabled β‘
|
| 773 |
""")
|
| 774 |
|
| 775 |
gr.Examples(
|
|
|
|
| 845 |
|
| 846 |
if __name__ == "__main__":
|
| 847 |
print("\n" + "=" * 60)
|
| 848 |
+
print("π Starting Sam-large-2 Chat with KV-Cache Optimization")
|
| 849 |
print("=" * 60 + "\n")
|
| 850 |
demo.queue(max_size=20)
|
| 851 |
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
|