import gradio as gr import jax import jax.numpy as jnp from jax import random import flax.linen as nn from tokenizers import Tokenizer from safetensors.flax import load_file import json import os from typing import Any, Optional import numpy as np # ============================================================================== # MODEL ARCHITECTURE (from your training code) # ============================================================================== class RMSNorm(nn.Module): epsilon: float = 1e-5 dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x): x = x.astype(jnp.float32) scale = self.param('scale', nn.initializers.ones, (x.shape[-1],)) variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) x = x * jax.lax.rsqrt(variance + self.epsilon) * scale return x.astype(self.dtype) def precompute_yarn_freqs(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0, alpha: float = 1.0, beta: float = 32.0, dtype=jnp.bfloat16): freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim)) if scale > 1.0: def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * jnp.log(max_position_embeddings / (num_rotations * 2 * jnp.pi))) / (2 * jnp.log(base)) def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): low = jnp.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = jnp.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return jnp.maximum(low, 0).astype(jnp.int32), jnp.minimum(high, dim - 1).astype(jnp.int32) def yarn_linear_ramp_mask(min_val, max_val, dim): if min_val == max_val: max_val += 0.001 linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val) return jnp.clip(linear_func, 0, 1) low, high = yarn_find_correction_range(beta, alpha, dim, theta, int(end * scale)) inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) freqs = freqs / ((1 - inv_freq_mask) * (scale - 1) + 1) t = jnp.arange(end, dtype=jnp.float32) freqs = jnp.outer(t, freqs) mscale = 1.0 if scale > 1.0: mscale = 0.1 * 1.0 * jnp.log(scale) + 1.0 cos = jnp.cos(freqs) * mscale sin = jnp.sin(freqs) * mscale return jnp.concatenate([cos, sin], axis=-1).astype(dtype), mscale def apply_rotary_emb(xq, xk, freqs_cis, mscale=1.0): def rotate_half(x): x1, x2 = jnp.split(x, 2, axis=-1) return jnp.concatenate([-x2, x1], axis=-1) seq_len = xq.shape[2] head_dim = xq.shape[3] freqs = freqs_cis[:seq_len, :] half_dim = head_dim // 2 cos = freqs[:, :half_dim] sin = freqs[:, half_dim:] cos = jnp.repeat(cos, 2, axis=-1) sin = jnp.repeat(sin, 2, axis=-1) cos = cos[None, None, :, :] sin = sin[None, None, :, :] xq_out = (xq * cos) + (rotate_half(xq) * sin) xk_out = (xk * cos) + (rotate_half(xk) * sin) return xq_out, xk_out class DepthwiseSeparableConv1D(nn.Module): channels: int kernel_size: int = 3 dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x): depthwise = nn.Conv( features=self.channels, kernel_size=(self.kernel_size,), feature_group_count=self.channels, padding='SAME', use_bias=False, dtype=self.dtype, name='depthwise' )(x) pointwise = nn.Conv( features=self.channels, kernel_size=(1,), use_bias=False, dtype=self.dtype, name='pointwise' )(depthwise) return pointwise class LocalContextCNN(nn.Module): d_model: int dropout: float dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x, training: bool = False): conv3 = DepthwiseSeparableConv1D(self.d_model, 3, self.dtype, name='conv3')(x) conv5 = DepthwiseSeparableConv1D(self.d_model, 5, self.dtype, name='conv5')(x) conv7 = DepthwiseSeparableConv1D(self.d_model, 7, self.dtype, name='conv7')(x) gate = nn.Dense(self.d_model * 3, dtype=self.dtype, name='fusion_gate')(x) gate = nn.sigmoid(gate) g3, g5, g7 = jnp.split(gate, 3, axis=-1) out = g3 * conv3 + g5 * conv5 + g7 * conv7 scale = self.param('layer_scale', nn.initializers.constant(1e-6), (self.d_model,)) out = out * scale return nn.Dropout(self.dropout, deterministic=not training)(out) class MinGRUCell(nn.Module): hidden_size: int dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x, h): z = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='gate')(x) h_tilde = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='candidate')(x) z = nn.sigmoid(z) h_tilde = nn.tanh(h_tilde) h_new = (1 - z) * h + z * h_tilde return h_new class BidirectionalMinGRU(nn.Module): hidden_size: int dropout: float dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x, training: bool = False): batch_size, seq_len, d_model = x.shape x_proj = nn.Dense(self.hidden_size, dtype=self.dtype, name='input_proj')(x) class ScanRNNCell(nn.Module): hidden_size: int dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, h, x_t): cell = MinGRUCell(self.hidden_size, dtype=self.dtype) h_new = cell(x_t, h) return h_new, h_new ForwardScanner = nn.scan( ScanRNNCell, variable_broadcast='params', split_rngs={'params': False}, in_axes=1, out_axes=1 ) h0_forward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype) _, h_forward = ForwardScanner( hidden_size=self.hidden_size, dtype=self.dtype, name='forward_cell' )(h0_forward, x_proj) BackwardScanner = nn.scan( ScanRNNCell, variable_broadcast='params', split_rngs={'params': False}, in_axes=1, out_axes=1 ) h0_backward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype) x_proj_reversed = jnp.flip(x_proj, axis=1) _, h_backward = BackwardScanner( hidden_size=self.hidden_size, dtype=self.dtype, name='backward_cell' )(h0_backward, x_proj_reversed) h_backward = jnp.flip(h_backward, axis=1) h_bi = jnp.concatenate([h_forward, h_backward], axis=-1) out = nn.Dense(d_model, dtype=self.dtype, name='output_proj')(h_bi) scale = self.param('layer_scale', nn.initializers.constant(1e-6), (d_model,)) out = out * scale return nn.Dropout(self.dropout, deterministic=not training)(out) class GroupedQueryAttention(nn.Module): d_model: int n_heads: int n_kv_heads: int dropout: float freqs_cis: jnp.ndarray yarn_mscale: float alibi_bias: Optional[jnp.ndarray] alibi_weight: float dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x, mask, training: bool = False): B, T, D = x.shape head_dim = self.d_model // self.n_heads n_rep = self.n_heads // self.n_kv_heads q = nn.Dense(self.d_model, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='q_proj')(x) kv_dim = self.d_model * self.n_kv_heads // self.n_heads k = nn.Dense(kv_dim, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='k_proj')(x) v = nn.Dense(kv_dim, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='v_proj')(x) q = q.reshape(B, T, self.n_heads, head_dim).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3) k = jnp.repeat(k, n_rep, axis=1) v = jnp.repeat(v, n_rep, axis=1) q, k = apply_rotary_emb(q, k, self.freqs_cis, self.yarn_mscale) scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(head_dim).astype(self.dtype) if self.alibi_bias is not None: scores = scores * (1 - self.alibi_weight) alibi = self.alibi_bias[:, :, :T, :T] scores = scores + (alibi * self.alibi_weight) scores = scores + mask attn_weights = nn.softmax(scores, axis=-1) attn_weights = nn.Dropout(self.dropout, deterministic=not training)(attn_weights) attn_out = jnp.matmul(attn_weights, v) attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, T, D) out = nn.Dense(self.d_model, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='o_proj')(attn_out) return nn.Dropout(self.dropout, deterministic=not training)(out) class SwiGLU(nn.Module): d_model: int ff_dim: int dropout: float dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x, training: bool = False): gate = nn.Dense(self.ff_dim, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='gate_proj')(x) up = nn.Dense(self.ff_dim, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='up_proj')(x) hidden = nn.silu(gate) * up out = nn.Dense(self.d_model, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='down_proj')(hidden) return nn.Dropout(self.dropout, deterministic=not training)(out) class HybridTransformerBlock(nn.Module): d_model: int n_heads: int n_kv_heads: int ff_dim: int dropout: float freqs_cis: jnp.ndarray yarn_mscale: float alibi_bias: Optional[jnp.ndarray] alibi_weight: float layer_idx: int layer_drop_prob: float = 0.0 use_cnn: bool = True use_rnn: bool = True rnn_hidden: int = 512 dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, x, mask, training: bool = False): scale = 1.0 if self.use_rnn: h_rnn = RMSNorm(dtype=self.dtype, name='rnn_norm')(x) h_rnn = BidirectionalMinGRU( self.rnn_hidden, self.dropout, dtype=self.dtype, name='bidirectional_rnn' )(h_rnn, training) x = x + h_rnn * scale if self.use_cnn: h_cnn = RMSNorm(dtype=self.dtype, name='cnn_norm')(x) h_cnn = LocalContextCNN( self.d_model, self.dropout, dtype=self.dtype, name='local_cnn' )(h_cnn, training) x = x + h_cnn * scale h = RMSNorm(dtype=self.dtype, name='attn_norm')(x) h = GroupedQueryAttention( self.d_model, self.n_heads, self.n_kv_heads, self.dropout, self.freqs_cis, self.yarn_mscale, self.alibi_bias, self.alibi_weight, dtype=self.dtype, name='attn' )(h, mask, training) x = x + h * scale h = RMSNorm(dtype=self.dtype, name='ffn_norm')(x) h = SwiGLU(self.d_model, self.ff_dim, self.dropout, dtype=self.dtype, name='ffn')(h, training) x = x + h * scale return x class SAM1HybridModel(nn.Module): vocab_size: int d_model: int n_layers: int n_heads: int n_kv_heads: int ff_dim: int max_len: int dropout: float = 0.1 layer_drop_prob: float = 0.05 rope_theta: float = 10000.0 yarn_scale: float = 1.0 yarn_alpha: float = 1.0 yarn_beta: float = 32.0 use_alibi: bool = False alibi_weight: float = 0.3 use_cnn: bool = True use_rnn: bool = True rnn_hidden: int = 384 dtype: Any = jnp.bfloat16 @nn.compact def __call__(self, input_ids, training: bool = False): head_dim = self.d_model // self.n_heads freqs_cis, yarn_mscale = precompute_yarn_freqs( head_dim, self.max_len, self.rope_theta, self.yarn_scale, self.yarn_alpha, self.yarn_beta, self.dtype ) alibi_bias = None x = nn.Embed(self.vocab_size, self.d_model, embedding_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='embed_tokens')(input_ids) seq_len = input_ids.shape[1] mask = jnp.tril(jnp.ones((seq_len, seq_len))) mask = jnp.where(mask == 0, -1e9, 0.0).astype(self.dtype) for i in range(self.n_layers): use_cnn_layer = self.use_cnn and (i % 3 == 0) use_rnn_layer = self.use_rnn and (i % 4 == 0) x = HybridTransformerBlock( self.d_model, self.n_heads, self.n_kv_heads, self.ff_dim, self.dropout, freqs_cis, yarn_mscale, alibi_bias, self.alibi_weight, i, self.layer_drop_prob, use_cnn_layer, use_rnn_layer, self.rnn_hidden, dtype=self.dtype, name=f'layers_{i}' )(x, mask, training) x = RMSNorm(dtype=self.dtype, name='norm')(x) logits = nn.Dense(self.vocab_size, use_bias=False, kernel_init=nn.initializers.normal(stddev=0.02), dtype=self.dtype, name='lm_head')(x) return logits # ============================================================================== # MODEL LOADING & GENERATION # ============================================================================== class ModelWrapper: def __init__(self, model_path: str): print("🔧 Loading model...") # Load config with open(os.path.join(model_path, "config.json"), 'r') as f: config = json.load(f) self.vocab_size = config['vocab_size'] self.d_model = config['d_model'] self.n_layers = config['n_layers'] self.n_heads = config['n_heads'] self.n_kv_heads = config['n_kv_heads'] self.ff_dim = int(self.d_model * 2.5) self.max_len = config['max_len'] self.use_cnn = config.get('use_cnn', True) self.use_rnn = config.get('use_rnn', True) self.rnn_hidden = config.get('rnn_hidden', 384) # Load tokenizer self.tokenizer = Tokenizer.from_file(os.path.join(model_path, "tokenizer.json")) # Initialize model self.model = SAM1HybridModel( vocab_size=self.vocab_size, d_model=self.d_model, n_layers=self.n_layers, n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, ff_dim=self.ff_dim, max_len=self.max_len, use_cnn=self.use_cnn, use_rnn=self.use_rnn, rnn_hidden=self.rnn_hidden, dtype=jnp.bfloat16 ) # Load weights flat_params = load_file(os.path.join(model_path, "model.safetensors")) # Unflatten parameters def unflatten_dict(flat_dict, sep='.'): result = {} for key, value in flat_dict.items(): parts = key.split(sep) d = result for part in parts[:-1]: if part not in d: d[part] = {} d = d[part] d[parts[-1]] = jnp.array(value) return result self.params = {'params': unflatten_dict(flat_params)} print(f"✅ Model loaded: {self.d_model}d × {self.n_layers}L × {self.n_heads}H") def generate_stream(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9): """Generator that yields tokens one at a time for streaming""" # Format prompt in ChatML format if not prompt.startswith("<|im_start|>"): prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" else: if "<|im_start|>assistant" not in prompt: prompt = prompt + "<|im_start|>assistant\n" # Tokenize encoding = self.tokenizer.encode(prompt) input_ids = jnp.array(encoding.ids)[None, :] if input_ids.shape[1] > self.max_len: input_ids = input_ids[:, -self.max_len:] rng = random.PRNGKey(42) generated_ids = input_ids response_text = "" # Generate tokens for _ in range(max_new_tokens): logits = self.model.apply(self.params, generated_ids, training=False) next_logits = logits[0, -1, :] / temperature # Top-k filtering top_k_logits, top_k_indices = jax.lax.top_k(next_logits, top_k) # Top-p (nucleus) filtering sorted_logits = jnp.sort(top_k_logits)[::-1] sorted_indices = jnp.argsort(top_k_logits)[::-1] cumsum_probs = jnp.cumsum(nn.softmax(sorted_logits)) mask = cumsum_probs <= top_p mask = jnp.concatenate([jnp.array([True]), mask[:-1]]) filtered_logits = jnp.where(mask, sorted_logits, -1e9) # Sample rng, sample_rng = random.split(rng) next_token_idx = random.categorical(sample_rng, filtered_logits) next_token = top_k_indices[sorted_indices[next_token_idx]][None, None] generated_ids = jnp.concatenate([generated_ids, next_token], axis=1) # Decode the new token token_id = int(next_token[0, 0]) # Stop on EOS or end tokens if token_id in [ self.tokenizer.token_to_id("<|endoftext|>"), self.tokenizer.token_to_id("<|im_end|>") ]: break # Decode and yield the token token_text = self.tokenizer.decode([token_id]) response_text += token_text yield response_text def generate(self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9): """Non-streaming generation (returns full response)""" response = "" for partial_response in self.generate_stream(prompt, max_new_tokens, temperature, top_k, top_p): response = partial_response return response # ============================================================================== # GRADIO INTERFACE # ============================================================================== # Download and load model from HuggingFace Hub from huggingface_hub import snapshot_download print("📥 Downloading model from HuggingFace Hub...") model_path = snapshot_download( repo_id="Smilyai-labs/MixSam-exp", repo_type="model", local_dir="./model_cache" ) print(f"✅ Model downloaded to: {model_path}") # Load model model = ModelWrapper(model_path) def chat_fn(message, history, temperature, top_k, top_p, max_tokens): # Build conversation context in ChatML format conversation = "" for user_msg, bot_msg in history: conversation += f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{bot_msg}<|im_end|>\n" # Add current message conversation += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" # Stream response token by token partial_response = "" for response in model.generate_stream( conversation, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p ): partial_response = response # Yield the full history + current streaming message yield history + [[message, partial_response]] # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🤖 SAM1 Hybrid Chat ### Transformer + CNN + RNN Architecture Chat with SAM1, a custom hybrid language model combining: - 🔷 **Transformer** attention (GQA + YARN + RoPE) - 🔶 **CNN** for local context (multi-scale convolutions) - 🔵 **RNN** for sequential modeling (bidirectional MinGRU) """) chatbot = gr.Chatbot(height=500, show_copy_button=True) with gr.Row(): msg = gr.Textbox( placeholder="Type your message here...", show_label=False, scale=4 ) submit = gr.Button("Send", scale=1, variant="primary") with gr.Accordion("⚙️ Generation Settings", open=False): with gr.Row(): temperature = gr.Slider(0.1, 2.0, value=0.8, label="Temperature", step=0.1) top_k = gr.Slider(1, 100, value=50, label="Top-K", step=1) with gr.Row(): top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top-P", step=0.05) max_tokens = gr.Slider(50, 500, value=200, label="Max Tokens", step=10) clear = gr.Button("🗑️ Clear Chat") # Event handlers msg.submit( chat_fn, inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens], outputs=chatbot ).then(lambda: "", None, msg) submit.click( chat_fn, inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens], outputs=chatbot ).then(lambda: "", None, msg) clear.click(lambda: None, None, chatbot, queue=False) gr.Markdown(""" --- **Model Details:** - Architecture: SAM1 Hybrid (Custom) - Parameters: ~600M - Context Length: 1024 tokens - Format: `User: {query} Sam: {response}` (no newlines) **Tips:** - Lower temperature (0.3-0.5) for focused responses - Higher temperature (0.8-1.2) for creative responses - Adjust top-k/top-p for response diversity """) if __name__ == "__main__": demo.queue().launch()