Keeby-smilyai commited on
Commit
765bb8c
Β·
verified Β·
1 Parent(s): 2d42d16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -274
app.py CHANGED
@@ -18,18 +18,16 @@ from datetime import datetime
18
  import uuid
19
 
20
  # ==============================================================================
21
- # 1. System & Hardware
22
  # ==============================================================================
23
- # Optimized for CPU/GPU throughput
24
  tf.config.threading.set_inter_op_parallelism_threads(2)
25
  tf.config.threading.set_intra_op_parallelism_threads(4)
26
  tf.config.optimizer.set_jit(True)
27
 
28
- print(f"πŸš€ SmilyAI System Initializing...")
29
- print(f"πŸ“± TensorFlow Version: {tf.__version__}")
30
 
31
  # ==============================================================================
32
- # 2. Database (State Management)
33
  # ==============================================================================
34
  def init_db():
35
  conn = sqlite3.connect('sam_tasks.db', check_same_thread=False)
@@ -37,9 +35,7 @@ def init_db():
37
  c.execute('''CREATE TABLE IF NOT EXISTS users
38
  (id INTEGER PRIMARY KEY AUTOINCREMENT,
39
  username TEXT UNIQUE NOT NULL,
40
- password_hash TEXT NOT NULL,
41
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)''')
42
-
43
  c.execute('''CREATE TABLE IF NOT EXISTS tasks
44
  (id TEXT PRIMARY KEY,
45
  user_id INTEGER,
@@ -49,8 +45,6 @@ def init_db():
49
  progress INTEGER DEFAULT 0,
50
  result TEXT,
51
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
52
- completed_at TIMESTAMP,
53
- tokens_generated INTEGER DEFAULT 0,
54
  tokens_per_sec REAL DEFAULT 0,
55
  FOREIGN KEY (user_id) REFERENCES users(id))''')
56
  conn.commit()
@@ -60,7 +54,7 @@ db_conn = init_db()
60
  db_lock = threading.Lock()
61
 
62
  # ==============================================================================
63
- # 3. Optimized Model Architecture (KV Cache Enabled)
64
  # ==============================================================================
65
  @keras.saving.register_keras_serializable()
66
  class RotaryEmbedding(keras.layers.Layer):
@@ -81,41 +75,19 @@ class RotaryEmbedding(keras.layers.Layer):
81
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
82
  self.built_cache = True
83
 
84
- def rotate_half(self, x):
85
- x1, x2 = tf.split(x, 2, axis=-1)
86
- return tf.concat([-x2, x1], axis=-1)
87
-
88
  def call(self, q, k):
89
  self._build_cache()
90
  seq_len = tf.shape(q)[2]
91
- cos = self.cos_cached[:seq_len, :]
92
- sin = self.sin_cached[:seq_len, :]
93
 
94
- # Reshape for broadcast: [1, 1, Seq, Dim]
95
- cos = cos[None, None, :, :]
96
- sin = sin[None, None, :, :]
97
-
98
- q_rotated = (q * cos) + (self.rotate_half(q) * sin)
99
- k_rotated = (k * cos) + (self.rotate_half(k) * sin)
100
- return q_rotated, k_rotated
101
-
102
- def get_config(self):
103
- config = super().get_config()
104
- config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
105
- return config
106
-
107
- @keras.saving.register_keras_serializable()
108
- class RMSNorm(keras.layers.Layer):
109
- def __init__(self, epsilon=1e-5, **kwargs):
110
- super().__init__(**kwargs)
111
- self.epsilon = epsilon
112
-
113
- def build(self, input_shape):
114
- self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
115
-
116
- def call(self, x):
117
- variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
118
- return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
119
 
120
  @keras.saving.register_keras_serializable()
121
  class TransformerBlock(keras.layers.Layer):
@@ -124,25 +96,23 @@ class TransformerBlock(keras.layers.Layer):
124
  self.n_heads = n_heads
125
  self.head_dim = d_model // n_heads
126
  self.d_model = d_model
127
-
128
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
129
- self.pre_attn_norm = RMSNorm()
130
- self.pre_ffn_norm = RMSNorm()
131
-
132
  self.q_proj = keras.layers.Dense(d_model, use_bias=False)
133
  self.k_proj = keras.layers.Dense(d_model, use_bias=False)
134
  self.v_proj = keras.layers.Dense(d_model, use_bias=False)
135
  self.out_proj = keras.layers.Dense(d_model, use_bias=False)
136
-
137
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False)
138
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False)
139
  self.down_proj = keras.layers.Dense(d_model, use_bias=False)
140
  self.dropout = keras.layers.Dropout(dropout)
141
 
142
  def call(self, x, cache=None, training=None):
143
- B, T = tf.shape(x)[0], tf.shape(x)[1]
 
144
 
145
- # --- Attention ---
146
  res = x
147
  y = self.pre_attn_norm(x)
148
 
@@ -150,38 +120,43 @@ class TransformerBlock(keras.layers.Layer):
150
  k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
