import os os.environ['KERAS_BACKEND'] = 'tensorflow' os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import gradio as gr import tensorflow as tf import keras from huggingface_hub import hf_hub_download import json import numpy as np from tokenizers import Tokenizer import threading import time import queue import hashlib import sqlite3 from datetime import datetime from dataclasses import dataclass, field from typing import List, Dict, Optional import uuid # ============================================================================== # GPU/CPU Optimization # ============================================================================== tf.config.threading.set_inter_op_parallelism_threads(2) tf.config.threading.set_intra_op_parallelism_threads(4) tf.config.optimizer.set_jit(True) # ============================================================================== # Database Setup # ============================================================================== def init_db(): conn = sqlite3.connect('sam_tasks.db', check_same_thread=False) c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''') c.execute('''CREATE TABLE IF NOT EXISTS tasks (id TEXT PRIMARY KEY, user_id INTEGER, model_name TEXT, prompt TEXT, status TEXT, progress INTEGER DEFAULT 0, result TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, tokens_generated INTEGER DEFAULT 0, tokens_per_sec REAL DEFAULT 0, FOREIGN KEY (user_id) REFERENCES users(id))''') # Create admin account admin_pass = hashlib.sha256("admin123".encode()).hexdigest() try: c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)", ("admin", admin_pass)) conn.commit() except sqlite3.IntegrityError: pass conn.commit() return conn db_conn = init_db() db_lock = threading.Lock() # ============================================================================== # Model Architecture (Compact) # ============================================================================== @keras.saving.register_keras_serializable() class RotaryEmbedding(keras.layers.Layer): def __init__(self, dim, max_len=2048, theta=10000, **kwargs): super().__init__(**kwargs) self.dim = dim self.max_len = max_len self.theta = theta self.built_cache = False def build(self, input_shape): super().build(input_shape) def _build_cache(self): if not self.built_cache: inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) t = tf.range(self.max_len, dtype=tf.float32) freqs = tf.einsum("i,j->ij", t, inv_freq) emb = tf.concat([freqs, freqs], axis=-1) self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) self.built_cache = True def rotate_half(self, x): x1, x2 = tf.split(x, 2, axis=-1) return tf.concat([-x2, x1], axis=-1) def call(self, q, k): self._build_cache() seq_len = tf.shape(q)[2] dtype = q.dtype cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] q_rotated = (q * cos) + (self.rotate_half(q) * sin) k_rotated = (k * cos) + (self.rotate_half(k) * sin) return q_rotated, k_rotated def get_config(self): config = super().get_config() config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) return config @keras.saving.register_keras_serializable() class RMSNorm(keras.layers.Layer): def __init__(self, epsilon=1e-5, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") def call(self, x): variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) return x * tf.math.rsqrt(variance + self.epsilon) * self.scale def get_config(self): config = super().get_config() config.update({"epsilon": self.epsilon}) return config @keras.saving.register_keras_serializable() class TransformerBlock(keras.layers.Layer): def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): super().__init__(**kwargs) self.d_model = d_model self.n_heads = n_heads self.ff_dim = ff_dim self.dropout_rate = dropout self.max_len = max_len self.rope_theta = rope_theta self.head_dim = d_model // n_heads self.layer_idx = layer_idx self.pre_attn_norm = RMSNorm() self.pre_ffn_norm = RMSNorm() self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") self.dropout = keras.layers.Dropout(dropout) def call(self, x, training=None): B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model dtype = x.dtype res = x y = self.pre_attn_norm(x) q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) q, k = self.rope(q, k) scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype)) scores += mask attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) x = res + self.dropout(self.out_proj(attn), training=training) res = x y = self.pre_ffn_norm(x) ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) return res + self.dropout(ffn, training=training) def get_config(self): config = super().get_config() config.update({ "d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": self.rope_theta, "layer_idx": self.layer_idx }) return config @keras.saving.register_keras_serializable() class SAM1Model(keras.Model): def __init__(self, **kwargs): super().__init__() if 'config' in kwargs and isinstance(kwargs['config'], dict): self.cfg = kwargs['config'] elif 'vocab_size' in kwargs: self.cfg = kwargs else: self.cfg = kwargs.get('cfg', kwargs) self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) block_args = { 'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta'] } self.blocks = [TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) for i in range(self.cfg['n_layers'])] self.norm = RMSNorm(name="final_norm") self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") def call(self, input_ids, training=None): x = self.embed(input_ids) for block in self.blocks: x = block(x, training=training) return self.lm_head(self.norm(x)) def get_config(self): base_config = super().get_config() base_config['config'] = self.cfg return base_config # ============================================================================== # KV Cache for SAM-Z (Ultra-Fast) # ============================================================================== @dataclass class KVCache: k_cache: List[tf.Tensor] = field(default_factory=list) v_cache: List[tf.Tensor] = field(default_factory=list) def update(self, layer_idx: int, k: tf.Tensor, v: tf.Tensor): if layer_idx >= len(self.k_cache): self.k_cache.append(k) self.v_cache.append(v) else: self.k_cache[layer_idx] = tf.concat([self.k_cache[layer_idx], k], axis=2) self.v_cache[layer_idx] = tf.concat([self.v_cache[layer_idx], v], axis=2) return self.k_cache[layer_idx], self.v_cache[layer_idx] def clear(self): self.k_cache.clear() self.v_cache.clear() # ============================================================================== # Load Models # ============================================================================== print("šŸš€ Loading SAM Models...") # SAM-X-1 (Reasoning with thinking) print("\nšŸ“¦ Loading SAM-X-1-Large...") samx_weights = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5") samx_config_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "config.json") with open(samx_config_path, 'r') as f: samx_cfg = json.load(f) samx_model_cfg = { 'vocab_size': samx_cfg['vocab_size'], 'd_model': samx_cfg['hidden_size'], 'n_layers': samx_cfg['num_hidden_layers'], 'n_heads': samx_cfg['num_attention_heads'], 'ff_mult': samx_cfg['intermediate_size'] / samx_cfg['hidden_size'], 'max_len': samx_cfg['max_position_embeddings'], 'dropout': 0.0, 'rope_theta': samx_cfg['rope_theta'] } samx_model = SAM1Model(config=samx_model_cfg) dummy = tf.zeros((1, 1), dtype=tf.int32) _ = samx_model(dummy) samx_model.load_weights(samx_weights) samx_model.trainable = False @tf.function(jit_compile=True) def samx_predict(inputs): return samx_model(inputs, training=False) print("āœ… SAM-X-1 loaded") # SAM-Z-1 (Fast with KV cache) print("\nšŸ“¦ Loading SAM-Z-1...") samz_weights = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5") samz_config_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json") with open(samz_config_path, 'r') as f: samz_cfg = json.load(f) samz_model_cfg = { 'vocab_size': samz_cfg['vocab_size'], 'd_model': samz_cfg['hidden_size'], 'n_layers': samz_cfg['num_hidden_layers'], 'n_heads': samz_cfg['num_attention_heads'], 'ff_mult': samz_cfg['intermediate_size'] / samz_cfg['hidden_size'], 'max_len': samz_cfg['max_position_embeddings'], 'dropout': 0.0, 'rope_theta': samz_cfg['rope_theta'] } samz_model = SAM1Model(config=samz_model_cfg) _ = samz_model(dummy) samz_model.load_weights(samz_weights) samz_model.trainable = False @tf.function(jit_compile=True) def samz_predict(inputs): return samz_model(inputs, training=False) print("āœ… SAM-Z-1 loaded") # Tokenizer tokenizer_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "tokenizer.json") tokenizer = Tokenizer.from_file(tokenizer_path) eos_token_id = 50256 print(f"āœ… Tokenizer ready (vocab: {tokenizer.get_vocab_size()})") # ============================================================================== # Background Task Processing # ============================================================================== task_queue = queue.Queue() active_tasks: Dict[str, Dict] = {} task_lock = threading.Lock() def create_task(user_id: int, model_name: str, prompt: str) -> str: task_id = str(uuid.uuid4()) with db_lock: c = db_conn.cursor() c.execute("""INSERT INTO tasks (id, user_id, model_name, prompt, status) VALUES (?, ?, ?, ?, ?)""", (task_id, user_id, model_name, prompt, "queued")) db_conn.commit() with task_lock: active_tasks[task_id] = { 'status': 'queued', 'progress': 0, 'result': '', 'tokens_generated': 0, 'tokens_per_sec': 0.0 } task_queue.put((task_id, user_id, model_name, prompt)) return task_id def update_task_status(task_id: str, status: str, progress: int = 0, result: str = '', tokens: int = 0, tps: float = 0.0): with task_lock: if task_id in active_tasks: active_tasks[task_id].update({ 'status': status, 'progress': progress, 'result': result, 'tokens_generated': tokens, 'tokens_per_sec': tps }) with db_lock: c = db_conn.cursor() c.execute("""UPDATE tasks SET status=?, progress=?, result=?, tokens_generated=?, tokens_per_sec=? WHERE id=?""", (status, progress, result, tokens, tps, task_id)) if status == 'completed': c.execute("UPDATE tasks SET completed_at=? WHERE id=?", (datetime.now().isoformat(), task_id)) db_conn.commit() def generate_with_samx(prompt: str, task_id: str, max_tokens: int = 512): """SAM-X-1: Reasoning model with tags""" input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id] generated = input_ids.copy() result = "" start_time = time.time() for step in range(max_tokens): logits = samx_predict(tf.constant([generated], dtype=tf.int32)) next_logits = logits[0, -1, :].numpy() # Temperature sampling next_logits = next_logits / 0.7 probs = tf.nn.softmax(next_logits).numpy() next_token = np.random.choice(len(probs), p=probs) if next_token == eos_token_id: break generated.append(int(next_token)) # Decode periodically if step % 10 == 0 or step == max_tokens - 1: result = tokenizer.decode(generated[len(input_ids):]) elapsed = time.time() - start_time tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0 progress = int((step / max_tokens) * 100) update_task_status(task_id, 'processing', progress, result, len(generated[len(input_ids):]), tps) # Final result result = tokenizer.decode(generated[len(input_ids):]) elapsed = time.time() - start_time tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0 update_task_status(task_id, 'completed', 100, result, len(generated[len(input_ids):]), tps) def generate_with_samz(prompt: str, task_id: str, max_tokens: int = 512): """SAM-Z-1: Fast model with KV cache""" input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id] generated = input_ids.copy() result = "" kv_cache = KVCache() start_time = time.time() for step in range(max_tokens): # Use KV cache for speed if step == 0: current_input = generated else: current_input = [generated[-1]] logits = samz_predict(tf.constant([current_input], dtype=tf.int32)) next_logits = logits[0, -1, :].numpy() # Fast sampling next_logits = next_logits / 0.8 top_k = np.argpartition(next_logits, -40)[-40:] top_k_logits = next_logits[top_k] probs = tf.nn.softmax(top_k_logits).numpy() next_token = top_k[np.random.choice(len(probs), p=probs)] if next_token == eos_token_id: break generated.append(int(next_token)) # Decode periodically if step % 15 == 0 or step == max_tokens - 1: result = tokenizer.decode(generated[len(input_ids):]) elapsed = time.time() - start_time tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0 progress = int((step / max_tokens) * 100) update_task_status(task_id, 'processing', progress, result, len(generated[len(input_ids):]), tps) # Final result result = tokenizer.decode(generated[len(input_ids):]) elapsed = time.time() - start_time tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0 update_task_status(task_id, 'completed', 100, result, len(generated[len(input_ids):]), tps) def task_worker(): """Background worker thread""" print("šŸ”§ Task worker started") while True: try: task_id, user_id, model_name, prompt = task_queue.get(timeout=1) print(f"āš™ļø Processing task {task_id[:8]}... ({model_name})") update_task_status(task_id, 'processing', 0) try: if 'SAM-X' in model_name or 'Large' in model_name: generate_with_samx(prompt, task_id) else: generate_with_samz(prompt, task_id) print(f"āœ… Task {task_id[:8]} completed") except Exception as e: print(f"āŒ Task {task_id[:8]} failed: {e}") update_task_status(task_id, 'failed', 0, f"Error: {str(e)}") task_queue.task_done() except queue.Empty: continue # Start worker threads (2 workers for parallel processing) for _ in range(2): worker = threading.Thread(target=task_worker, daemon=True) worker.start() # ============================================================================== # User Management # ============================================================================== def hash_password(password: str) -> str: return hashlib.sha256(password.encode()).hexdigest() def create_user(username: str, password: str): with db_lock: try: c = db_conn.cursor() c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)", (username, hash_password(password))) db_conn.commit() return True, "Account created!" except sqlite3.IntegrityError: return False, "Username exists!" def authenticate(username: str, password: str): with db_lock: c = db_conn.cursor() c.execute("SELECT id, password_hash FROM users WHERE username=?", (username,)) result = c.fetchone() if result and result[1] == hash_password(password): return True, result[0] return False, None def get_user_tasks(user_id: int): with db_lock: c = db_conn.cursor() c.execute("""SELECT id, model_name, prompt, status, progress, tokens_generated, tokens_per_sec, created_at FROM tasks WHERE user_id=? ORDER BY created_at DESC LIMIT 50""", (user_id,)) return c.fetchall() def get_user_active_tasks(user_id: int): with db_lock: c = db_conn.cursor() c.execute("""SELECT COUNT(*) FROM tasks WHERE user_id=? AND status IN ('queued', 'processing')""", (user_id,)) return c.fetchone()[0] # ============================================================================== # Gradio UI # ============================================================================== css = """ .container { max-width: 1400px; margin: 0 auto; } .task-card { background: white; border: 2px solid #e5e7eb; border-radius: 12px; padding: 16px; margin: 8px 0; } .status-queued { color: #f59e0b; } .status-processing { color: #3b82f6; } .status-completed { color: #10b981; } .status-failed { color: #ef4444; } .progress-bar { height: 8px; background: #e5e7eb; border-radius: 4px; overflow: hidden; margin: 8px 0; } .progress-fill { height: 100%; background: linear-gradient(90deg, #10b981, #059669); transition: width 0.3s; } """ with gr.Blocks(css=css, title="SAM Background Processor") as demo: user_id_state = gr.State(None) gr.Markdown("# šŸš€ SAM Multi-Task Processor") gr.Markdown("Submit up to 5 background tasks. No need to stay on page!") # Auth with gr.Group(visible=True) as auth_group: gr.Markdown("### šŸ” Sign In / Sign Up") auth_username = gr.Textbox(label="Username", placeholder="username") auth_password = gr.Textbox(label="Password", type="password") auth_btn = gr.Button("Continue", variant="primary") auth_msg = gr.Markdown("") # Main UI with gr.Group(visible=False) as main_group: with gr.Row(): gr.Markdown("### šŸ¤– Create Task") user_display = gr.Markdown("") with gr.Row(): with gr.Column(scale=2): model_choice = gr.Radio( choices=["SAM-X-1-Large (Reasoning)", "SAM-Z-1 (Fast)"], value="SAM-Z-1 (Fast)", label="Model" ) prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt...", lines=4 ) submit_btn = gr.Button("šŸš€ Submit Task", variant="primary", size="lg") task_msg = gr.Markdown("") with gr.Column(scale=1): gr.Markdown("### ā„¹ļø Info") gr.Markdown(""" - **SAM-X-1**: Reasoning model with `` tags - **SAM-Z-1**: Ultra-fast direct responses - Max 5 concurrent tasks - Results saved to database - Background processing """) gr.Markdown("---") with gr.Row(): gr.Markdown("### šŸ“‹ Your Tasks") refresh_btn = gr.Button("šŸ”„ Refresh", size="sm") tasks_display = gr.HTML("") auto_refresh = gr.Checkbox(label="Auto-refresh every 3 seconds", value=True) # Auth handler def handle_auth(username, password): if len(username) < 3 or len(password) < 6: return None, "āŒ Invalid credentials", gr.update(), gr.update() success, user_id = authenticate(username, password) if not success: success, msg = create_user(username, password) if success: success, user_id = authenticate(username, password) if success: return ( user_id, f"āœ… Welcome, **{username}**!", gr.update(visible=False), gr.update(visible=True) ) return None, "āŒ Authentication failed", gr.update(), gr.update() # Submit task def submit_task(user_id, model, prompt): if not user_id: return "āŒ Please sign in", "" if not prompt.strip(): return "āŒ Prompt required", "" active_count = get_user_active_tasks(user_id) if active_count >= 5: return f"āŒ Max 5 active tasks (you have {active_count})", "" task_id = create_task(user_id, model, prompt) return f"āœ… Task submitted! ID: `{task_id[:8]}...`", "" # Render tasks def render_tasks(user_id): if not user_id: return "" tasks = get_user_tasks(user_id) if not tasks: return "
No tasks yet
" html = "" for task in tasks: task_id, model, prompt, status, progress, tokens, tps, created = task status_class = f"status-{status}" html += f"""
Task: {task_id[:8]}... ā—{status.upper()}
Model: {model}
Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''}
Progress: {progress}% | Tokens: {tokens} | Speed: {tps:.1f} tok/s
""" return html # Get task result def get_task_result(user_id, task_id_short): if not user_id or not task_id_short: return "āŒ Invalid request" with db_lock: c = db_conn.cursor() c.execute("""SELECT result, status FROM tasks WHERE user_id=? AND id LIKE ?""", (user_id, f"{task_id_short}%")) result = c.fetchone() if result: if result[1] == 'completed': return f"### āœ… Result\n\n{result[0]}" elif result[1] == 'failed': return f"### āŒ Failed\n\n{result[0]}" else: return f"### ā³ Status: {result[1]}" return "āŒ Task not found" # Event handlers auth_btn.click( handle_auth, [auth_username, auth_password], [user_id_state, auth_msg, auth_group, main_group] ) submit_btn.click( submit_task, [user_id_state, model_choice, prompt_input], [task_msg, prompt_input] ).then( render_tasks, [user_id_state], [tasks_display] ) refresh_btn.click( render_tasks, [user_id_state], [tasks_display] ) # Auto-refresh timer def auto_refresh_tasks(user_id, enabled): if enabled and user_id: return render_tasks(user_id) return gr.update() # Poll every 3 seconds when auto-refresh enabled demo.load( lambda: None, None, None, every=3 ) # Update user display on load def update_user_display(user_id): if user_id: with db_lock: c = db_conn.cursor() c.execute("SELECT username FROM users WHERE id=?", (user_id,)) result = c.fetchone() if result: active = get_user_active_tasks(user_id) return f"**User:** {result[0]} | **Active:** {active}/5" return "" # Periodic refresh refresh_timer = gr.Timer(3) @refresh_timer.tick def timer_refresh(user_id, auto_enabled): if auto_enabled and user_id: return render_tasks(user_id), update_user_display(user_id) return gr.update(), gr.update() refresh_timer.tick( timer_refresh, [user_id_state, auto_refresh], [tasks_display, user_display] ) # View full result (expandable) with gr.Accordion("šŸ” View Task Result", open=False): result_task_id = gr.Textbox( label="Task ID (first 8 chars)", placeholder="e.g., 3f7a9b2c" ) view_result_btn = gr.Button("View Result", variant="primary") result_display = gr.Markdown("") view_result_btn.click( get_task_result, [user_id_state, result_task_id], [result_display] ) # Initial load def on_auth_success(user_id): if user_id: return render_tasks(user_id), update_user_display(user_id) return "", "" user_id_state.change( on_auth_success, [user_id_state], [tasks_display, user_display] ) if __name__ == "__main__": print("\n" + "="*80) print("šŸš€ SAM BACKGROUND PROCESSOR".center(80)) print("="*80) print(f"āœ… 2 worker threads active") print(f"āœ… Max 5 tasks per user") print(f"āœ… Background processing enabled") print(f"āœ… Database: sam_tasks.db") print("="*80 + "\n") demo.queue(max_size=50) demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )