Spaces:
Runtime error
Runtime error
| # ------------------------------- | |
| # app.py β Sam-3.5: The Reasoning AI (Updated Architecture) | |
| # ------------------------------- | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| from safetensors.torch import load_file | |
| from transformers import AutoTokenizer | |
| from dataclasses import dataclass | |
| from typing import Dict, List | |
| import gradio as gr | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| # ------------------------------- | |
| # 1) Configuration & Special Tokens | |
| # ------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| SPECIAL_TOKENS = { | |
| "bos": "<|bos|>", | |
| "eot": "<|eot|>", | |
| "user": "<|user|>", | |
| "assistant": "<|assistant|>", | |
| "system": "<|system|>", | |
| "think": "<|think|>", # Keep this for reasoning display | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())}) | |
| SPECIAL_IDS = {k: tokenizer.convert_tokens_to_ids(v) for k, v in SPECIAL_TOKENS.items()} | |
| EOT_ID = SPECIAL_IDS.get("eot", tokenizer.eos_token_id) | |
| THINK_ID = SPECIAL_IDS.get("think") | |
| assert THINK_ID is not None, "Tokenizer must include <|think|> token" | |
| MAX_LENGTH = 1024 | |
| # ------------------------------- | |
| # 2) Model Architecture (Sam-3.5) | |
| # ------------------------------- | |
| class RMSNorm(nn.Module): | |
| def __init__(self, d, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(d)) | |
| def forward(self, x): | |
| return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() | |
| class MHA(nn.Module): | |
| def __init__(self, d_model, n_heads, dropout=0.0): | |
| super().__init__() | |
| if d_model % n_heads != 0: | |
| raise ValueError("d_model must be divisible by n_heads") | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, attn_mask=None): | |
| B, T, C = x.shape | |
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| is_causal=True, | |
| dropout_p=self.dropout.p if self.training else 0.0 | |
| ) | |
| return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C)) | |
| class SwiGLU(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.0): | |
| super().__init__() | |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) | |
| self.w2 = nn.Linear(d_model, d_ff, bias=False) | |
| self.w3 = nn.Linear(d_ff, d_model, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.w3(self.dropout(F.silu(self.w1(x)) * self.w2(x))) | |
| class Block(nn.Module): | |
| def __init__(self, d_model, n_heads, ff_mult, dropout=0.0): | |
| super().__init__() | |
| self.norm1 = RMSNorm(d_model) | |
| self.attn = MHA(d_model, n_heads, dropout=dropout) | |
| self.norm2 = RMSNorm(d_model) | |
| self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x, attn_mask=None): | |
| x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask)) | |
| x = x + self.drop(self.ff(self.norm2(x))) | |
| return x | |
| class Sam3Config: | |
| vocab_size: int | |
| d_model: int = 468 | |
| n_layers: int = 14 | |
| n_heads: int = 6 | |
| ff_mult: float = 4.0 | |
| dropout: float = 0.1 | |
| input_modality: str = "text" | |
| head_type: str = "causal_lm" | |
| version: str = "0.1" | |
| class Sam3(nn.Module): | |
| def __init__(self, config: Sam3Config): | |
| super().__init__() | |
| self.config = config | |
| self.embed = nn.Embedding(config.vocab_size, config.d_model) | |
| self.blocks = nn.ModuleList([ | |
| Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) | |
| for _ in range(config.n_layers) | |
| ]) | |
| self.norm = RMSNorm(config.d_model) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| self.lm_head.weight = self.embed.weight # Weight tying | |
| def forward(self, input_ids, attention_mask=None): | |
| x = self.embed(input_ids) | |
| for blk in self.blocks: | |
| x = blk(x, attn_mask=attention_mask) | |
| x = self.norm(x) | |
| return self.lm_head(x) | |
| # ------------------------------- | |
| # 3) Load Model from Hugging Face Hub | |
| # ------------------------------- | |
| def load_sam3_model_from_hf(repo_id: str, filename: str = "sam3-epoch1-best.safetensors"): | |
| print(f"π₯ Loading config and weights from: {repo_id}") | |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
| weights_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
| with open(config_path, "r") as f: | |
| config_dict = json.load(f) | |
| # Ensure vocab_size matches tokenizer after adding special tokens | |
| config_dict["vocab_size"] = len(tokenizer) | |
| config = Sam3Config(**config_dict) | |
| model = Sam3(config).to(device) | |
| state_dict = load_file(weights_path) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| print(f"β Model loaded successfully from Hugging Face Hub: {repo_id}") | |
| return model | |
| # Load model | |
| model = load_sam3_model_from_hf("Smilyai-labs/Sam-3.5-1") | |
| # ------------------------------- | |
| # 4) Sampling Function (Enhanced from your original) | |
| # ------------------------------- | |
| def sample_next_token( | |
| logits, | |
| past_tokens, | |
| temperature=0.8, | |
| top_k=60, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| max_repeat=5, | |
| no_repeat_ngram_size=3 | |
| ): | |
| if logits.dim() == 3: | |
| logits = logits[:, -1, :].clone() | |
| else: | |
| logits = logits.clone() | |
| batch_size, vocab_size = logits.size(0), logits.size(1) | |
| orig_logits = logits.clone() | |
| if temperature != 1.0: | |
| logits = logits / float(temperature) | |
| past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens) | |
| for token_id in set(past_list): | |
| if 0 <= token_id < vocab_size: | |
| logits[:, token_id] /= repetition_penalty | |
| if len(past_list) >= max_repeat: | |
| last_token = past_list[-1] | |
| count = 1 | |
| for i in reversed(past_list[:-1]): | |
| if i == last_token: | |
| count += 1 | |
| else: | |
| break | |
| if count >= max_repeat: | |
| if 0 <= last_token < vocab_size: | |
| logits[:, last_token] = -float("inf") | |
| if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size: | |
| for i in range(len(past_list) - no_repeat_ngram_size + 1): | |
| ngram = tuple(past_list[i : i + no_repeat_ngram_size]) | |
| if len(past_list) >= no_repeat_ngram_size - 1: | |
| prefix = tuple(past_list[-(no_repeat_ngram_size - 1):]) | |
| for token_id in range(vocab_size): | |
| if tuple(list(prefix) + [token_id]) == ngram and 0 <= token_id < vocab_size: | |
| logits[:, token_id] = -float("inf") | |
| if top_k is not None and top_k > 0: | |
| tk = min(max(1, int(top_k)), vocab_size) | |
| topk_vals, topk_indices = torch.topk(logits, tk, dim=-1) | |
| min_topk = topk_vals[:, -1].unsqueeze(-1) | |
| logits[logits < min_topk] = -float("inf") | |
| if top_p is not None and 0.0 < top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| for b in range(batch_size): | |
| sorted_mask = cumulative_probs[b] > top_p | |
| if sorted_mask.numel() > 0: | |
| sorted_mask[0] = False | |
| tokens_to_remove = sorted_indices[b][sorted_mask] | |
| logits[b, tokens_to_remove] = -float("inf") | |
| for b in range(batch_size): | |
| if torch.isneginf(logits[b]).all(): | |
| logits[b] = orig_logits[b] | |
| probs = F.softmax(logits, dim=-1) | |
| if torch.isnan(probs).any(): | |
| probs = torch.ones_like(logits) / logits.size(1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| return next_token.to(device) | |
| # ------------------------------- | |
| # 5) Gradio Chat Interface β WITH STYLED THINKING STEPS | |
| # ------------------------------- | |
| def predict(message, history): | |
| # Build prompt | |
| chat_history = [] | |
| for human, assistant in history: | |
| chat_history.append(f"{SPECIAL_TOKENS['user']} {human} {SPECIAL_TOKENS['eot']}") | |
| if assistant: | |
| # Assistant responses may contain <|think|>...<|eot|> blocks β we don't reconstruct them here | |
| chat_history.append(f"{SPECIAL_TOKENS['assistant']} {assistant} {SPECIAL_TOKENS['eot']}") | |
| chat_history.append(f"{SPECIAL_TOKENS['user']} {message} {SPECIAL_TOKENS['eot']}") | |
| system_prompt = "You are Sam-3.5, an advanced reasoning AI. You think step-by-step, analyze deeply, and respond with precision. You do not guess β you deduce. Avoid medical or legal advice." | |
| prompt = f"{SPECIAL_TOKENS['system']} {system_prompt} {SPECIAL_TOKENS['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS['assistant']} {SPECIAL_TOKENS['think']}" | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_LENGTH).to(device) | |
| input_ids = inputs["input_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| generated_text = "" | |
| thinking_mode = False | |
| thinking_buffer = "" | |
| for _ in range(256): | |
| with torch.no_grad(): | |
| logits = model(input_ids, attention_mask=attention_mask) | |
| next_token = sample_next_token( | |
| logits, | |
| input_ids[0], | |
| temperature=0.4, | |
| top_k=50, | |
| top_p=0.9, | |
| repetition_penalty=1.1 | |
| ) | |
| token_id = int(next_token.squeeze().item()) | |
| token_str = tokenizer.decode([token_id], skip_special_tokens=False) | |
| # Append to sequence | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device, dtype=attention_mask.dtype)], dim=1) | |
| # Handle thinking mode | |
| if not thinking_mode and token_str.strip() == "<|think|>": | |
| thinking_mode = True | |
| thinking_buffer = "" | |
| continue | |
| if thinking_mode: | |
| if token_str.strip() == "<|eot|>": | |
| # End thinking block β yield styled output | |
| thinking_buffer = thinking_buffer.strip() | |
| if thinking_buffer: | |
| yield f"<div style='background-color:#f8f9fa; padding:12px; border-left:4px solid #007bff; border-radius:0 8px 8px 0; margin:10px 0; font-style:italic; color:#495057; font-size:0.95em;'>π‘ <strong>Thinking:</strong> {thinking_buffer}</div>" | |
| thinking_mode = False | |
| continue | |
| else: | |
| thinking_buffer += token_str | |
| continue | |
| # Normal output | |
| if not thinking_mode: | |
| # Clean token for display (optional: handle GPT-2 space artifacts) | |
| clean_token = token_str.replace('Δ ', ' ').replace('Δ', '\n') | |
| generated_text += clean_token | |
| yield generated_text | |
| # Stop if final EOT (outside thinking block) | |
| if token_id == EOT_ID and not thinking_mode: | |
| break | |
| # ------------------------------- | |
| # 6) Launch Gradio Interface | |
| # ------------------------------- | |
| CSS = """ | |
| .gradio-container .message-bubble { | |
| border-radius: 12px !important; | |
| padding: 10px 14px !important; | |
| font-size: 16px !important; | |
| } | |
| .gradio-container .message-bubble.user { | |
| background-color: #007bff !important; | |
| color: white !important; | |
| } | |
| .gradio-container .message-bubble.assistant { | |
| background-color: #f8f9fa !important; | |
| color: #212529 !important; | |
| border: 1px solid #e9ecef; | |
| } | |
| """ | |
| demo = gr.ChatInterface( | |
| fn=predict, | |
| title="π§ Sam-3.5: The Reasoning AI", | |
| description=""" | |
| Sam-3.5 doesnβt just answer β it **thinks first**. | |
| Watch its internal reasoning unfold in real time β step by step, clearly shown. | |
| No guessing. No fluff. Just pure deduction. | |
| Try asking: | |
| β βWhy does a mirror reverse left and right but not up and down?β | |
| β βIf I have 3 apples and give away half, then buy 5 more, how many do I have?β | |
| β βExplain quantum entanglement like Iβm 10.β | |
| β βWhatβs wrong with this argument: βAll birds fly; penguins are birds; therefore penguins can flyβ?β | |
| """, | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="blue" | |
| ), | |
| chatbot=gr.Chatbot( | |
| label="Sam-3.5 π€", | |
| bubble_full_width=False, | |
| height=600, | |
| avatar_images=( | |
| "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-bot.jpg", | |
| "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-user.jpg" | |
| ) | |
| ), | |
| examples=[ | |
| "What is the capital of France?", | |
| "Explain why the sky is blue.", | |
| "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?", | |
| "What are the ethical implications of AI making medical diagnoses?" | |
| ], | |
| css=CSS, | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=True) |