151
  v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
152
 
153
- # KV Cache Update
154
  if cache is not None:
155
  k_cache, v_cache = cache
156
  k = tf.concat([k_cache, k], axis=1)
157
  v = tf.concat([v_cache, v], axis=1)
158
-
159
  new_cache = (k, v)
160
 
161
  # RoPE
162
  q = tf.transpose(q, [0, 2, 1, 3])
163
  k_rot = tf.transpose(k, [0, 2, 1, 3])
164
  v_t = tf.transpose(v, [0, 2, 1, 3])
165
-
166
  q, k_rot = self.rope(q, k_rot)
167
 
168
- # Scaled Dot Product Attention
169
  scores = tf.matmul(q, k_rot, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, x.dtype))
170
 
171
- if T > 1: # Causal mask for prefill
 
 
172
  mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
173
- mask = (1.0 - mask) * -1e9
174
- scores += mask
175
 
 
 
 
 
 
 
 
 
 
176
  attn = tf.nn.softmax(scores, axis=-1)
177
  out = tf.matmul(attn, v_t)
178
-
179
- out = tf.transpose(out, [0, 2, 1, 3])
180
- out = tf.reshape(out, [B, T, self.d_model])
181
-
182
  x = res + self.out_proj(out)
183
 
184
- # --- FFN ---
185
  res = x
186
  y = self.pre_ffn_norm(x)
187
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
@@ -194,84 +169,61 @@ class SAM1Model(keras.Model):
194
  super().__init__(**kwargs)
195
  self.embed = keras.layers.Embedding(config['vocab_size'], config['d_model'])
196
  ff_dim = int(config['d_model'] * config['ff_mult'])
197
-
198
  self.blocks = [
199
  TransformerBlock(
200
  config['d_model'], config['n_heads'], ff_dim, config['dropout'],
201
- config['max_len'], config['rope_theta'], name=f"blk_{i}"
202
  ) for i in range(config['n_layers'])
203
  ]
204
- self.norm = RMSNorm()
205
  self.lm_head = keras.layers.Dense(config['vocab_size'], use_bias=False)
206
 
207
  def call(self, input_ids, cache=None, training=None):
208
  x = self.embed(input_ids)
209
  new_caches = []
210
-
211
  for i, block in enumerate(self.blocks):
212
  c_i = cache[i] if cache is not None else None
213
  x, nc_i = block(x, cache=c_i, training=training)
214
  new_caches.append(nc_i)
215
-
216
  return self.lm_head(self.norm(x)), new_caches
217
 
218
  # ==============================================================================
219
- # 4. Load Resources (Models + Tokenizers)
220
  # ==============================================================================
221
- print("\nπŸ“¦ Loading SmilyAI Resources...")
222
 
223
  dummy_in = tf.zeros((1, 1), dtype=tf.int32)
224
 
225
- # --- 1. SAM-X-1 (Reasoning) ---
226
- print("πŸ”Ή Loading SAM-X-1...")
227
- # Config & Tokenizer from: Sam-1-large-it-0002
228
- samx_cfg_path = hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "config.json")
229
- samx_tok_path = hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "tokenizer.json")
230
- # Weights from: Sam-1x-instruct
231
- samx_wgt_path = hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5")
232
-
233
- with open(samx_cfg_path) as f: cfg_x_json = json.load(f)
234
- tokenizer_x = Tokenizer.from_file(samx_tok_path)
235
-
236
- samx_model = SAM1Model({
237
- 'vocab_size': cfg_x_json['vocab_size'],
238
- 'd_model': cfg_x_json['hidden_size'],
239
- 'n_layers': cfg_x_json['num_hidden_layers'],
240
- 'n_heads': cfg_x_json['num_attention_heads'],
241
- 'ff_mult': cfg_x_json['intermediate_size'] / cfg_x_json['hidden_size'],
242
- 'max_len': cfg_x_json['max_position_embeddings'],
243
- 'dropout': 0.0,
244
- 'rope_theta': cfg_x_json['rope_theta']
245
- })
246
- _ = samx_model(dummy_in) # Build
247
- samx_model.load_weights(samx_wgt_path)
248
- print("βœ… SAM-X-1 Ready")
249
-
250
- # --- 2. SAM-Z-1 (Speed) ---
251
- print("πŸ”Ή Loading SAM-Z-1...")
252
- # Everything from: Sam-Z-1-tensorflow
253
- samz_cfg_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json")
254
- samz_tok_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "tokenizer.json")
255
- samz_wgt_path = hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5")
256
-
257
- with open(samz_cfg_path) as f: cfg_z_json = json.load(f)
258
- tokenizer_z = Tokenizer.from_file(samz_tok_path)
259
-
260
- samz_model = SAM1Model({
261
- 'vocab_size': cfg_z_json['vocab_size'],
262
- 'd_model': cfg_z_json['hidden_size'],
263
- 'n_layers': cfg_z_json['num_hidden_layers'],
264
- 'n_heads': cfg_z_json['num_attention_heads'],
265
- 'ff_mult': cfg_z_json['intermediate_size'] / cfg_z_json['hidden_size'],
266
- 'max_len': cfg_z_json['max_position_embeddings'],
267
- 'dropout': 0.0,
268
- 'rope_theta': cfg_z_json['rope_theta']
269
- })
270
- _ = samz_model(dummy_in) # Build
271
- samz_model.load_weights(samz_wgt_path)
272
- print("βœ… SAM-Z-1 Ready")
273
-
274
- # JIT Compilation
275
  @tf.function(jit_compile=True)
