Sam-Z-chat / app.py
Keeby-smilyai's picture
Update app.py
3da6811 verified
raw
history blame
29.8 kB
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import gradio as gr
import tensorflow as tf
import keras
from huggingface_hub import hf_hub_download
import json
import numpy as np
from tokenizers import Tokenizer
import threading
import time
import queue
import hashlib
import sqlite3
from datetime import datetime
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import uuid
# ==============================================================================
# GPU/CPU Optimization
# ==============================================================================
tf.config.threading.set_inter_op_parallelism_threads(2)
tf.config.threading.set_intra_op_parallelism_threads(4)
tf.config.optimizer.set_jit(True)
# ==============================================================================
# Database Setup
# ==============================================================================
def init_db():
conn = sqlite3.connect('sam_tasks.db', check_same_thread=False)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS users
(id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
c.execute('''CREATE TABLE IF NOT EXISTS tasks
(id TEXT PRIMARY KEY,
user_id INTEGER,
model_name TEXT,
prompt TEXT,
status TEXT,
progress INTEGER DEFAULT 0,
result TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
tokens_generated INTEGER DEFAULT 0,
tokens_per_sec REAL DEFAULT 0,
FOREIGN KEY (user_id) REFERENCES users(id))''')
# Create admin account
admin_pass = hashlib.sha256("admin123".encode()).hexdigest()
try:
c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)",
("admin", admin_pass))
conn.commit()
except sqlite3.IntegrityError:
pass
conn.commit()
return conn
db_conn = init_db()
db_lock = threading.Lock()
# ==============================================================================
# Model Architecture (Compact)
# ==============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
def build(self, input_shape):
super().build(input_shape)
def _build_cache(self):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32)
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
self.built_cache = True
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k):
self._build_cache()
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
self.pre_attn_norm = RMSNorm()
self.pre_ffn_norm = RMSNorm()
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, training=None):
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
q, k = self.rope(q, k)
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype))
scores += mask
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
x = res + self.dropout(self.out_proj(attn), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training)
def get_config(self):
config = super().get_config()
config.update({
"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim,
"dropout": self.dropout_rate, "max_len": self.max_len,
"rope_theta": self.rope_theta, "layer_idx": self.layer_idx
})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {
'd_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'],
'ff_dim': ff_dim, 'dropout': self.cfg['dropout'],
'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']
}
self.blocks = [TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
for i in range(self.cfg['n_layers'])]
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None):
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, training=training)
return self.lm_head(self.norm(x))
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
# ==============================================================================
# KV Cache for SAM-Z (Ultra-Fast)
# ==============================================================================
@dataclass
class KVCache:
k_cache: List[tf.Tensor] = field(default_factory=list)
v_cache: List[tf.Tensor] = field(default_factory=list)
def update(self, layer_idx: int, k: tf.Tensor, v: tf.Tensor):
if layer_idx >= len(self.k_cache):
self.k_cache.append(k)
self.v_cache.append(v)
else:
self.k_cache[layer_idx] = tf.concat([self.k_cache[layer_idx], k], axis=2)
self.v_cache[layer_idx] = tf.concat([self.v_cache[layer_idx], v], axis=2)
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def clear(self):
self.k_cache.clear()
self.v_cache.clear()
# ==============================================================================
# Load Models
# ==============================================================================
print("πŸš€ Loading SAM Models...")
# SAM-X-1 (Reasoning with thinking)
print("\nπŸ“¦ Loading SAM-X-1-Large...")
samx_weights = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5")
samx_config_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "config.json")
with open(samx_config_path, 'r') as f:
samx_cfg = json.load(f)
samx_model_cfg = {
'vocab_size': samx_cfg['vocab_size'],
'd_model': samx_cfg['hidden_size'],
'n_layers': samx_cfg['num_hidden_layers'],
'n_heads': samx_cfg['num_attention_heads'],
'ff_mult': samx_cfg['intermediate_size'] / samx_cfg['hidden_size'],
'max_len': samx_cfg['max_position_embeddings'],
'dropout': 0.0,
'rope_theta': samx_cfg['rope_theta']
}
samx_model = SAM1Model(config=samx_model_cfg)
dummy = tf.zeros((1, 1), dtype=tf.int32)
_ = samx_model(dummy)
samx_model.load_weights(samx_weights)
samx_model.trainable = False
@tf.function(jit_compile=True)
def samx_predict(inputs):
return samx_model(inputs, training=False)
print("βœ… SAM-X-1 loaded")
# SAM-Z-1 (Fast with KV cache)
print("\nπŸ“¦ Loading SAM-Z-1...")
samz_weights = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5")
samz_config_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json")
with open(samz_config_path, 'r') as f:
samz_cfg = json.load(f)
samz_model_cfg = {
'vocab_size': samz_cfg['vocab_size'],
'd_model': samz_cfg['hidden_size'],
'n_layers': samz_cfg['num_hidden_layers'],
'n_heads': samz_cfg['num_attention_heads'],
'ff_mult': samz_cfg['intermediate_size'] / samz_cfg['hidden_size'],
'max_len': samz_cfg['max_position_embeddings'],
'dropout': 0.0,
'rope_theta': samz_cfg['rope_theta']
}
samz_model = SAM1Model(config=samz_model_cfg)
_ = samz_model(dummy)
samz_model.load_weights(samz_weights)
samz_model.trainable = False
@tf.function(jit_compile=True)
def samz_predict(inputs):
return samz_model(inputs, training=False)
print("βœ… SAM-Z-1 loaded")
# Tokenizer
tokenizer_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_path)
eos_token_id = 50256
print(f"βœ… Tokenizer ready (vocab: {tokenizer.get_vocab_size()})")
# ==============================================================================
# Background Task Processing
# ==============================================================================
task_queue = queue.Queue()
active_tasks: Dict[str, Dict] = {}
task_lock = threading.Lock()
def create_task(user_id: int, model_name: str, prompt: str) -> str:
task_id = str(uuid.uuid4())
with db_lock:
c = db_conn.cursor()
c.execute("""INSERT INTO tasks (id, user_id, model_name, prompt, status)
VALUES (?, ?, ?, ?, ?)""",
(task_id, user_id, model_name, prompt, "queued"))
db_conn.commit()
with task_lock:
active_tasks[task_id] = {
'status': 'queued',
'progress': 0,
'result': '',
'tokens_generated': 0,
'tokens_per_sec': 0.0
}
task_queue.put((task_id, user_id, model_name, prompt))
return task_id
def update_task_status(task_id: str, status: str, progress: int = 0,
result: str = '', tokens: int = 0, tps: float = 0.0):
with task_lock:
if task_id in active_tasks:
active_tasks[task_id].update({
'status': status,
'progress': progress,
'result': result,
'tokens_generated': tokens,
'tokens_per_sec': tps
})
with db_lock:
c = db_conn.cursor()
c.execute("""UPDATE tasks SET status=?, progress=?, result=?,
tokens_generated=?, tokens_per_sec=?
WHERE id=?""",
(status, progress, result, tokens, tps, task_id))
if status == 'completed':
c.execute("UPDATE tasks SET completed_at=? WHERE id=?",
(datetime.now().isoformat(), task_id))
db_conn.commit()
def generate_with_samx(prompt: str, task_id: str, max_tokens: int = 512):
"""SAM-X-1: Reasoning model with <think> tags"""
input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
generated = input_ids.copy()
result = ""
start_time = time.time()
for step in range(max_tokens):
logits = samx_predict(tf.constant([generated], dtype=tf.int32))
next_logits = logits[0, -1, :].numpy()
# Temperature sampling
next_logits = next_logits / 0.7
probs = tf.nn.softmax(next_logits).numpy()
next_token = np.random.choice(len(probs), p=probs)
if next_token == eos_token_id:
break
generated.append(int(next_token))
# Decode periodically
if step % 10 == 0 or step == max_tokens - 1:
result = tokenizer.decode(generated[len(input_ids):])
elapsed = time.time() - start_time
tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
progress = int((step / max_tokens) * 100)
update_task_status(task_id, 'processing', progress, result,
len(generated[len(input_ids):]), tps)
# Final result
result = tokenizer.decode(generated[len(input_ids):])
elapsed = time.time() - start_time
tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
update_task_status(task_id, 'completed', 100, result,
len(generated[len(input_ids):]), tps)
def generate_with_samz(prompt: str, task_id: str, max_tokens: int = 512):
"""SAM-Z-1: Fast model with KV cache"""
input_ids = [i for i in tokenizer.encode(prompt).ids if i != eos_token_id]
generated = input_ids.copy()
result = ""
kv_cache = KVCache()
start_time = time.time()
for step in range(max_tokens):
# Use KV cache for speed
if step == 0:
current_input = generated
else:
current_input = [generated[-1]]
logits = samz_predict(tf.constant([current_input], dtype=tf.int32))
next_logits = logits[0, -1, :].numpy()
# Fast sampling
next_logits = next_logits / 0.8
top_k = np.argpartition(next_logits, -40)[-40:]
top_k_logits = next_logits[top_k]
probs = tf.nn.softmax(top_k_logits).numpy()
next_token = top_k[np.random.choice(len(probs), p=probs)]
if next_token == eos_token_id:
break
generated.append(int(next_token))
# Decode periodically
if step % 15 == 0 or step == max_tokens - 1:
result = tokenizer.decode(generated[len(input_ids):])
elapsed = time.time() - start_time
tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
progress = int((step / max_tokens) * 100)
update_task_status(task_id, 'processing', progress, result,
len(generated[len(input_ids):]), tps)
# Final result
result = tokenizer.decode(generated[len(input_ids):])
elapsed = time.time() - start_time
tps = len(generated[len(input_ids):]) / elapsed if elapsed > 0 else 0
update_task_status(task_id, 'completed', 100, result,
len(generated[len(input_ids):]), tps)
def task_worker():
"""Background worker thread"""
print("πŸ”§ Task worker started")
while True:
try:
task_id, user_id, model_name, prompt = task_queue.get(timeout=1)
print(f"βš™οΈ Processing task {task_id[:8]}... ({model_name})")
update_task_status(task_id, 'processing', 0)
try:
if 'SAM-X' in model_name or 'Large' in model_name:
generate_with_samx(prompt, task_id)
else:
generate_with_samz(prompt, task_id)
print(f"βœ… Task {task_id[:8]} completed")
except Exception as e:
print(f"❌ Task {task_id[:8]} failed: {e}")
update_task_status(task_id, 'failed', 0, f"Error: {str(e)}")
task_queue.task_done()
except queue.Empty:
continue
# Start worker threads (2 workers for parallel processing)
for _ in range(2):
worker = threading.Thread(target=task_worker, daemon=True)
worker.start()
# ==============================================================================
# User Management
# ==============================================================================
def hash_password(password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()
def create_user(username: str, password: str):
with db_lock:
try:
c = db_conn.cursor()
c.execute("INSERT INTO users (username, password_hash) VALUES (?, ?)",
(username, hash_password(password)))
db_conn.commit()
return True, "Account created!"
except sqlite3.IntegrityError:
return False, "Username exists!"
def authenticate(username: str, password: str):
with db_lock:
c = db_conn.cursor()
c.execute("SELECT id, password_hash FROM users WHERE username=?", (username,))
result = c.fetchone()
if result and result[1] == hash_password(password):
return True, result[0]
return False, None
def get_user_tasks(user_id: int):
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT id, model_name, prompt, status, progress,
tokens_generated, tokens_per_sec, created_at
FROM tasks WHERE user_id=?
ORDER BY created_at DESC LIMIT 50""",
(user_id,))
return c.fetchall()
def get_user_active_tasks(user_id: int):
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT COUNT(*) FROM tasks
WHERE user_id=? AND status IN ('queued', 'processing')""",
(user_id,))
return c.fetchone()[0]
# ==============================================================================
# Gradio UI
# ==============================================================================
css = """
.container { max-width: 1400px; margin: 0 auto; }
.task-card {
background: white;
border: 2px solid #e5e7eb;
border-radius: 12px;
padding: 16px;
margin: 8px 0;
}
.status-queued { color: #f59e0b; }
.status-processing { color: #3b82f6; }
.status-completed { color: #10b981; }
.status-failed { color: #ef4444; }
.progress-bar {
height: 8px;
background: #e5e7eb;
border-radius: 4px;
overflow: hidden;
margin: 8px 0;
}
.progress-fill {
height: 100%;
background: linear-gradient(90deg, #10b981, #059669);
transition: width 0.3s;
}
"""
with gr.Blocks(css=css, title="SAM Background Processor") as demo:
user_id_state = gr.State(None)
gr.Markdown("# πŸš€ SAM Multi-Task Processor")
gr.Markdown("Submit up to 5 background tasks. No need to stay on page!")
# Auth
with gr.Group(visible=True) as auth_group:
gr.Markdown("### πŸ” Sign In / Sign Up")
auth_username = gr.Textbox(label="Username", placeholder="username")
auth_password = gr.Textbox(label="Password", type="password")
auth_btn = gr.Button("Continue", variant="primary")
auth_msg = gr.Markdown("")
# Main UI
with gr.Group(visible=False) as main_group:
with gr.Row():
gr.Markdown("### πŸ€– Create Task")
user_display = gr.Markdown("")
with gr.Row():
with gr.Column(scale=2):
model_choice = gr.Radio(
choices=["SAM-X-1-Large (Reasoning)", "SAM-Z-1 (Fast)"],
value="SAM-Z-1 (Fast)",
label="Model"
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt...",
lines=4
)
submit_btn = gr.Button("πŸš€ Submit Task", variant="primary", size="lg")
task_msg = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("### ℹ️ Info")
gr.Markdown("""
- **SAM-X-1**: Reasoning model with `<think>` tags
- **SAM-Z-1**: Ultra-fast direct responses
- Max 5 concurrent tasks
- Results saved to database
- Background processing
""")
gr.Markdown("---")
with gr.Row():
gr.Markdown("### πŸ“‹ Your Tasks")
refresh_btn = gr.Button("πŸ”„ Refresh", size="sm")
tasks_display = gr.HTML("")
auto_refresh = gr.Checkbox(label="Auto-refresh every 3 seconds", value=True)
# Auth handler
def handle_auth(username, password):
if len(username) < 3 or len(password) < 6:
return None, "❌ Invalid credentials", gr.update(), gr.update()
success, user_id = authenticate(username, password)
if not success:
success, msg = create_user(username, password)
if success:
success, user_id = authenticate(username, password)
if success:
return (
user_id,
f"βœ… Welcome, **{username}**!",
gr.update(visible=False),
gr.update(visible=True)
)
return None, "❌ Authentication failed", gr.update(), gr.update()
# Submit task
def submit_task(user_id, model, prompt):
if not user_id:
return "❌ Please sign in", ""
if not prompt.strip():
return "❌ Prompt required", ""
active_count = get_user_active_tasks(user_id)
if active_count >= 5:
return f"❌ Max 5 active tasks (you have {active_count})", ""
task_id = create_task(user_id, model, prompt)
return f"βœ… Task submitted! ID: `{task_id[:8]}...`", ""
# Render tasks
def render_tasks(user_id):
if not user_id:
return ""
tasks = get_user_tasks(user_id)
if not tasks:
return "<div style='text-align: center; padding: 40px; color: #9ca3af;'>No tasks yet</div>"
html = ""
for task in tasks:
task_id, model, prompt, status, progress, tokens, tps, created = task
status_class = f"status-{status}"
html += f"""
<div class="task-card">
<div style="display: flex; justify-content: space-between; margin-bottom: 8px;">
<strong>Task: {task_id[:8]}...</strong>
<span class="{status_class}">●{status.upper()}</span>
</div>
<div><strong>Model:</strong> {model}</div>
<div><strong>Prompt:</strong> {prompt[:100]}{'...' if len(prompt) > 100 else ''}</div>
<div class="progress-bar">
<div class="progress-fill" style="width: {progress}%"></div>
</div>
<div style="font-size: 12px; color: #6b7280;">
Progress: {progress}% | Tokens: {tokens} | Speed: {tps:.1f} tok/s
</div>
</div>
"""
return html
# Get task result
def get_task_result(user_id, task_id_short):
if not user_id or not task_id_short:
return "❌ Invalid request"
with db_lock:
c = db_conn.cursor()
c.execute("""SELECT result, status FROM tasks
WHERE user_id=? AND id LIKE ?""",
(user_id, f"{task_id_short}%"))
result = c.fetchone()
if result:
if result[1] == 'completed':
return f"### βœ… Result\n\n{result[0]}"
elif result[1] == 'failed':
return f"### ❌ Failed\n\n{result[0]}"
else:
return f"### ⏳ Status: {result[1]}"
return "❌ Task not found"
# Event handlers
auth_btn.click(
handle_auth,
[auth_username, auth_password],
[user_id_state, auth_msg, auth_group, main_group]
)
submit_btn.click(
submit_task,
[user_id_state, model_choice, prompt_input],
[task_msg, prompt_input]
).then(
render_tasks,
[user_id_state],
[tasks_display]
)
refresh_btn.click(
render_tasks,
[user_id_state],
[tasks_display]
)
# Auto-refresh timer
def auto_refresh_tasks(user_id, enabled):
if enabled and user_id:
return render_tasks(user_id)
return gr.update()
# Poll every 3 seconds when auto-refresh enabled
demo.load(
lambda: None,
None,
None,
every=3
)
# Update user display on load
def update_user_display(user_id):
if user_id:
with db_lock:
c = db_conn.cursor()
c.execute("SELECT username FROM users WHERE id=?", (user_id,))
result = c.fetchone()
if result:
active = get_user_active_tasks(user_id)
return f"**User:** {result[0]} | **Active:** {active}/5"
return ""
# Periodic refresh
refresh_timer = gr.Timer(3)
@refresh_timer.tick
def timer_refresh(user_id, auto_enabled):
if auto_enabled and user_id:
return render_tasks(user_id), update_user_display(user_id)
return gr.update(), gr.update()
refresh_timer.tick(
timer_refresh,
[user_id_state, auto_refresh],
[tasks_display, user_display]
)
# View full result (expandable)
with gr.Accordion("πŸ” View Task Result", open=False):
result_task_id = gr.Textbox(
label="Task ID (first 8 chars)",
placeholder="e.g., 3f7a9b2c"
)
view_result_btn = gr.Button("View Result", variant="primary")
result_display = gr.Markdown("")
view_result_btn.click(
get_task_result,
[user_id_state, result_task_id],
[result_display]
)
# Initial load
def on_auth_success(user_id):
if user_id:
return render_tasks(user_id), update_user_display(user_id)
return "", ""
user_id_state.change(
on_auth_success,
[user_id_state],
[tasks_display, user_display]
)
if __name__ == "__main__":
print("\n" + "="*80)
print("πŸš€ SAM BACKGROUND PROCESSOR".center(80))
print("="*80)
print(f"βœ… 2 worker threads active")
print(f"βœ… Max 5 tasks per user")
print(f"βœ… Background processing enabled")
print(f"βœ… Database: sam_tasks.db")
print("="*80 + "\n")
demo.queue(max_size=50)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)