Spaces:
Runtime error
Runtime error
| # ------------------------------- | |
| # app.py | |
| # | |
| # Sam-3: The Reasoning AI — Now Showing Its Thought Process! | |
| # Powered by Smilyai-labs/Sam-3.0-3. Trained to think before speaking. | |
| # ------------------------------- | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| from safetensors.torch import load_file, safe_open | |
| from transformers import AutoTokenizer | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| # ------------------------------- | |
| # 1) Sam-3.0-3 Architecture | |
| # ------------------------------- | |
| class Sam3Config: | |
| vocab_size: int = 50257 | |
| d_model: int = 384 | |
| n_layers: int = 10 | |
| n_heads: int = 6 | |
| ff_mult: float = 4.0 | |
| dropout: float = 0.1 | |
| input_modality: str = "text" | |
| head_type: str = "causal_lm" | |
| version: str = "0.1" | |
| def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs): | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.n_layers = n_layers | |
| self.n_heads = n_heads | |
| self.ff_mult = ff_mult | |
| self.dropout = dropout | |
| self.input_modality = input_modality | |
| self.head_type = head_type | |
| self.version = version | |
| class RMSNorm(nn.Module): | |
| def __init__(self, d, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(d)) | |
| def forward(self, x): | |
| return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() | |
| class MHA(nn.Module): | |
| def __init__(self, d_model, n_heads, dropout=0.0): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, attn_mask=None): | |
| B, T, C = x.shape | |
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1) | |
| scores = scores.masked_fill(causal, float("-inf")) | |
| if attn_mask is not None: | |
| scores = scores.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2).bool(), float("-inf")) | |
| attn = torch.softmax(scores, dim=-1) | |
| out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C) | |
| return self.out_proj(out) | |
| class SwiGLU(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.0): | |
| super().__init__() | |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) | |
| self.w2 = nn.Linear(d_model, d_ff, bias=False) | |
| self.w3 = nn.Linear(d_ff, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.w3(self.dropout(torch.nn.functional.silu(self.w1(x)) * self.w2(x))) | |
| class Block(nn.Module): | |
| def __init__(self, d_model, n_heads, ff_mult, dropout=0.0): | |
| super().__init__() | |
| self.norm1 = RMSNorm(d_model) | |
| self.attn = MHA(d_model, n_heads, dropout=dropout) | |
| self.norm2 = RMSNorm(d_model) | |
| self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x, attn_mask=None): | |
| x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask)) | |
| x = x + self.drop(self.ff(self.norm2(x))) | |
| return x | |
| class Sam3(nn.Module): | |
| def __init__(self, config: Sam3Config): | |
| super().__init__() | |
| self.config = config | |
| self.embed = nn.Embedding(config.vocab_size, config.d_model) | |
| self.blocks = nn.ModuleList([Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers)]) | |
| self.norm = RMSNorm(config.d_model) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| self.lm_head.weight = self.embed.weight | |
| def forward(self, input_ids, attention_mask=None): | |
| x = self.embed(input_ids) | |
| for blk in self.blocks: | |
| x = blk(x, attn_mask=attention_mask) | |
| x = self.norm(x) | |
| return self.lm_head(x) | |
| # ------------------------------- | |
| # 2) Load Tokenizer & Special Tokens | |
| # ------------------------------- | |
| SPECIAL_TOKENS = { | |
| "bos": "<|bos|>", | |
| "eot": "<|eot|>", | |
| "user": "<|user|>", | |
| "assistant": "<|assistant|>", | |
| "system": "<|system|>", | |
| "think": "<|think|>", | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) | |
| EOT_ID = tokenizer.convert_tokens_to_ids("<|eot|>") or tokenizer.eos_token_id | |
| THINK_ID = tokenizer.convert_tokens_to_ids("<|think|>") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ------------------------------- | |
| # 3) Download Model Weights from Hugging Face Hub | |
| # ------------------------------- | |
| hf_repo = "Smilyai-labs/Sam-3.0-3" | |
| weights_filename = "model.safetensors" | |
| print(f"Loading model '{hf_repo}' from Hugging Face Hub...") | |
| try: | |
| weights_path = hf_hub_download(repo_id=hf_repo, filename=weights_filename) | |
| print(f"✅ Downloaded weights to: {weights_path}") | |
| if not os.path.exists(weights_path): | |
| raise FileNotFoundError(f"Downloaded file not found at {weights_path}") | |
| file_size = os.path.getsize(weights_path) | |
| print(f"📄 File size: {file_size} bytes") | |
| except Exception as e: | |
| raise RuntimeError(f"❌ Failed to download model weights: {e}") | |
| # Initialize model | |
| cfg = Sam3Config(vocab_size=len(tokenizer)) | |
| model = Sam3(cfg).to(device) | |
| # Load state dict safely | |
| print("Loading state dict...") | |
| try: | |
| state_dict = {} | |
| with safe_open(weights_path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| print("✅ Loaded via safe_open") | |
| except Exception as e: | |
| print(f"⚠️ safe_open failed: {e}. Falling back to torch.load...") | |
| try: | |
| state_dict = torch.load(weights_path, map_location="cpu") | |
| print("✅ Loaded via torch.load") | |
| except Exception as torch_e: | |
| raise RuntimeError(f"❌ Could not load model weights: {torch_e}") | |
| # Filter and load | |
| model_state_dict = model.state_dict() | |
| filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict} | |
| missing_keys = set(model_state_dict.keys()) - set(filtered_state_dict.keys()) | |
| extra_keys = set(state_dict.keys()) - set(model_state_dict.keys()) | |
| if missing_keys: | |
| print(f"⚠️ Missing keys: {missing_keys}") | |
| if extra_keys: | |
| print(f"⚠️ Extra keys: {extra_keys}") | |
| model.load_state_dict(filtered_state_dict, strict=False) | |
| model.eval() | |
| print("✅ Model loaded successfully!") | |
| # ------------------------------- | |
| # 4) Sampling Function (Unchanged) | |
| # ------------------------------- | |
| def sample_next_token( | |
| logits, | |
| past_tokens, | |
| temperature=0.8, | |
| top_k=60, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| max_repeat=5, | |
| no_repeat_ngram_size=3 | |
| ): | |
| if logits.dim() == 3: | |
| logits = logits[:, -1, :].clone() | |
| else: | |
| logits = logits.clone() | |
| batch_size, vocab_size = logits.size(0), logits.size(1) | |
| orig_logits = logits.clone() | |
| if temperature != 1.0: | |
| logits = logits / float(temperature) | |
| past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens) | |
| for token_id in set(past_list): | |
| if 0 <= token_id < vocab_size: | |
| logits[:, token_id] /= repetition_penalty | |
| if len(past_list) >= max_repeat: | |
| last_token = past_list[-1] | |
| count = 1 | |
| for i in reversed(past_list[:-1]): | |
| if i == last_token: | |
| count += 1 | |
| else: | |
| break | |
| if count >= max_repeat: | |
| if 0 <= last_token < vocab_size: | |
| logits[:, last_token] = -float("inf") | |
| if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size: | |
| for i in range(len(past_list) - no_repeat_ngram_size + 1): | |
| ngram = tuple(past_list[i : i + no_repeat_ngram_size]) | |
| if len(past_list) >= no_repeat_ngram_size - 1: | |
| prefix = tuple(past_list[-(no_repeat_ngram_size - 1):]) | |
| for token_id in range(vocab_size): | |
| if tuple(list(prefix) + [token_id]) == ngram and 0 <= token_id < vocab_size: | |
| logits[:, token_id] = -float("inf") | |
| if top_k is not None and top_k > 0: | |
| tk = min(max(1, int(top_k)), vocab_size) | |
| topk_vals, topk_indices = torch.topk(logits, tk, dim=-1) | |
| min_topk = topk_vals[:, -1].unsqueeze(-1) | |
| logits[logits < min_topk] = -float("inf") | |
| if top_p is not None and 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| for b in range(batch_size): | |
| sorted_mask = cumulative_probs[b] > top_p | |
| if sorted_mask.numel() > 0: | |
| sorted_mask[0] = False | |
| tokens_to_remove = sorted_indices[b][sorted_mask] | |
| logits[b, tokens_to_remove] = -float("inf") | |
| for b in range(batch_size): | |
| if torch.isneginf(logits[b]).all(): | |
| logits[b] = orig_logits[b] | |
| probs = F.softmax(logits, dim=-1) | |
| if torch.isnan(probs).any(): | |
| probs = torch.ones_like(logits) / logits.size(1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token.to(device) | |
| # ------------------------------- | |
| # 5) Gradio Chat Interface — WITH STYLED THINKING STEPS | |
| # ------------------------------- | |
| SPECIAL_TOKENS_CHAT = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>", "think": "<|think|>"} | |
| def predict(message, history): | |
| # Build prompt with <|think|> to trigger internal reasoning | |
| chat_history = [] | |
| for human, assistant in history: | |
| chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {human} {SPECIAL_TOKENS_CHAT['eot']}") | |
| if assistant: | |
| chat_history.append(f"{SPECIAL_TOKENS_CHAT['assistant']} {assistant} {SPECIAL_TOKENS_CHAT['eot']}") | |
| chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {message} {SPECIAL_TOKENS_CHAT['eot']}") | |
| system_prompt = "You are Sam-3, an advanced reasoning AI. You think step-by-step, analyze deeply, and respond with precision. You do not guess — you deduce. Avoid medical or legal advice." | |
| prompt = f"{SPECIAL_TOKENS_CHAT['system']} {system_prompt} {SPECIAL_TOKENS_CHAT['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS_CHAT['assistant']} {SPECIAL_TOKENS_CHAT['think']}" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| generated_text = "" | |
| thinking_mode = False | |
| thinking_buffer = "" | |
| for _ in range(256): | |
| with torch.no_grad(): | |
| logits = model(input_ids, attention_mask=attention_mask) | |
| next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1) | |
| token_id = int(next_token.squeeze().item()) | |
| token_str = tokenizer.decode([token_id], skip_special_tokens=False) # Keep special tokens! | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1) | |
| # Detect if we're entering/exiting thinking mode | |
| if not thinking_mode and token_str == "<|think|>": | |
| thinking_mode = True | |
| thinking_buffer = "" # Start capturing thoughts | |
| continue # Don't yield <|think|> itself | |
| if thinking_mode: | |
| if token_str == "<|eot|>": | |
| # End of thought — now yield the full thinking block | |
| thinking_buffer = thinking_buffer.strip() | |
| if thinking_buffer: | |
| # Yield as styled markdown block | |
| yield f"<div style='background-color:#f8f9fa; padding:12px; border-left:4px solid #ccc; border-radius:0 8px 8px 0; margin:10px 0; font-style:italic; color:#555;'>💡 Thinking: {thinking_buffer}</div>" | |
| thinking_mode = False | |
| continue | |
| else: | |
| thinking_buffer += token_str | |
| continue # Don't yield yet — buffer until <|eot|> | |
| # Normal response output | |
| if not thinking_mode: | |
| generated_text += token_str | |
| yield generated_text | |
| # Stop on final EOT | |
| if token_id == EOT_ID and not thinking_mode: | |
| break | |
| # Custom CSS for styling thinking blocks | |
| CSS = """ | |
| .gradio-container .message-bubble { | |
| border-radius: 12px !important; | |
| } | |
| .gradio-container .message-bubble.user { | |
| background-color: #1f7bff !important; | |
| color: white !important; | |
| } | |
| .gradio-container .message-bubble.assistant { | |
| background-color: #e9ecef !important; | |
| color: #212529 !important; | |
| } | |
| """ | |
| # Gradio Interface | |
| demo = gr.ChatInterface( | |
| fn=predict, | |
| title="🌟 Sam-3: The Reasoning AI", | |
| description=""" | |
| Sam-3 doesn’t just answer — it **thinks first**. | |
| Watch its internal reasoning unfold in real time — step by step, clearly shown. | |
| No guessing. No fluff. Just pure deduction. | |
| Try asking: | |
| → “Why does a mirror reverse left and right but not up and down?” | |
| → “If I have 3 apples and give away half, then buy 5 more, how many do I have?” | |
| → “Explain quantum entanglement like I’m 10.” | |
| → “What’s wrong with this argument: ‘All birds fly; penguins are birds; therefore penguins can fly’?” | |
| """, | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="blue" | |
| ), | |
| chatbot=gr.Chatbot( | |
| label="Sam-3 🤔", | |
| bubble_full_width=False, | |
| height=600, | |
| avatar_images=( | |
| "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-bot.jpg", | |
| "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-user.jpg" | |
| ) | |
| ), | |
| examples=[ | |
| "What is the capital of France?", | |
| "Explain why the sky is blue.", | |
| "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?", | |
| "What are the ethical implications of AI making medical diagnoses?" | |
| ], | |
| css=CSS, | |
| cache_examples=False | |
| ).launch( | |
| show_api=True | |
| ) |