276
  def predict_x(ids, cache): return samx_model(ids, cache=cache, training=False)
277
 
@@ -279,199 +231,236 @@ def predict_x(ids, cache): return samx_model(ids, cache=cache, training=False)
279
  def predict_z(ids, cache): return samz_model(ids, cache=cache, training=False)
280
 
281
  # ==============================================================================
282
- # 5. Task Processing
283
  # ==============================================================================
284
  task_queue = queue.Queue()
285
- db_lock = threading.Lock()
286
-
287
- def create_task(uid, model, prompt):
288
- tid = str(uuid.uuid4())
289
- with db_lock:
290
- c = db_conn.cursor()
291
- c.execute("INSERT INTO tasks (id, user_id, model_name, prompt, status) VALUES (?,?,?,?,?)",
292
- (tid, uid, model, prompt, 'queued'))
293
- db_conn.commit()
294
- task_queue.put((tid, model, prompt))
295
- return tid
296
-
297
- def update_task(tid, status, progress, result, tokens, tps):
298
- with db_lock:
299
- c = db_conn.cursor()
300
- c.execute("UPDATE tasks SET status=?, progress=?, result=?, tokens_generated=?, tokens_per_sec=? WHERE id=?",
301
- (status, progress, result, tokens, tps, tid))
302
- if status in ['completed', 'failed']:
303
- c.execute("UPDATE tasks SET completed_at=? WHERE id=?", (datetime.now().isoformat(), tid))
304
- db_conn.commit()
305
-
306
- def run_inference(tid, model_tag, prompt):
307
- # Select Resources
308
- if "SAM-X" in model_tag:
309
- predict_fn = predict_x
310
- tok = tokenizer_x
311
- else:
312
- predict_fn = predict_z
313
- tok = tokenizer_z
314
-
315
- try:
316
- start_time = time.time()
317
- ids = [i for i in tok.encode(prompt).ids]
318
- generated = []
319
-
320
- # 1. Prefill
321
- curr_ids = tf.constant([ids], dtype=tf.int32)
322
- logits, cache = predict_fn(curr_ids, cache=None)
323
- next_token = np.argmax(logits[0, -1, :])
324
- generated.append(next_token)
325
-
326
- # 2. Decode
327
- for step in range(1024):
328
- curr_ids = tf.constant([[generated[-1]]], dtype=tf.int32)
329
- logits, cache = predict_fn(curr_ids, cache=cache)
330
-
331
- # Simple sampling
332
- logits_np = logits[0, -1, :].numpy()
333
- next_token = np.argmax(logits_np) # Greedy for speed
334
-
335
- if next_token == 50256: # EOS
336
- break
337
-
338
- generated.append(next_token)
339
-
340
- # Stream Update (every 4 tokens)
341
- if step % 4 == 0:
342
- txt = tok.decode(generated)
343
- elapsed = time.time() - start_time
344
- tps = len(generated) / elapsed if elapsed > 0 else 0
345
- prog = int((step/1024)*100)
346
- update_task(tid, 'processing', prog, txt, len(generated), tps)
347
-
348
- # Final
349
- txt = tok.decode(generated)
350
- elapsed = time.time() - start_time
351
- update_task(tid, 'completed', 100, txt, len(generated), len(generated)/elapsed)
352
-
353
- except Exception as e:
354
- print(f"❌ Task {tid} failed: {e}")
355
- update_task(tid, 'failed', 0, str(e), 0, 0)
356
 
357
  def worker():
358
  while True:
359
  try:
360
  tid, model, prompt = task_queue.get(timeout=1)
