Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import jax | |
| import jax.numpy as jnp | |
| from jax import random | |
| import flax.linen as nn | |
| from tokenizers import Tokenizer | |
| from safetensors.flax import load_file | |
| import json | |
| import os | |
| from typing import Any, Optional | |
| import numpy as np | |
| # ============================================================================== | |
| # MODEL ARCHITECTURE (from your training code) | |
| # ============================================================================== | |
| class RMSNorm(nn.Module): | |
| epsilon: float = 1e-5 | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x): | |
| x = x.astype(jnp.float32) | |
| scale = self.param('scale', nn.initializers.ones, (x.shape[-1],)) | |
| variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True) | |
| x = x * jax.lax.rsqrt(variance + self.epsilon) * scale | |
| return x.astype(self.dtype) | |
| def precompute_yarn_freqs(dim: int, end: int, theta: float = 10000.0, | |
| scale: float = 1.0, alpha: float = 1.0, | |
| beta: float = 32.0, dtype=jnp.bfloat16): | |
| freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim)) | |
| if scale > 1.0: | |
| def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): | |
| return (dim * jnp.log(max_position_embeddings / (num_rotations * 2 * jnp.pi))) / (2 * jnp.log(base)) | |
| def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): | |
| low = jnp.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) | |
| high = jnp.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) | |
| return jnp.maximum(low, 0).astype(jnp.int32), jnp.minimum(high, dim - 1).astype(jnp.int32) | |
| def yarn_linear_ramp_mask(min_val, max_val, dim): | |
| if min_val == max_val: | |
| max_val += 0.001 | |
| linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val) | |
| return jnp.clip(linear_func, 0, 1) | |
| low, high = yarn_find_correction_range(beta, alpha, dim, theta, int(end * scale)) | |
| inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) | |
| freqs = freqs / ((1 - inv_freq_mask) * (scale - 1) + 1) | |
| t = jnp.arange(end, dtype=jnp.float32) | |
| freqs = jnp.outer(t, freqs) | |
| mscale = 1.0 | |
| if scale > 1.0: | |
| mscale = 0.1 * 1.0 * jnp.log(scale) + 1.0 | |
| cos = jnp.cos(freqs) * mscale | |
| sin = jnp.sin(freqs) * mscale | |
| return jnp.concatenate([cos, sin], axis=-1).astype(dtype), mscale | |
| def apply_rotary_emb(xq, xk, freqs_cis, mscale=1.0): | |
| def rotate_half(x): | |
| x1, x2 = jnp.split(x, 2, axis=-1) | |
| return jnp.concatenate([-x2, x1], axis=-1) | |
| seq_len = xq.shape[2] | |
| head_dim = xq.shape[3] | |
| freqs = freqs_cis[:seq_len, :] | |
| half_dim = head_dim // 2 | |
| cos = freqs[:, :half_dim] | |
| sin = freqs[:, half_dim:] | |
| cos = jnp.repeat(cos, 2, axis=-1) | |
| sin = jnp.repeat(sin, 2, axis=-1) | |
| cos = cos[None, None, :, :] | |
| sin = sin[None, None, :, :] | |
| xq_out = (xq * cos) + (rotate_half(xq) * sin) | |
| xk_out = (xk * cos) + (rotate_half(xk) * sin) | |
| return xq_out, xk_out | |
| class DepthwiseSeparableConv1D(nn.Module): | |
| channels: int | |
| kernel_size: int = 3 | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x): | |
| depthwise = nn.Conv( | |
| features=self.channels, | |
| kernel_size=(self.kernel_size,), | |
| feature_group_count=self.channels, | |
| padding='SAME', | |
| use_bias=False, | |
| dtype=self.dtype, | |
| name='depthwise' | |
| )(x) | |
| pointwise = nn.Conv( | |
| features=self.channels, | |
| kernel_size=(1,), | |
| use_bias=False, | |
| dtype=self.dtype, | |
| name='pointwise' | |
| )(depthwise) | |
| return pointwise | |
| class LocalContextCNN(nn.Module): | |
| d_model: int | |
| dropout: float | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, training: bool = False): | |
| conv3 = DepthwiseSeparableConv1D(self.d_model, 3, self.dtype, name='conv3')(x) | |
| conv5 = DepthwiseSeparableConv1D(self.d_model, 5, self.dtype, name='conv5')(x) | |
| conv7 = DepthwiseSeparableConv1D(self.d_model, 7, self.dtype, name='conv7')(x) | |
| gate = nn.Dense(self.d_model * 3, dtype=self.dtype, name='fusion_gate')(x) | |
| gate = nn.sigmoid(gate) | |
| g3, g5, g7 = jnp.split(gate, 3, axis=-1) | |
| out = g3 * conv3 + g5 * conv5 + g7 * conv7 | |
| scale = self.param('layer_scale', nn.initializers.constant(1e-6), (self.d_model,)) | |
| out = out * scale | |
| return nn.Dropout(self.dropout, deterministic=not training)(out) | |
| class MinGRUCell(nn.Module): | |
| hidden_size: int | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, h): | |
| z = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='gate')(x) | |
| h_tilde = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='candidate')(x) | |
| z = nn.sigmoid(z) | |
| h_tilde = nn.tanh(h_tilde) | |
| h_new = (1 - z) * h + z * h_tilde | |
| return h_new | |
| class BidirectionalMinGRU(nn.Module): | |
| hidden_size: int | |
| dropout: float | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, training: bool = False): | |
| batch_size, seq_len, d_model = x.shape | |
| x_proj = nn.Dense(self.hidden_size, dtype=self.dtype, name='input_proj')(x) | |
| class ScanRNNCell(nn.Module): | |
| hidden_size: int | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, h, x_t): | |
| cell = MinGRUCell(self.hidden_size, dtype=self.dtype) | |
| h_new = cell(x_t, h) | |
| return h_new, h_new | |
| ForwardScanner = nn.scan( | |
| ScanRNNCell, | |
| variable_broadcast='params', | |
| split_rngs={'params': False}, | |
| in_axes=1, | |
| out_axes=1 | |
| ) | |
| h0_forward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype) | |
| _, h_forward = ForwardScanner( | |
| hidden_size=self.hidden_size, | |
| dtype=self.dtype, | |
| name='forward_cell' | |
| )(h0_forward, x_proj) | |
| BackwardScanner = nn.scan( | |
| ScanRNNCell, | |
| variable_broadcast='params', | |
| split_rngs={'params': False}, | |
| in_axes=1, | |
| out_axes=1 | |
| ) | |
| h0_backward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype) | |
| x_proj_reversed = jnp.flip(x_proj, axis=1) | |
| _, h_backward = BackwardScanner( | |
| hidden_size=self.hidden_size, | |
| dtype=self.dtype, | |
| name='backward_cell' | |
| )(h0_backward, x_proj_reversed) | |
| h_backward = jnp.flip(h_backward, axis=1) | |
| h_bi = jnp.concatenate([h_forward, h_backward], axis=-1) | |
| out = nn.Dense(d_model, dtype=self.dtype, name='output_proj')(h_bi) | |
| scale = self.param('layer_scale', nn.initializers.constant(1e-6), (d_model,)) | |
| out = out * scale | |
| return nn.Dropout(self.dropout, deterministic=not training)(out) | |
| class GroupedQueryAttention(nn.Module): | |
| d_model: int | |
| n_heads: int | |
| n_kv_heads: int | |
| dropout: float | |
| freqs_cis: jnp.ndarray | |
| yarn_mscale: float | |
| alibi_bias: Optional[jnp.ndarray] | |
| alibi_weight: float | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, mask, training: bool = False): | |
| B, T, D = x.shape | |
| head_dim = self.d_model // self.n_heads | |
| n_rep = self.n_heads // self.n_kv_heads | |
| q = nn.Dense(self.d_model, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='q_proj')(x) | |
| kv_dim = self.d_model * self.n_kv_heads // self.n_heads | |
| k = nn.Dense(kv_dim, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='k_proj')(x) | |
| v = nn.Dense(kv_dim, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='v_proj')(x) | |
| q = q.reshape(B, T, self.n_heads, head_dim).transpose(0, 2, 1, 3) | |
| k = k.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3) | |
| v = v.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3) | |
| k = jnp.repeat(k, n_rep, axis=1) | |
| v = jnp.repeat(v, n_rep, axis=1) | |
| q, k = apply_rotary_emb(q, k, self.freqs_cis, self.yarn_mscale) | |
| scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(head_dim).astype(self.dtype) | |
| if self.alibi_bias is not None: | |
| scores = scores * (1 - self.alibi_weight) | |
| alibi = self.alibi_bias[:, :, :T, :T] | |
| scores = scores + (alibi * self.alibi_weight) | |
| scores = scores + mask | |
| attn_weights = nn.softmax(scores, axis=-1) | |
| attn_weights = nn.Dropout(self.dropout, deterministic=not training)(attn_weights) | |
| attn_out = jnp.matmul(attn_weights, v) | |
| attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, T, D) | |
| out = nn.Dense(self.d_model, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='o_proj')(attn_out) | |
| return nn.Dropout(self.dropout, deterministic=not training)(out) | |
| class SwiGLU(nn.Module): | |
| d_model: int | |
| ff_dim: int | |
| dropout: float | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, training: bool = False): | |
| gate = nn.Dense(self.ff_dim, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='gate_proj')(x) | |
| up = nn.Dense(self.ff_dim, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='up_proj')(x) | |
| hidden = nn.silu(gate) * up | |
| out = nn.Dense(self.d_model, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='down_proj')(hidden) | |
| return nn.Dropout(self.dropout, deterministic=not training)(out) | |
| class HybridTransformerBlock(nn.Module): | |
| d_model: int | |
| n_heads: int | |
| n_kv_heads: int | |
| ff_dim: int | |
| dropout: float | |
| freqs_cis: jnp.ndarray | |
| yarn_mscale: float | |
| alibi_bias: Optional[jnp.ndarray] | |
| alibi_weight: float | |
| layer_idx: int | |
| layer_drop_prob: float = 0.0 | |
| use_cnn: bool = True | |
| use_rnn: bool = True | |
| rnn_hidden: int = 512 | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, x, mask, training: bool = False): | |
| scale = 1.0 | |
| if self.use_rnn: | |
| h_rnn = RMSNorm(dtype=self.dtype, name='rnn_norm')(x) | |
| h_rnn = BidirectionalMinGRU( | |
| self.rnn_hidden, self.dropout, dtype=self.dtype, name='bidirectional_rnn' | |
| )(h_rnn, training) | |
| x = x + h_rnn * scale | |
| if self.use_cnn: | |
| h_cnn = RMSNorm(dtype=self.dtype, name='cnn_norm')(x) | |
| h_cnn = LocalContextCNN( | |
| self.d_model, self.dropout, dtype=self.dtype, name='local_cnn' | |
| )(h_cnn, training) | |
| x = x + h_cnn * scale | |
| h = RMSNorm(dtype=self.dtype, name='attn_norm')(x) | |
| h = GroupedQueryAttention( | |
| self.d_model, self.n_heads, self.n_kv_heads, self.dropout, | |
| self.freqs_cis, self.yarn_mscale, self.alibi_bias, | |
| self.alibi_weight, dtype=self.dtype, name='attn' | |
| )(h, mask, training) | |
| x = x + h * scale | |
| h = RMSNorm(dtype=self.dtype, name='ffn_norm')(x) | |
| h = SwiGLU(self.d_model, self.ff_dim, self.dropout, | |
| dtype=self.dtype, name='ffn')(h, training) | |
| x = x + h * scale | |
| return x | |
| class SAM1HybridModel(nn.Module): | |
| vocab_size: int | |
| d_model: int | |
| n_layers: int | |
| n_heads: int | |
| n_kv_heads: int | |
| ff_dim: int | |
| max_len: int | |
| dropout: float = 0.1 | |
| layer_drop_prob: float = 0.05 | |
| rope_theta: float = 10000.0 | |
| yarn_scale: float = 1.0 | |
| yarn_alpha: float = 1.0 | |
| yarn_beta: float = 32.0 | |
| use_alibi: bool = False | |
| alibi_weight: float = 0.3 | |
| use_cnn: bool = True | |
| use_rnn: bool = True | |
| rnn_hidden: int = 384 | |
| dtype: Any = jnp.bfloat16 | |
| def __call__(self, input_ids, training: bool = False): | |
| head_dim = self.d_model // self.n_heads | |
| freqs_cis, yarn_mscale = precompute_yarn_freqs( | |
| head_dim, self.max_len, self.rope_theta, | |
| self.yarn_scale, self.yarn_alpha, self.yarn_beta, self.dtype | |
| ) | |
| alibi_bias = None | |
| x = nn.Embed(self.vocab_size, self.d_model, | |
| embedding_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='embed_tokens')(input_ids) | |
| seq_len = input_ids.shape[1] | |
| mask = jnp.tril(jnp.ones((seq_len, seq_len))) | |
| mask = jnp.where(mask == 0, -1e9, 0.0).astype(self.dtype) | |
| for i in range(self.n_layers): | |
| use_cnn_layer = self.use_cnn and (i % 3 == 0) | |
| use_rnn_layer = self.use_rnn and (i % 4 == 0) | |
| x = HybridTransformerBlock( | |
| self.d_model, self.n_heads, self.n_kv_heads, self.ff_dim, | |
| self.dropout, freqs_cis, yarn_mscale, alibi_bias, | |
| self.alibi_weight, i, self.layer_drop_prob, | |
| use_cnn_layer, use_rnn_layer, self.rnn_hidden, | |
| dtype=self.dtype, name=f'layers_{i}' | |
| )(x, mask, training) | |
| x = RMSNorm(dtype=self.dtype, name='norm')(x) | |
| logits = nn.Dense(self.vocab_size, use_bias=False, | |
| kernel_init=nn.initializers.normal(stddev=0.02), | |
| dtype=self.dtype, name='lm_head')(x) | |
| return logits | |
| # ============================================================================== | |
| # MODEL LOADING & GENERATION | |
| # ============================================================================== | |
| class ModelWrapper: | |
| def __init__(self, model_path: str): | |
| print("π§ Loading model...") | |
| # Load config | |
| with open(os.path.join(model_path, "config.json"), 'r') as f: | |
| config = json.load(f) | |
| self.vocab_size = config['vocab_size'] | |
| self.d_model = config['d_model'] | |
| self.n_layers = config['n_layers'] | |
| self.n_heads = config['n_heads'] | |
| self.n_kv_heads = config['n_kv_heads'] | |
| self.ff_dim = int(self.d_model * 2.5) | |
| self.max_len = config['max_len'] | |
| self.use_cnn = config.get('use_cnn', True) | |
| self.use_rnn = config.get('use_rnn', True) | |
| self.rnn_hidden = config.get('rnn_hidden', 384) | |
| # Load tokenizer | |
| self.tokenizer = Tokenizer.from_file(os.path.join(model_path, "tokenizer.json")) | |
| # Initialize model | |
| self.model = SAM1HybridModel( | |
| vocab_size=self.vocab_size, | |
| d_model=self.d_model, | |
| n_layers=self.n_layers, | |
| n_heads=self.n_heads, | |
| n_kv_heads=self.n_kv_heads, | |
| ff_dim=self.ff_dim, | |
| max_len=self.max_len, | |
| use_cnn=self.use_cnn, | |
| use_rnn=self.use_rnn, | |
| rnn_hidden=self.rnn_hidden, | |
| dtype=jnp.bfloat16 | |
| ) | |
| # Load weights | |
| flat_params = load_file(os.path.join(model_path, "model.safetensors")) | |
| # Unflatten parameters | |
| def unflatten_dict(flat_dict, sep='.'): | |
| result = {} | |
| for key, value in flat_dict.items(): | |
| parts = key.split(sep) | |
| d = result | |
| for part in parts[:-1]: | |
| if part not in d: | |
| d[part] = {} | |
| d = d[part] | |
| d[parts[-1]] = jnp.array(value) | |
| return result | |
| self.params = {'params': unflatten_dict(flat_params)} | |
| print(f"β Model loaded: {self.d_model}d Γ {self.n_layers}L Γ {self.n_heads}H") | |
| def generate_stream(self, prompt: str, max_new_tokens: int = 200, | |
| temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9): | |
| """Generator that yields tokens one at a time for streaming""" | |
| # Format prompt in ChatML format | |
| if not prompt.startswith("<|im_start|>"): | |
| prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| else: | |
| if "<|im_start|>assistant" not in prompt: | |
| prompt = prompt + "<|im_start|>assistant\n" | |
| # Tokenize | |
| encoding = self.tokenizer.encode(prompt) | |
| input_ids = jnp.array(encoding.ids)[None, :] | |
| if input_ids.shape[1] > self.max_len: | |
| input_ids = input_ids[:, -self.max_len:] | |
| rng = random.PRNGKey(42) | |
| generated_ids = input_ids | |
| response_text = "" | |
| # Generate tokens | |
| for _ in range(max_new_tokens): | |
| logits = self.model.apply(self.params, generated_ids, training=False) | |
| next_logits = logits[0, -1, :] / temperature | |
| # Top-k filtering | |
| top_k_logits, top_k_indices = jax.lax.top_k(next_logits, top_k) | |
| # Top-p (nucleus) filtering | |
| sorted_logits = jnp.sort(top_k_logits)[::-1] | |
| sorted_indices = jnp.argsort(top_k_logits)[::-1] | |
| cumsum_probs = jnp.cumsum(nn.softmax(sorted_logits)) | |
| mask = cumsum_probs <= top_p | |
| mask = jnp.concatenate([jnp.array([True]), mask[:-1]]) | |
| filtered_logits = jnp.where(mask, sorted_logits, -1e9) | |
| # Sample | |
| rng, sample_rng = random.split(rng) | |
| next_token_idx = random.categorical(sample_rng, filtered_logits) | |
| next_token = top_k_indices[sorted_indices[next_token_idx]][None, None] | |
| generated_ids = jnp.concatenate([generated_ids, next_token], axis=1) | |
| # Decode the new token | |
| token_id = int(next_token[0, 0]) | |
| # Stop on EOS or end tokens | |
| if token_id in [ | |
| self.tokenizer.token_to_id("<|endoftext|>"), | |
| self.tokenizer.token_to_id("<|im_end|>") | |
| ]: | |
| break | |
| # Decode and yield the token | |
| token_text = self.tokenizer.decode([token_id]) | |
| response_text += token_text | |
| yield response_text | |
| def generate(self, prompt: str, max_new_tokens: int = 200, | |
| temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9): | |
| """Non-streaming generation (returns full response)""" | |
| response = "" | |
| for partial_response in self.generate_stream(prompt, max_new_tokens, temperature, top_k, top_p): | |
| response = partial_response | |
| return response | |
| # ============================================================================== | |
| # GRADIO INTERFACE | |
| # ============================================================================== | |
| # Download and load model from HuggingFace Hub | |
| from huggingface_hub import snapshot_download | |
| print("π₯ Downloading model from HuggingFace Hub...") | |
| model_path = snapshot_download( | |
| repo_id="Smilyai-labs/MixSam-exp", | |
| repo_type="model", | |
| local_dir="./model_cache" | |
| ) | |
| print(f"β Model downloaded to: {model_path}") | |
| # Load model | |
| model = ModelWrapper(model_path) | |
| def chat_fn(message, history, temperature, top_k, top_p, max_tokens): | |
| # Build conversation context in ChatML format | |
| conversation = "" | |
| for user_msg, bot_msg in history: | |
| conversation += f"<|im_start|>user\n{user_msg}<|im_end|>\n<|im_start|>assistant\n{bot_msg}<|im_end|>\n" | |
| # Add current message | |
| conversation += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
| # Stream response token by token | |
| partial_response = "" | |
| for response in model.generate_stream( | |
| conversation, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ): | |
| partial_response = response | |
| # Yield the full history + current streaming message | |
| yield history + [[message, partial_response]] | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π€ SAM1 Hybrid Chat | |
| ### Transformer + CNN + RNN Architecture | |
| Chat with SAM1, a custom hybrid language model combining: | |
| - π· **Transformer** attention (GQA + YARN + RoPE) | |
| - πΆ **CNN** for local context (multi-scale convolutions) | |
| - π΅ **RNN** for sequential modeling (bidirectional MinGRU) | |
| """) | |
| chatbot = gr.Chatbot(height=500, show_copy_button=True) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Type your message here...", | |
| show_label=False, | |
| scale=4 | |
| ) | |
| submit = gr.Button("Send", scale=1, variant="primary") | |
| with gr.Accordion("βοΈ Generation Settings", open=False): | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 2.0, value=0.8, label="Temperature", step=0.1) | |
| top_k = gr.Slider(1, 100, value=50, label="Top-K", step=1) | |
| with gr.Row(): | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top-P", step=0.05) | |
| max_tokens = gr.Slider(50, 500, value=200, label="Max Tokens", step=10) | |
| clear = gr.Button("ποΈ Clear Chat") | |
| # Event handlers | |
| msg.submit( | |
| chat_fn, | |
| inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens], | |
| outputs=chatbot | |
| ).then(lambda: "", None, msg) | |
| submit.click( | |
| chat_fn, | |
| inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens], | |
| outputs=chatbot | |
| ).then(lambda: "", None, msg) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| gr.Markdown(""" | |
| --- | |
| **Model Details:** | |
| - Architecture: SAM1 Hybrid (Custom) | |
| - Parameters: ~600M | |
| - Context Length: 1024 tokens | |
| - Format: `User: {query} Sam: {response}` (no newlines) | |
| **Tips:** | |
| - Lower temperature (0.3-0.5) for focused responses | |
| - Higher temperature (0.8-1.2) for creative responses | |
| - Adjust top-k/top-p for response diversity | |
| """) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |