# ------------------------------- # app.py — Sam-3.5: The Reasoning AI (Updated Architecture) # ------------------------------- import math import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from safetensors.torch import load_file from transformers import AutoTokenizer from dataclasses import dataclass from typing import Dict, List import gradio as gr import os from huggingface_hub import hf_hub_download import json # ------------------------------- # 1) Configuration & Special Tokens # ------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") SPECIAL_TOKENS = { "bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>", "think": "<|think|>", # Keep this for reasoning display } tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()} EOT_ID = SPECIAL_IDS.get("eot", tokenizer.eos_token_id) THINK_ID = SPECIAL_IDS.get("think") assert THINK_ID is not None, "Tokenizer must include <|think|> token" MAX_LENGTH = 1024 # ------------------------------- # 2) Model Architecture (Sam-3.5) # ------------------------------- class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d)) def forward(self, x): return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() class MHA(nn.Module): def __init__(self, d_model, n_heads, dropout=0.0): super().__init__() if d_model % n_heads != 0: raise ValueError("d_model must be divisible by n_heads") self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, attn_mask=None): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) out = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.dropout.p if self.training else 0.0 ) return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C)) class SwiGLU(nn.Module): def __init__(self, d_model, d_ff, dropout=0.0): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x))) class Block(nn.Module): def __init__(self, d_model, n_heads, ff_mult, dropout=0.0): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = MHA(d_model, n_heads, dropout=dropout) self.norm2 = RMSNorm(d_model) self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout) self.drop = nn.Dropout(dropout) def forward(self, x, attn_mask=None): x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask)) x = x + self.drop(self.ff(self.norm2(x))) return x @dataclass class Sam3Config: vocab_size: int d_model: int = 468 n_layers: int = 14 n_heads: int = 6 ff_mult: float = 4.0 dropout: float = 0.1 input_modality: str = "text" head_type: str = "causal_lm" version: str = "0.1" class Sam3(nn.Module): def __init__(self, config: Sam3Config): super().__init__() self.config = config self.embed = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([ Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers) ]) self.norm = RMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embed.weight # Weight tying def forward(self, input_ids, attention_mask=None): x = self.embed(input_ids) for blk in self.blocks: x = blk(x, attn_mask=attention_mask) x = self.norm(x) return self.lm_head(x) # ------------------------------- # 3) Load Model from Hugging Face Hub # ------------------------------- def load_sam3_model_from_hf(repo_id: str, filename: str = "sam3-epoch1-best.safetensors"): print(f"📥 Loading config and weights from: {repo_id}") config_path = hf_hub_download(repo_id=repo_id, filename="config.json") weights_path = hf_hub_download(repo_id=repo_id, filename=filename) with open(config_path, "r") as f: config_dict = json.load(f) # Ensure vocab_size matches tokenizer after adding special tokens config_dict["vocab_size"] = len(tokenizer) config = Sam3Config(**config_dict) model = Sam3(config).to(device) state_dict = load_file(weights_path) model.load_state_dict(state_dict, strict=False) model.eval() print(f"✅ Model loaded successfully from Hugging Face Hub: {repo_id}") return model # Load model model = load_sam3_model_from_hf("Smilyai-labs/Sam-3.5-1") # ------------------------------- # 4) Sampling Function (Enhanced from your original) # ------------------------------- def sample_next_token( logits, past_tokens, temperature=0.8, top_k=60, top_p=0.9, repetition_penalty=1.1, max_repeat=5, no_repeat_ngram_size=3 ): if logits.dim() == 3: logits = logits[:, -1, :].clone() else: logits = logits.clone() batch_size, vocab_size = logits.size(0), logits.size(1) orig_logits = logits.clone() if temperature != 1.0: logits = logits / float(temperature) past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens) for token_id in set(past_list): if 0 <= token_id < vocab_size: logits[:, token_id] /= repetition_penalty if len(past_list) >= max_repeat: last_token = past_list[-1] count = 1 for i in reversed(past_list[:-1]): if i == last_token: count += 1 else: break if count >= max_repeat: if 0 <= last_token < vocab_size: logits[:, last_token] = -float("inf") if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size: for i in range(len(past_list) - no_repeat_ngram_size + 1): ngram = tuple(past_list[i : i + no_repeat_ngram_size]) if len(past_list) >= no_repeat_ngram_size - 1: prefix = tuple(past_list[-(no_repeat_ngram_size - 1):]) for token_id in range(vocab_size): if tuple(list(prefix) + [token_id]) == ngram and 0 <= token_id < vocab_size: logits[:, token_id] = -float("inf") if top_k is not None and top_k > 0: tk = min(max(1, int(top_k)), vocab_size) topk_vals, topk_indices = torch.topk(logits, tk, dim=-1) min_topk = topk_vals[:, -1].unsqueeze(-1) logits[logits < min_topk] = -float("inf") if top_p is not None and 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) for b in range(batch_size): sorted_mask = cumulative_probs[b] > top_p if sorted_mask.numel() > 0: sorted_mask[0] = False tokens_to_remove = sorted_indices[b][sorted_mask] logits[b, tokens_to_remove] = -float("inf") for b in range(batch_size): if torch.isneginf(logits[b]).all(): logits[b] = orig_logits[b] probs = F.softmax(logits, dim=-1) if torch.isnan(probs).any(): probs = torch.ones_like(logits) / logits.size(1) next_token = torch.multinomial(probs, num_samples=1) return next_token.to(device) # ------------------------------- # 5) Gradio Chat Interface — WITH STYLED THINKING STEPS # ------------------------------- def predict(message, history): # Build prompt chat_history = [] for human, assistant in history: chat_history.append(f"{SPECIAL_TOKENS['user']} {human} {SPECIAL_TOKENS['eot']}") if assistant: # Assistant responses may contain <|think|>...<|eot|> blocks — we don't reconstruct them here chat_history.append(f"{SPECIAL_TOKENS['assistant']} {assistant} {SPECIAL_TOKENS['eot']}") chat_history.append(f"{SPECIAL_TOKENS['user']} {message} {SPECIAL_TOKENS['eot']}") system_prompt = "You are Sam-3.5, an advanced reasoning AI. You think step-by-step, analyze deeply, and respond with precision. You do not guess — you deduce. Avoid medical or legal advice." prompt = f"{SPECIAL_TOKENS['system']} {system_prompt} {SPECIAL_TOKENS['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS['assistant']} {SPECIAL_TOKENS['think']}" inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(device) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] generated_text = "" thinking_mode = False thinking_buffer = "" for _ in range(256): with torch.no_grad(): logits = model(input_ids, attention_mask=attention_mask) next_token = sample_next_token( logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1 ) token_id = int(next_token.squeeze().item()) token_str = tokenizer.decode([token_id], skip_special_tokens=False) # Append to sequence input_ids = torch.cat([input_ids, next_token], dim=1) attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device, dtype=attention_mask.dtype)], dim=1) # Handle thinking mode if not thinking_mode and token_str.strip() == "<|think|>": thinking_mode = True thinking_buffer = "" continue if thinking_mode: if token_str.strip() == "<|eot|>": # End thinking block → yield styled output thinking_buffer = thinking_buffer.strip() if thinking_buffer: yield f"
💡 Thinking: {thinking_buffer}
" thinking_mode = False continue else: thinking_buffer += token_str continue # Normal output if not thinking_mode: # Clean token for display (optional: handle GPT-2 space artifacts) clean_token = token_str.replace('Ġ', ' ').replace('Ċ', '\n') generated_text += clean_token yield generated_text # Stop if final EOT (outside thinking block) if token_id == EOT_ID and not thinking_mode: break # ------------------------------- # 6) Launch Gradio Interface # ------------------------------- CSS = """ .gradio-container .message-bubble { border-radius: 12px !important; padding: 10px 14px !important; font-size: 16px !important; } .gradio-container .message-bubble.user { background-color: #007bff !important; color: white !important; } .gradio-container .message-bubble.assistant { background-color: #f8f9fa !important; color: #212529 !important; border: 1px solid #e9ecef; } """ demo = gr.ChatInterface( fn=predict, title="🧠 Sam-3.5: The Reasoning AI", description=""" Sam-3.5 doesn’t just answer — it **thinks first**. Watch its internal reasoning unfold in real time — step by step, clearly shown. No guessing. No fluff. Just pure deduction. Try asking: → “Why does a mirror reverse left and right but not up and down?” → “If I have 3 apples and give away half, then buy 5 more, how many do I have?” → “Explain quantum entanglement like I’m 10.” → “What’s wrong with this argument: ‘All birds fly; penguins are birds; therefore penguins can fly’?” """, theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="blue" ), chatbot=gr.Chatbot( label="Sam-3.5 🤔", bubble_full_width=False, height=600, avatar_images=( "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-bot.jpg", "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-user.jpg" ) ), examples=[ "What is the capital of France?", "Explain why the sky is blue.", "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?", "What are the ethical implications of AI making medical diagnoses?" ], css=CSS, cache_examples=False ) if __name__ == "__main__": demo.launch(show_api=True)