361
- print(f"βš™οΈ Processing {tid} [{model}]")
362
- run_inference(tid, model, prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  task_queue.task_done()
364
- except queue.Empty:
365
- continue
366
 
367
- # Start Workers
368
- for _ in range(2):
369
- threading.Thread(target=worker, daemon=True).start()
370
 
371
  # ==============================================================================
372
- # 6. Gradio UI (Streaming Enabled)
373
  # ==============================================================================
374
  css = """
375
- .thought-box { background: #f0fdf4; border-left: 4px solid #22c55e; padding: 10px; margin: 10px 0; font-size: 0.9em; }
376
- .task-row { padding: 10px; border-bottom: 1px solid #eee; cursor: pointer; transition: background 0.2s; }
377
- .task-row:hover { background: #f9fafb; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  """
379
 
380
- def format_text(text):
381
  if not text: return ""
382
- # Render <think> tags
383
  if "<think>" in text:
384
  parts = text.split("<think>")
385
  pre = parts[0]
386
  rest = parts[1]
387
  if "</think>" in rest:
388
  thought, ans = rest.split("</think>")
389
- return f"{pre}<div class='thought-box'>🧠 <b>Thinking:</b><br>{thought}</div>{ans}"
390
- else:
391
- return f"{pre}<div class='thought-box'>🧠 <b>Thinking...</b><br>{rest}</div>"
392
  return text.replace("\n", "<br>")
393
 
394
- with gr.Blocks(css=css, title="SmilyAI Studio") as demo:
395
- uid_state = gr.State()
396
-
397
- gr.Markdown("## 🧠 SmilyAI Studio")
398
 
399
- with gr.Row():
400
- with gr.Column(scale=1):
401
- u_in = gr.Textbox(label="User")
402
- p_in = gr.Textbox(label="Pass", type="password")
403
- login_btn = gr.Button("Login")
 
404
 
405
- with gr.Column(scale=2):
406
- model_sel = gr.Radio(["SAM-X-1 (Reasoning)", "SAM-Z-1 (Fast)"], label="Model", value="SAM-Z-1 (Fast)")
407
- prompt_in = gr.Textbox(label="Prompt", lines=3)
408
- gen_btn = gr.Button("Generate", variant="primary")
409
-
410
- gr.Markdown("### πŸ“‘ Live Monitor (Click a task to watch)")
411
- with gr.Row():
412
- task_list = gr.HTML(label="History", elem_id="task-list")
413
- with gr.Column():
414
- monitor_id = gr.Textbox(label="Watching Task ID", interactive=False)
415
- stream_view = gr.HTML(label="Live Output", min_height=400)
416
-
417
- # Logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  def login(u, p):
419
- hashed = hashlib.sha256(p.encode()).hexdigest()
420
  with db_lock:
421
  c = db_conn.cursor()
422
- c.execute("SELECT id FROM users WHERE username=? AND password_hash=?", (u, hashed))
423
- res = c.fetchone()
424
- if not res:
425
- try:
426
- c.execute("INSERT INTO users (username, password_hash) VALUES (?,?)", (u, hashed))
427
- db_conn.commit()
428
- c.execute("SELECT id FROM users WHERE username=?", (u,))
429
- res = c.fetchone()
430
- except: return None
431
- return res[0]
432
-
433
- def submit(uid, m, p):
434
- if not uid: return None
435
- tid = create_task(uid, m, p)
436
- return tid
437
-
438
- def get_history(uid):
439
- if not uid: return ""
440
  with db_lock:
441
- c = db_conn.cursor()
442
- c.execute("SELECT id, model_name, status, progress FROM tasks WHERE user_id=? ORDER BY created_at DESC LIMIT 5", (uid,))
443
- rows = c.fetchall()
 
 
 
 
 
 
 
444
 
445
  html = ""
446
  for r in rows:
447
- # Add onclick to set the monitor_id
448
- html += f"""<div class='task-row' onclick="
449
- const ta = document.querySelector('#component-14 textarea');
450
- ta.value = '{r[0]}';
451
- ta.dispatchEvent(new Event('input'));
452
- ">
453
- <b>{r[1]}</b> | {r[2]} ({r[3]}%) <br><small>{r[0]}</small>
454
- </div>"""
 
 
 
 
455
  return html
456
 
457
- # Stream Timer
458
- timer = gr.Timer(0.5)
459
-
460
- def update_monitor(tid):
461
- if not tid: return ""
462
  with db_lock:
463
- c = db_conn.cursor()
464
- c.execute("SELECT result FROM tasks WHERE id=?", (tid,))
465
- res = c.fetchone()
466
- return format_text(res[0]) if res else "Task not found"
467
-
468
- # Events
469
- login_btn.click(login, [u_in, p_in], [uid_state])
470
- gen_btn.click(submit, [uid_state, model_sel, prompt_in], [monitor_id])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- # Auto-refresh history & stream
473
- timer.tick(get_history, [uid_state], [task_list])
474
- timer.tick(update_monitor, [monitor_id], [stream_view])
 
 
 
 
 
 
475
 
476
  if __name__ == "__main__":
477
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
18
  import uuid
19
 
20
  # ==============================================================================
21
+ # 1. Hardware Optimization & Setup
22
  # ==============================================================================
 
23
  tf.config.threading.set_inter_op_parallelism_threads(2)
24
  tf.config.threading.set_intra_op_parallelism_threads(4)
25
  tf.config.optimizer.set_jit(True)
26
 
27
+ print(f"πŸš€ SmilyAI Pro System Initializing...")
 
28
 
29
  # ==============================================================================
30
+ # 2. Database
31
  # ==============================================================================
32
  def init_db():
33
  conn = sqlite3.connect('sam_tasks.db', check_same_thread=False)
 
35
  c.execute('''CREATE TABLE IF NOT EXISTS users
36
  (id INTEGER PRIMARY KEY AUTOINCREMENT,
37
  username TEXT UNIQUE NOT NULL,
38
+ password_hash TEXT NOT NULL)''')
 
 
39
  c.execute('''CREATE TABLE IF NOT EXISTS tasks
40
  (id TEXT PRIMARY KEY,
41
  user_id INTEGER,
 
45
  progress INTEGER DEFAULT 0,
46
  result TEXT,
47
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
 
 
48
  tokens_per_sec REAL DEFAULT 0,
49
  FOREIGN KEY (user_id) REFERENCES users(id))''')
50
  conn.commit()
 
54
  db_lock = threading.Lock()
55
 
56
  # ==============================================================================
57
+ # 3. Model (Fixed with tf.cond)
58
  # ==============================================================================
59
  @keras.saving.register_keras_serializable()
60
  class RotaryEmbedding(keras.layers.Layer):
 
75
  self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32)
