# ------------------------------- # app.py # # Sam-3: The Reasoning AI β€” Now Showing Its Thought Process! # Powered by Smilyai-labs/Sam-3.0-3. Trained to think before speaking. # ------------------------------- import math import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from safetensors.torch import load_file, safe_open from transformers import AutoTokenizer from dataclasses import dataclass import gradio as gr import os from huggingface_hub import hf_hub_download # ------------------------------- # 1) Sam-3.0-3 Architecture # ------------------------------- @dataclass class Sam3Config: vocab_size: int = 50257 d_model: int = 384 n_layers: int = 10 n_heads: int = 6 ff_mult: float = 4.0 dropout: float = 0.1 input_modality: str = "text" head_type: str = "causal_lm" version: str = "0.1" def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs): self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.ff_mult = ff_mult self.dropout = dropout self.input_modality = input_modality self.head_type = head_type self.version = version class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d)) def forward(self, x): return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() class MHA(nn.Module): def __init__(self, d_model, n_heads, dropout=0.0): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, attn_mask=None): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(causal, float("-inf")) if attn_mask is not None: scores = scores.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2).bool(), float("-inf")) attn = torch.softmax(scores, dim=-1) out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(out) class SwiGLU(nn.Module): def __init__(self, d_model, d_ff, dropout=0.0): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w3(self.dropout(torch.nn.functional.silu(self.w1(x)) * self.w2(x))) class Block(nn.Module): def __init__(self, d_model, n_heads, ff_mult, dropout=0.0): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = MHA(d_model, n_heads, dropout=dropout) self.norm2 = RMSNorm(d_model) self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout) self.drop = nn.Dropout(dropout) def forward(self, x, attn_mask=None): x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask)) x = x + self.drop(self.ff(self.norm2(x))) return x class Sam3(nn.Module): def __init__(self, config: Sam3Config): super().__init__() self.config = config self.embed = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embed.weight def forward(self, input_ids, attention_mask=None): x = self.embed(input_ids) for blk in self.blocks: x = blk(x, attn_mask=attention_mask) x = self.norm(x) return self.lm_head(x) # ------------------------------- # 2) Load Tokenizer & Special Tokens # ------------------------------- SPECIAL_TOKENS = { "bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>", "think": "<|think|>", } tokenizer = AutoTokenizer.from_pretrained("gpt2") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) EOT_ID = tokenizer.convert_tokens_to_ids("<|eot|>") or tokenizer.eos_token_id THINK_ID = tokenizer.convert_tokens_to_ids("<|think|>") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ------------------------------- # 3) Download Model Weights from Hugging Face Hub # ------------------------------- hf_repo = "Smilyai-labs/Sam-3.0-3" weights_filename = "model.safetensors" print(f"Loading model '{hf_repo}' from Hugging Face Hub...") try: weights_path = hf_hub_download(repo_id=hf_repo, filename=weights_filename) print(f"βœ… Downloaded weights to: {weights_path}") if not os.path.exists(weights_path): raise FileNotFoundError(f"Downloaded file not found at {weights_path}") file_size = os.path.getsize(weights_path) print(f"πŸ“„ File size: {file_size} bytes") except Exception as e: raise RuntimeError(f"❌ Failed to download model weights: {e}") # Initialize model cfg = Sam3Config(vocab_size=len(tokenizer)) model = Sam3(cfg).to(device) # Load state dict safely print("Loading state dict...") try: state_dict = {} with safe_open(weights_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) print("βœ… Loaded via safe_open") except Exception as e: print(f"⚠️ safe_open failed: {e}. Falling back to torch.load...") try: state_dict = torch.load(weights_path, map_location="cpu") print("βœ… Loaded via torch.load") except Exception as torch_e: raise RuntimeError(f"❌ Could not load model weights: {torch_e}") # Filter and load model_state_dict = model.state_dict() filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} missing_keys = set(model_state_dict.keys()) - set(filtered_state_dict.keys()) extra_keys = set(state_dict.keys()) - set(model_state_dict.keys()) if missing_keys: print(f"⚠️ Missing keys: {missing_keys}") if extra_keys: print(f"⚠️ Extra keys: {extra_keys}") model.load_state_dict(filtered_state_dict, strict=False) model.eval() print("βœ… Model loaded successfully!") # ------------------------------- # 4) Sampling Function (Unchanged) # ------------------------------- def sample_next_token( logits, past_tokens, temperature=0.8, top_k=60, top_p=0.9, repetition_penalty=1.1, max_repeat=5, no_repeat_ngram_size=3 ): if logits.dim() == 3: logits = logits[:, -1, :].clone() else: logits = logits.clone() batch_size, vocab_size = logits.size(0), logits.size(1) orig_logits = logits.clone() if temperature != 1.0: logits = logits / float(temperature) past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens) for token_id in set(past_list): if 0 <= token_id < vocab_size: logits[:, token_id] /= repetition_penalty if len(past_list) >= max_repeat: last_token = past_list[-1] count = 1 for i in reversed(past_list[:-1]): if i == last_token: count += 1 else: break if count >= max_repeat: if 0 <= last_token < vocab_size: logits[:, last_token] = -float("inf") if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size: for i in range(len(past_list) - no_repeat_ngram_size + 1): ngram = tuple(past_list[i : i + no_repeat_ngram_size]) if len(past_list) >= no_repeat_ngram_size - 1: prefix = tuple(past_list[-(no_repeat_ngram_size - 1):]) for token_id in range(vocab_size): if tuple(list(prefix) + [token_id]) == ngram and 0 <= token_id < vocab_size: logits[:, token_id] = -float("inf") if top_k is not None and top_k > 0: tk = min(max(1, int(top_k)), vocab_size) topk_vals, topk_indices = torch.topk(logits, tk, dim=-1) min_topk = topk_vals[:, -1].unsqueeze(-1) logits[logits < min_topk] = -float("inf") if top_p is not None and 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) for b in range(batch_size): sorted_mask = cumulative_probs[b] > top_p if sorted_mask.numel() > 0: sorted_mask[0] = False tokens_to_remove = sorted_indices[b][sorted_mask] logits[b, tokens_to_remove] = -float("inf") for b in range(batch_size): if torch.isneginf(logits[b]).all(): logits[b] = orig_logits[b] probs = F.softmax(logits, dim=-1) if torch.isnan(probs).any(): probs = torch.ones_like(logits) / logits.size(1) next_token = torch.multinomial(probs, num_samples=1) return next_token.to(device) # ------------------------------- # 5) Gradio Chat Interface β€” WITH STYLED THINKING STEPS # ------------------------------- SPECIAL_TOKENS_CHAT = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>", "think": "<|think|>"} def predict(message, history): # Build prompt with <|think|> to trigger internal reasoning chat_history = [] for human, assistant in history: chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {human} {SPECIAL_TOKENS_CHAT['eot']}") if assistant: chat_history.append(f"{SPECIAL_TOKENS_CHAT['assistant']} {assistant} {SPECIAL_TOKENS_CHAT['eot']}") chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {message} {SPECIAL_TOKENS_CHAT['eot']}") system_prompt = "You are Sam-3, an advanced reasoning AI. You think step-by-step, analyze deeply, and respond with precision. You do not guess β€” you deduce. Avoid medical or legal advice." prompt = f"{SPECIAL_TOKENS_CHAT['system']} {system_prompt} {SPECIAL_TOKENS_CHAT['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS_CHAT['assistant']} {SPECIAL_TOKENS_CHAT['think']}" inputs = tokenizer(prompt, return_tensors="pt").to(device) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] generated_text = "" thinking_mode = False thinking_buffer = "" for _ in range(256): with torch.no_grad(): logits = model(input_ids, attention_mask=attention_mask) next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1) token_id = int(next_token.squeeze().item()) token_str = tokenizer.decode([token_id], skip_special_tokens=False) # Keep special tokens! input_ids = torch.cat([input_ids, next_token], dim=1) attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1) # Detect if we're entering/exiting thinking mode if not thinking_mode and token_str == "<|think|>": thinking_mode = True thinking_buffer = "" # Start capturing thoughts continue # Don't yield <|think|> itself if thinking_mode: if token_str == "<|eot|>": # End of thought β€” now yield the full thinking block thinking_buffer = thinking_buffer.strip() if thinking_buffer: # Yield as styled markdown block yield f"
πŸ’‘ Thinking: {thinking_buffer}
" thinking_mode = False continue else: thinking_buffer += token_str continue # Don't yield yet β€” buffer until <|eot|> # Normal response output if not thinking_mode: generated_text += token_str yield generated_text # Stop on final EOT if token_id == EOT_ID and not thinking_mode: break # Custom CSS for styling thinking blocks CSS = """ .gradio-container .message-bubble { border-radius: 12px !important; } .gradio-container .message-bubble.user { background-color: #1f7bff !important; color: white !important; } .gradio-container .message-bubble.assistant { background-color: #e9ecef !important; color: #212529 !important; } """ # Gradio Interface demo = gr.ChatInterface( fn=predict, title="🌟 Sam-3: The Reasoning AI", description=""" Sam-3 doesn’t just answer β€” it **thinks first**. Watch its internal reasoning unfold in real time β€” step by step, clearly shown. No guessing. No fluff. Just pure deduction. Try asking: β†’ β€œWhy does a mirror reverse left and right but not up and down?” β†’ β€œIf I have 3 apples and give away half, then buy 5 more, how many do I have?” β†’ β€œExplain quantum entanglement like I’m 10.” β†’ β€œWhat’s wrong with this argument: β€˜All birds fly; penguins are birds; therefore penguins can fly’?” """, theme=gr.themes.Soft( primary_hue="indigo", secondary_hue="blue" ), chatbot=gr.Chatbot( label="Sam-3 πŸ€”", bubble_full_width=False, height=600, avatar_images=( "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-bot.jpg", "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-user.jpg" ) ), examples=[ "What is the capital of France?", "Explain why the sky is blue.", "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?", "What are the ethical implications of AI making medical diagnoses?" ], css=CSS, cache_examples=False ).launch( show_api=True )