76
  self.built_cache = True
77
 
 
 
 
 
78
  def call(self, q, k):
79
  self._build_cache()
80
  seq_len = tf.shape(q)[2]
81
+ cos = self.cos_cached[:seq_len, :][None, None, :, :]
82
+ sin = self.sin_cached[:seq_len, :][None, None, :, :]
83
 
84
+ def rotate_half(x):
85
+ x1, x2 = tf.split(x, 2, axis=-1)
86
+ return tf.concat([-x2, x1], axis=-1)
87
+
88
+ q_rot = (q * cos) + (rotate_half(q) * sin)
89
+ k_rot = (k * cos) + (rotate_half(k) * sin)
90
+ return q_rot, k_rot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  @keras.saving.register_keras_serializable()
93
  class TransformerBlock(keras.layers.Layer):
 
96
  self.n_heads = n_heads
97
  self.head_dim = d_model // n_heads
98
  self.d_model = d_model
 
99
  self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
100
+ self.pre_attn_norm = keras.layers.LayerNormalization(epsilon=1e-5)
101
+ self.pre_ffn_norm = keras.layers.LayerNormalization(epsilon=1e-5)
 
102
  self.q_proj = keras.layers.Dense(d_model, use_bias=False)
103
  self.k_proj = keras.layers.Dense(d_model, use_bias=False)
104
  self.v_proj = keras.layers.Dense(d_model, use_bias=False)
105
  self.out_proj = keras.layers.Dense(d_model, use_bias=False)
 
106
  self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False)
107
  self.up_proj = keras.layers.Dense(ff_dim, use_bias=False)
108
  self.down_proj = keras.layers.Dense(d_model, use_bias=False)
109
  self.dropout = keras.layers.Dropout(dropout)
110
 
111
  def call(self, x, cache=None, training=None):
112
+ B = tf.shape(x)[0]
113
+ T = tf.shape(x)[1]
114
 
115
+ # 1. Attention
116
  res = x
117
  y = self.pre_attn_norm(x)
118
 
 
120
  k = tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim])
121
  v = tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim])
122
 
123
+ # KV Cache
124
  if cache is not None:
125
  k_cache, v_cache = cache
126
  k = tf.concat([k_cache, k], axis=1)
127
  v = tf.concat([v_cache, v], axis=1)
 
128
  new_cache = (k, v)
129
 
130
  # RoPE
131
  q = tf.transpose(q, [0, 2, 1, 3])
132
  k_rot = tf.transpose(k, [0, 2, 1, 3])
133
  v_t = tf.transpose(v, [0, 2, 1, 3])
 
134
  q, k_rot = self.rope(q, k_rot)
135
 
136
+ # Attention Scores
137
  scores = tf.matmul(q, k_rot, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, x.dtype))
138
 
139
+ # --- πŸ› οΈ FIX: Graph-Safe Causal Mask ---
140
+ def apply_mask():
141
+ # Create triangular mask for prefill (T > 1)
142
  mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
143
+ return (1.0 - mask) * -1e9
 
144
 
145
+ def no_mask():
146
+ # No mask needed for decoding step (T=1 attends to all past)
147
+ return tf.zeros((1, 1)) # Broadcastable 0
148
+
149
+ # Use tf.cond instead of python 'if'
150
+ mask_offset = tf.cond(tf.greater(T, 1), apply_mask, no_mask)
151
+ scores = scores + mask_offset
152
+ # -----------------------------------------
153
+
154
  attn = tf.nn.softmax(scores, axis=-1)
155
  out = tf.matmul(attn, v_t)
156
+ out = tf.reshape(tf.transpose(out, [0, 2, 1, 3]), [B, T, self.d_model])
 
 
 
157
  x = res + self.out_proj(out)
158
 
159
+ # 2. FFN
160
  res = x
161
  y = self.pre_ffn_norm(x)
162
  ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
 
169
  super().__init__(**kwargs)
170
  self.embed = keras.layers.Embedding(config['vocab_size'], config['d_model'])
171
  ff_dim = int(config['d_model'] * config['ff_mult'])
 
172
  self.blocks = [
173
  TransformerBlock(
174
  config['d_model'], config['n_heads'], ff_dim, config['dropout'],
175
+ config['max_len'], config['rope_theta']
176
  ) for i in range(config['n_layers'])
177
  ]
178
+ self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
179
  self.lm_head = keras.layers.Dense(config['vocab_size'], use_bias=False)
180
 
181
  def call(self, input_ids, cache=None, training=None):
182
  x = self.embed(input_ids)
183
  new_caches = []
 
184
  for i, block in enumerate(self.blocks):
185
  c_i = cache[i] if cache is not None else None
186
  x, nc_i = block(x, cache=c_i, training=training)
187
  new_caches.append(nc_i)
 
188
  return self.lm_head(self.norm(x)), new_caches
189
 
190
  # ==============================================================================
191
+ # 4. Load Models
192
  # ==============================================================================
193
+ print("\nπŸ“¦ Loading Resources...")
194
 
195
  dummy_in = tf.zeros((1, 1), dtype=tf.int32)
196
 
197
+ # SAM-X (Reasoning)
198
+ print("πŸ”Ή SAM-X-1 (Reasoning)")
199
+ try:
200
+ samx_cfg = json.load(open(hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "config.json")))
201
+ samx_model = SAM1Model({
202
+ 'vocab_size': samx_cfg['vocab_size'], 'd_model': samx_cfg['hidden_size'],
203
+ 'n_layers': samx_cfg['num_hidden_layers'], 'n_heads': samx_cfg['num_attention_heads'],
204
+ 'ff_mult': samx_cfg['intermediate_size']/samx_cfg['hidden_size'],
205
+ 'max_len': samx_cfg['max_position_embeddings'], 'rope_theta': samx_cfg['rope_theta'], 'dropout':0.0
206
+ })
207
+ _ = samx_model(dummy_in)
208
+ samx_model.load_weights(hf_hub_download("Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5"))
209
+ tokenizer_x = Tokenizer.from_file(hf_hub_download("Smilyai-labs/Sam-1-large-it-0002", "tokenizer.json"))
210
+ except Exception as e: print(f"⚠️ Failed to load SAM-X: {e}")
211
+
212
+ # SAM-Z (Speed)
213
+ print("πŸ”Ή SAM-Z-1 (Fast)")
214
+ try:
215
+ samz_cfg = json.load(open(hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "config.json")))
216
+ samz_model = SAM1Model({
217
+ 'vocab_size': samz_cfg['vocab_size'], 'd_model': samz_cfg['hidden_size'],
218
+ 'n_layers': samz_cfg['num_hidden_layers'], 'n_heads': samz_cfg['num_attention_heads'],
219
+ 'ff_mult': samz_cfg['intermediate_size']/samz_cfg['hidden_size'],
220
+ 'max_len': samz_cfg['max_position_embeddings'], 'rope_theta': samz_cfg['rope_theta'], 'dropout':0.0
221
+ })
222
+ _ = samz_model(dummy_in)
223
+ samz_model.load_weights(hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "ckpt.weights.h5"))
224
+ tokenizer_z = Tokenizer.from_file(hf_hub_download("Smilyai-labs/Sam-Z-1-tensorflow", "tokenizer.json"))
225
+ except Exception as e: print(f"⚠️ Failed to load SAM-Z: {e}")
226
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  @tf.function(jit_compile=True)
228
  def predict_x(ids, cache): return samx_model(ids, cache=cache, training=False)
229
 
 
231
  def predict_z(ids, cache): return samz_model(ids, cache=cache, training=False)
232
 
233
  # ==============================================================================
234
+ # 5. Backend Workers
235
  # ==============================================================================
236
  task_queue = queue.Queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  def worker():
239
  while True:
240
  try:
241
  tid, model, prompt = task_queue.get(timeout=1)
242
+
243
+ # Select Model
244
+ if "SAM-X" in model: pred_fn, tok = predict_x, tokenizer_x
245
+ else: pred_fn, tok = predict_z, tokenizer_z
246
+
247
+ # Inference
248
+ try:
249
+ ids = [i for i in tok.encode(prompt).ids]
250
+ gen = []
251
+
252
+ # Prefill
253
+ curr = tf.constant([ids], dtype=tf.int32)
254
+ logits, cache = pred_fn(curr, cache=None)
255
+ next_t = np.argmax(logits[0,-1,:])
256
+ gen.append(next_t)
257
+
258
+ # Decode
259
+ start = time.time()
260
+ for i in range(1024):
261
+ curr = tf.constant([[gen[-1]]], dtype=tf.int32)
262
+ logits, cache = pred_fn(curr, cache=cache)
263
+ next_t = np.argmax(logits[0,-1,:])
264
+ if next_t == 50256: break
265
+ gen.append(next_t)
266
+
267
+ if i % 5 == 0:
268
+ txt = tok.decode(gen)
269
+ with db_lock:
270
+ db_conn.execute("UPDATE tasks SET status='processing', result=?, progress=? WHERE id=?",
271
+ (txt, int(i/10.24), tid))
272
+ db_conn.commit()
273
+
274
+ # Done
275
+ txt = tok.decode(gen)
276
+ with db_lock:
277
+ db_conn.execute("UPDATE tasks SET status='completed', result=?, progress=100, completed_at=? WHERE id=?",
278
+ (txt, datetime.now().isoformat(), tid))
279
+ db_conn.commit()
280
+
281
+ except Exception as e:
282
+ print(f"Error {tid}: {e}")
283
+ with db_lock:
284
+ db_conn.execute("UPDATE tasks SET status='failed', result=? WHERE id=?", (str(e), tid))
285
+ db_conn.commit()
286
+
287
  task_queue.task_done()
288
+ except queue.Empty: continue
 
289
 
290
+ threading.Thread(target=worker, daemon=True).start()
 
 
291
 
292
  # ==============================================================================
293
+ # 6. "More Better" UI (Custom CSS + Chat Layout)
294
  # ==============================================================================
295
  css = """
296
+ body { background-color: #0b0f19; color: #e5e7eb; }
297
+ .sidebar { background-color: #111827; border-right: 1px solid #374151; height: 100vh; overflow-y: auto; padding: 20px; }
298
+ .main-content { padding: 20px; max-width: 900px; margin: 0 auto; }
299
+ .task-card {
300
+ background: #1f2937; border: 1px solid #374151; border-radius: 8px;
301
+ padding: 12px; margin-bottom: 8px; cursor: pointer; transition: all 0.2s;
302
+ }
303
+ .task-card:hover { background: #374151; border-color: #60a5fa; }
304
+ .status-badge {
305
+ font-size: 10px; padding: 2px 6px; border-radius: 4px; text-transform: uppercase; font-weight: bold;
306
+ }
307
+ .status-queued { background: #f59e0b20; color: #f59e0b; }
308
+ .status-processing { background: #3b82f620; color: #3b82f6; animation: pulse 2s infinite; }
309
+ .status-completed { background: #10b98120; color: #10b981; }
310
+ .status-failed { background: #ef444420; color: #ef4444; }
311
+
312
+ /* Message Bubbles */
313
+ .chat-container { display: flex; flex-direction: column; gap: 20px; margin-top: 20px; }
314
+ .message { padding: 16px; border-radius: 12px; max-width: 85%; line-height: 1.6; }
315
+ .user-msg { align-self: flex-end; background: #2563eb; color: white; }
316
+ .bot-msg { align-self: flex-start; background: #1f2937; border: 1px solid #374151; color: #e5e7eb; width: 100%; }
317
+
318
+ /* Thought Block */
319
+ details.think {
320
+ background: #172554; border-left: 3px solid #3b82f6; border-radius: 4px;
321
+ padding: 8px; margin-bottom: 12px; font-size: 0.9em; color: #93c5fd;
322
+ }
323
+ details.think summary { cursor: pointer; font-weight: bold; opacity: 0.8; }
324
+ details.think[open] summary { margin-bottom: 8px; border-bottom: 1px solid #3b82f640; padding-bottom: 4px; }
325
+
326
+ @keyframes pulse { 0% { opacity: 1; } 50% { opacity: 0.6; } 100% { opacity: 1; } }
327
  """
328
 
329
+ def format_chat(text):
330
  if not text: return ""
331
+ # Beautiful formatted thought blocks
332
  if "<think>" in text:
333
  parts = text.split("<think>")
334
  pre = parts[0]
335
  rest = parts[1]
336
  if "</think>" in rest:
337
  thought, ans = rest.split("</think>")
338
+ return f"{pre}<details class='think'><summary>🧠 Thought Process</summary>{thought}</details>{ans}"
339
+ return f"{pre}<details class='think' open><summary>🧠 Thinking...</summary>{rest} <span class='status-processing'>●</span></details>"
 
340
  return text.replace("\n", "<br>")
341
 
342
+ with gr.Blocks(css=css, title="SmilyAI Studio", theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate")) as demo:
343
+ user_id = gr.State(value=None)
344
+ current_task = gr.State(value=None)
 
345
 
346
+ with gr.Row(elem_classes="container"):
347
+ # --- Left Sidebar (History) ---
348
+ with gr.Column(scale=1, elem_classes="sidebar"):
349
+ gr.Markdown("### πŸ—‚οΈ History")
350
+ refresh_btn = gr.Button("πŸ”„ Refresh", size="sm", variant="secondary")
351
+ history_list = gr.HTML("Log in to see tasks")
352
 
353
+ gr.Markdown("---")
354
+ gr.Markdown("### πŸ‘€ Account")
355
+ u_in = gr.Textbox(placeholder="Username", show_label=False)
356
+ p_in = gr.Textbox(placeholder="Password", show_label=False, type="password")
357
+ login_btn = gr.Button("Login", size="sm")
358
+
359
+ # --- Main Content (Chat & Monitor) ---
360
+ with gr.Column(scale=3, elem_classes="main-content"):
361
+ gr.Markdown("# ✨ SmilyAI Studio")
362
+
363
+ with gr.Group():
364
+ with gr.Row():
365
+ model_sel = gr.Dropdown(
366
+ ["SAM-X-1 (Reasoning)", "SAM-Z-1 (Fast)"],
367
+ value="SAM-Z-1 (Fast)", label="Select Model", interactive=True
368
+ )
369
+ prompt_in = gr.Textbox(
370
+ placeholder="Ask anything... (e.g. 'Explain quantum physics')",
371
+ lines=3, show_label=False
372
+ )
373
+ with gr.Row():
374
+ generate_btn = gr.Button("πŸš€ Generate", variant="primary", size="lg")
375
+
376
+ # Live View
377
+ gr.Markdown("### πŸ“‘ Live Monitor")
378
+ with gr.Group():
379
+ stream_display = gr.HTML(
380
+ "<div style='padding:20px; text-align:center; color:#6b7280'>Select a task to watch</div>",
381
+ elem_id="stream-box"
382
+ )
383
+
384
+ # --- Logic Functions ---
385
  def login(u, p):
386
+ h = hashlib.sha256(p.encode()).hexdigest()
387
  with db_lock:
388
  c = db_conn.cursor()
389
+ c.execute("SELECT id FROM users WHERE username=?", (u,))
390
+ row = c.fetchone()
391
+ if not row: # Auto-register for demo
392
+ c.execute("INSERT INTO users (username, password_hash) VALUES (?,?)", (u, h))
393
+ db_conn.commit()
394
+ row = (c.lastrowid,)
395
+ return row[0], load_history(row[0])
396
+
397
+ def create_task(uid, model, text):
398
+ if not uid: return None, "Please login first"
399
+ tid = str(uuid.uuid4())
 
 
 
 
 
 
 
400
  with db_lock:
401
+ db_conn.execute("INSERT INTO tasks (id, user_id, model_name, prompt, status) VALUES (?,?,?,?,?)",
402
+ (tid, uid, model, text, 'queued'))
403
+ db_conn.commit()
404
+ task_queue.put((tid, model, text))
405
+ return tid, tid # Set current task
406
+
407
+ def load_history(uid):
408
+ if not uid: return "Please Login"
409
+ with db_lock:
410
+ rows = db_conn.execute("SELECT id, model_name, status, prompt FROM tasks WHERE user_id=? ORDER BY created_at DESC LIMIT 10", (uid,)).fetchall()
411
 
412
  html = ""
413
  for r in rows:
414
+ tid, mod, stat, p = r
415
+ short_mod = "Reasoning" if "SAM-X" in mod else "Fast"
416
+ html += f"""
417
+ <div class='task-card' onclick="setTask('{tid}')">
418
+ <div style='display:flex; justify-content:space-between; margin-bottom:4px'>
419
+ <span style='font-weight:bold; color:#e5e7eb'>{short_mod}</span>
420
+ <span class='status-badge status-{stat}'>{stat}</span>
421
+ </div>
422
+ <div style='font-size:12px; color:#9ca3af; white-space:nowrap; overflow:hidden; text-overflow:ellipsis'>{p}</div>
423
+ <div style='font-size:10px; color:#4b5563; margin-top:4px'>ID: {tid[:8]}</div>
424
+ </div>
425
+ """
426
  return html
427
 
428
+ def watch_stream(tid):
429
+ if not tid: return "Select a task..."
 
 
 
430
  with db_lock:
431
+ row = db_conn.execute("SELECT result, status FROM tasks WHERE id=?", (tid,)).fetchone()
432
+ if not row: return "Task not found"
433
+
434
+ text, status = row
435
+ formatted = format_chat(text)
436
+
437
+ container = f"""
438
+ <div class='chat-container'>
439
+ <div class='message bot-msg'>
440
+ {formatted}
441
+ </div>
442
+ </div>
443
+ """
444
+ return container
445
+
446
+ # --- Wiring ---
447
+ login_btn.click(login, [u_in, p_in], [user_id, history_list])
448
+
449
+ generate_btn.click(
450
+ create_task, [user_id, model_sel, prompt_in], [current_task, current_task]
451
+ ).then(
452
+ load_history, [user_id], [history_list]
453
+ )
454
 
455
+ refresh_btn.click(load_history, [user_id], [history_list])
456
+
457
+ # Helper to handle Javascript click on HTML cards
458
+ # Requires a hidden text input to bridge JS -> Python (omitted for brevity, polling works fine)
459
+
460
+ # Auto-refresh stream
461
+ timer = gr.Timer(0.5)
462
+ timer.tick(watch_stream, [current_task], [stream_display])
463
+ timer.tick(load_history, [user_id], [history_list])
464
 
465
  if __name__ == "__main__":
466
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)