Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import jax
|
| 3 |
+
import jax.numpy as jnp
|
| 4 |
+
from jax import random
|
| 5 |
+
import flax.linen as nn
|
| 6 |
+
from tokenizers import Tokenizer
|
| 7 |
+
from safetensors.flax import load_file
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from typing import Any, Optional
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# ==============================================================================
|
| 14 |
+
# MODEL ARCHITECTURE (from your training code)
|
| 15 |
+
# ==============================================================================
|
| 16 |
+
|
| 17 |
+
class RMSNorm(nn.Module):
|
| 18 |
+
epsilon: float = 1e-5
|
| 19 |
+
dtype: Any = jnp.bfloat16
|
| 20 |
+
|
| 21 |
+
@nn.compact
|
| 22 |
+
def __call__(self, x):
|
| 23 |
+
x = x.astype(jnp.float32)
|
| 24 |
+
scale = self.param('scale', nn.initializers.ones, (x.shape[-1],))
|
| 25 |
+
variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
|
| 26 |
+
x = x * jax.lax.rsqrt(variance + self.epsilon) * scale
|
| 27 |
+
return x.astype(self.dtype)
|
| 28 |
+
|
| 29 |
+
def precompute_yarn_freqs(dim: int, end: int, theta: float = 10000.0,
|
| 30 |
+
scale: float = 1.0, alpha: float = 1.0,
|
| 31 |
+
beta: float = 32.0, dtype=jnp.bfloat16):
|
| 32 |
+
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim))
|
| 33 |
+
|
| 34 |
+
if scale > 1.0:
|
| 35 |
+
def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
| 36 |
+
return (dim * jnp.log(max_position_embeddings / (num_rotations * 2 * jnp.pi))) / (2 * jnp.log(base))
|
| 37 |
+
|
| 38 |
+
def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
| 39 |
+
low = jnp.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
| 40 |
+
high = jnp.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
| 41 |
+
return jnp.maximum(low, 0).astype(jnp.int32), jnp.minimum(high, dim - 1).astype(jnp.int32)
|
| 42 |
+
|
| 43 |
+
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
| 44 |
+
if min_val == max_val:
|
| 45 |
+
max_val += 0.001
|
| 46 |
+
linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val)
|
| 47 |
+
return jnp.clip(linear_func, 0, 1)
|
| 48 |
+
|
| 49 |
+
low, high = yarn_find_correction_range(beta, alpha, dim, theta, int(end * scale))
|
| 50 |
+
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
| 51 |
+
freqs = freqs / ((1 - inv_freq_mask) * (scale - 1) + 1)
|
| 52 |
+
|
| 53 |
+
t = jnp.arange(end, dtype=jnp.float32)
|
| 54 |
+
freqs = jnp.outer(t, freqs)
|
| 55 |
+
|
| 56 |
+
mscale = 1.0
|
| 57 |
+
if scale > 1.0:
|
| 58 |
+
mscale = 0.1 * 1.0 * jnp.log(scale) + 1.0
|
| 59 |
+
|
| 60 |
+
cos = jnp.cos(freqs) * mscale
|
| 61 |
+
sin = jnp.sin(freqs) * mscale
|
| 62 |
+
|
| 63 |
+
return jnp.concatenate([cos, sin], axis=-1).astype(dtype), mscale
|
| 64 |
+
|
| 65 |
+
def apply_rotary_emb(xq, xk, freqs_cis, mscale=1.0):
|
| 66 |
+
def rotate_half(x):
|
| 67 |
+
x1, x2 = jnp.split(x, 2, axis=-1)
|
| 68 |
+
return jnp.concatenate([-x2, x1], axis=-1)
|
| 69 |
+
|
| 70 |
+
seq_len = xq.shape[2]
|
| 71 |
+
head_dim = xq.shape[3]
|
| 72 |
+
|
| 73 |
+
freqs = freqs_cis[:seq_len, :]
|
| 74 |
+
half_dim = head_dim // 2
|
| 75 |
+
cos = freqs[:, :half_dim]
|
| 76 |
+
sin = freqs[:, half_dim:]
|
| 77 |
+
|
| 78 |
+
cos = jnp.repeat(cos, 2, axis=-1)
|
| 79 |
+
sin = jnp.repeat(sin, 2, axis=-1)
|
| 80 |
+
|
| 81 |
+
cos = cos[None, None, :, :]
|
| 82 |
+
sin = sin[None, None, :, :]
|
| 83 |
+
|
| 84 |
+
xq_out = (xq * cos) + (rotate_half(xq) * sin)
|
| 85 |
+
xk_out = (xk * cos) + (rotate_half(xk) * sin)
|
| 86 |
+
|
| 87 |
+
return xq_out, xk_out
|
| 88 |
+
|
| 89 |
+
class DepthwiseSeparableConv1D(nn.Module):
|
| 90 |
+
channels: int
|
| 91 |
+
kernel_size: int = 3
|
| 92 |
+
dtype: Any = jnp.bfloat16
|
| 93 |
+
|
| 94 |
+
@nn.compact
|
| 95 |
+
def __call__(self, x):
|
| 96 |
+
depthwise = nn.Conv(
|
| 97 |
+
features=self.channels,
|
| 98 |
+
kernel_size=(self.kernel_size,),
|
| 99 |
+
feature_group_count=self.channels,
|
| 100 |
+
padding='SAME',
|
| 101 |
+
use_bias=False,
|
| 102 |
+
dtype=self.dtype,
|
| 103 |
+
name='depthwise'
|
| 104 |
+
)(x)
|
| 105 |
+
|
| 106 |
+
pointwise = nn.Conv(
|
| 107 |
+
features=self.channels,
|
| 108 |
+
kernel_size=(1,),
|
| 109 |
+
use_bias=False,
|
| 110 |
+
dtype=self.dtype,
|
| 111 |
+
name='pointwise'
|
| 112 |
+
)(depthwise)
|
| 113 |
+
|
| 114 |
+
return pointwise
|
| 115 |
+
|
| 116 |
+
class LocalContextCNN(nn.Module):
|
| 117 |
+
d_model: int
|
| 118 |
+
dropout: float
|
| 119 |
+
dtype: Any = jnp.bfloat16
|
| 120 |
+
|
| 121 |
+
@nn.compact
|
| 122 |
+
def __call__(self, x, training: bool = False):
|
| 123 |
+
conv3 = DepthwiseSeparableConv1D(self.d_model, 3, self.dtype, name='conv3')(x)
|
| 124 |
+
conv5 = DepthwiseSeparableConv1D(self.d_model, 5, self.dtype, name='conv5')(x)
|
| 125 |
+
conv7 = DepthwiseSeparableConv1D(self.d_model, 7, self.dtype, name='conv7')(x)
|
| 126 |
+
|
| 127 |
+
gate = nn.Dense(self.d_model * 3, dtype=self.dtype, name='fusion_gate')(x)
|
| 128 |
+
gate = nn.sigmoid(gate)
|
| 129 |
+
g3, g5, g7 = jnp.split(gate, 3, axis=-1)
|
| 130 |
+
|
| 131 |
+
out = g3 * conv3 + g5 * conv5 + g7 * conv7
|
| 132 |
+
|
| 133 |
+
scale = self.param('layer_scale', nn.initializers.constant(1e-6), (self.d_model,))
|
| 134 |
+
out = out * scale
|
| 135 |
+
|
| 136 |
+
return nn.Dropout(self.dropout, deterministic=not training)(out)
|
| 137 |
+
|
| 138 |
+
class MinGRUCell(nn.Module):
|
| 139 |
+
hidden_size: int
|
| 140 |
+
dtype: Any = jnp.bfloat16
|
| 141 |
+
|
| 142 |
+
@nn.compact
|
| 143 |
+
def __call__(self, x, h):
|
| 144 |
+
z = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='gate')(x)
|
| 145 |
+
h_tilde = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='candidate')(x)
|
| 146 |
+
|
| 147 |
+
z = nn.sigmoid(z)
|
| 148 |
+
h_tilde = nn.tanh(h_tilde)
|
| 149 |
+
h_new = (1 - z) * h + z * h_tilde
|
| 150 |
+
|
| 151 |
+
return h_new
|
| 152 |
+
|
| 153 |
+
class BidirectionalMinGRU(nn.Module):
|
| 154 |
+
hidden_size: int
|
| 155 |
+
dropout: float
|
| 156 |
+
dtype: Any = jnp.bfloat16
|
| 157 |
+
|
| 158 |
+
@nn.compact
|
| 159 |
+
def __call__(self, x, training: bool = False):
|
| 160 |
+
batch_size, seq_len, d_model = x.shape
|
| 161 |
+
|
| 162 |
+
x_proj = nn.Dense(self.hidden_size, dtype=self.dtype, name='input_proj')(x)
|
| 163 |
+
|
| 164 |
+
class ScanRNNCell(nn.Module):
|
| 165 |
+
hidden_size: int
|
| 166 |
+
dtype: Any = jnp.bfloat16
|
| 167 |
+
|
| 168 |
+
@nn.compact
|
| 169 |
+
def __call__(self, h, x_t):
|
| 170 |
+
cell = MinGRUCell(self.hidden_size, dtype=self.dtype)
|
| 171 |
+
h_new = cell(x_t, h)
|
| 172 |
+
return h_new, h_new
|
| 173 |
+
|
| 174 |
+
ForwardScanner = nn.scan(
|
| 175 |
+
ScanRNNCell,
|
| 176 |
+
variable_broadcast='params',
|
| 177 |
+
split_rngs={'params': False},
|
| 178 |
+
in_axes=1,
|
| 179 |
+
out_axes=1
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
h0_forward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype)
|
| 183 |
+
_, h_forward = ForwardScanner(
|
| 184 |
+
hidden_size=self.hidden_size,
|
| 185 |
+
dtype=self.dtype,
|
| 186 |
+
name='forward_cell'
|
| 187 |
+
)(h0_forward, x_proj)
|
| 188 |
+
|
| 189 |
+
BackwardScanner = nn.scan(
|
| 190 |
+
ScanRNNCell,
|
| 191 |
+
variable_broadcast='params',
|
| 192 |
+
split_rngs={'params': False},
|
| 193 |
+
in_axes=1,
|
| 194 |
+
out_axes=1
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
h0_backward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype)
|
| 198 |
+
x_proj_reversed = jnp.flip(x_proj, axis=1)
|
| 199 |
+
_, h_backward = BackwardScanner(
|
| 200 |
+
hidden_size=self.hidden_size,
|
| 201 |
+
dtype=self.dtype,
|
| 202 |
+
name='backward_cell'
|
| 203 |
+
)(h0_backward, x_proj_reversed)
|
| 204 |
+
h_backward = jnp.flip(h_backward, axis=1)
|
| 205 |
+
|
| 206 |
+
h_bi = jnp.concatenate([h_forward, h_backward], axis=-1)
|
| 207 |
+
out = nn.Dense(d_model, dtype=self.dtype, name='output_proj')(h_bi)
|
| 208 |
+
|
| 209 |
+
scale = self.param('layer_scale', nn.initializers.constant(1e-6), (d_model,))
|
| 210 |
+
out = out * scale
|
| 211 |
+
|
| 212 |
+
return nn.Dropout(self.dropout, deterministic=not training)(out)
|
| 213 |
+
|
| 214 |
+
class GroupedQueryAttention(nn.Module):
|
| 215 |
+
d_model: int
|
| 216 |
+
n_heads: int
|
| 217 |
+
n_kv_heads: int
|
| 218 |
+
dropout: float
|
| 219 |
+
freqs_cis: jnp.ndarray
|
| 220 |
+
yarn_mscale: float
|
| 221 |
+
alibi_bias: Optional[jnp.ndarray]
|
| 222 |
+
alibi_weight: float
|
| 223 |
+
dtype: Any = jnp.bfloat16
|
| 224 |
+
|
| 225 |
+
@nn.compact
|
| 226 |
+
def __call__(self, x, mask, training: bool = False):
|
| 227 |
+
B, T, D = x.shape
|
| 228 |
+
head_dim = self.d_model // self.n_heads
|
| 229 |
+
n_rep = self.n_heads // self.n_kv_heads
|
| 230 |
+
|
| 231 |
+
q = nn.Dense(self.d_model, use_bias=False,
|
| 232 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 233 |
+
dtype=self.dtype, name='q_proj')(x)
|
| 234 |
+
|
| 235 |
+
kv_dim = self.d_model * self.n_kv_heads // self.n_heads
|
| 236 |
+
k = nn.Dense(kv_dim, use_bias=False,
|
| 237 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 238 |
+
dtype=self.dtype, name='k_proj')(x)
|
| 239 |
+
v = nn.Dense(kv_dim, use_bias=False,
|
| 240 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 241 |
+
dtype=self.dtype, name='v_proj')(x)
|
| 242 |
+
|
| 243 |
+
q = q.reshape(B, T, self.n_heads, head_dim).transpose(0, 2, 1, 3)
|
| 244 |
+
k = k.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3)
|
| 245 |
+
v = v.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3)
|
| 246 |
+
|
| 247 |
+
k = jnp.repeat(k, n_rep, axis=1)
|
| 248 |
+
v = jnp.repeat(v, n_rep, axis=1)
|
| 249 |
+
|
| 250 |
+
q, k = apply_rotary_emb(q, k, self.freqs_cis, self.yarn_mscale)
|
| 251 |
+
|
| 252 |
+
scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(head_dim).astype(self.dtype)
|
| 253 |
+
|
| 254 |
+
if self.alibi_bias is not None:
|
| 255 |
+
scores = scores * (1 - self.alibi_weight)
|
| 256 |
+
alibi = self.alibi_bias[:, :, :T, :T]
|
| 257 |
+
scores = scores + (alibi * self.alibi_weight)
|
| 258 |
+
|
| 259 |
+
scores = scores + mask
|
| 260 |
+
|
| 261 |
+
attn_weights = nn.softmax(scores, axis=-1)
|
| 262 |
+
attn_weights = nn.Dropout(self.dropout, deterministic=not training)(attn_weights)
|
| 263 |
+
|
| 264 |
+
attn_out = jnp.matmul(attn_weights, v)
|
| 265 |
+
attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, T, D)
|
| 266 |
+
|
| 267 |
+
out = nn.Dense(self.d_model, use_bias=False,
|
| 268 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 269 |
+
dtype=self.dtype, name='o_proj')(attn_out)
|
| 270 |
+
|
| 271 |
+
return nn.Dropout(self.dropout, deterministic=not training)(out)
|
| 272 |
+
|
| 273 |
+
class SwiGLU(nn.Module):
|
| 274 |
+
d_model: int
|
| 275 |
+
ff_dim: int
|
| 276 |
+
dropout: float
|
| 277 |
+
dtype: Any = jnp.bfloat16
|
| 278 |
+
|
| 279 |
+
@nn.compact
|
| 280 |
+
def __call__(self, x, training: bool = False):
|
| 281 |
+
gate = nn.Dense(self.ff_dim, use_bias=False,
|
| 282 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 283 |
+
dtype=self.dtype, name='gate_proj')(x)
|
| 284 |
+
up = nn.Dense(self.ff_dim, use_bias=False,
|
| 285 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 286 |
+
dtype=self.dtype, name='up_proj')(x)
|
| 287 |
+
hidden = nn.silu(gate) * up
|
| 288 |
+
out = nn.Dense(self.d_model, use_bias=False,
|
| 289 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 290 |
+
dtype=self.dtype, name='down_proj')(hidden)
|
| 291 |
+
return nn.Dropout(self.dropout, deterministic=not training)(out)
|
| 292 |
+
|
| 293 |
+
class HybridTransformerBlock(nn.Module):
|
| 294 |
+
d_model: int
|
| 295 |
+
n_heads: int
|
| 296 |
+
n_kv_heads: int
|
| 297 |
+
ff_dim: int
|
| 298 |
+
dropout: float
|
| 299 |
+
freqs_cis: jnp.ndarray
|
| 300 |
+
yarn_mscale: float
|
| 301 |
+
alibi_bias: Optional[jnp.ndarray]
|
| 302 |
+
alibi_weight: float
|
| 303 |
+
layer_idx: int
|
| 304 |
+
layer_drop_prob: float = 0.0
|
| 305 |
+
use_cnn: bool = True
|
| 306 |
+
use_rnn: bool = True
|
| 307 |
+
rnn_hidden: int = 512
|
| 308 |
+
dtype: Any = jnp.bfloat16
|
| 309 |
+
|
| 310 |
+
@nn.compact
|
| 311 |
+
def __call__(self, x, mask, training: bool = False):
|
| 312 |
+
scale = 1.0
|
| 313 |
+
|
| 314 |
+
if self.use_rnn:
|
| 315 |
+
h_rnn = RMSNorm(dtype=self.dtype, name='rnn_norm')(x)
|
| 316 |
+
h_rnn = BidirectionalMinGRU(
|
| 317 |
+
self.rnn_hidden, self.dropout, dtype=self.dtype, name='bidirectional_rnn'
|
| 318 |
+
)(h_rnn, training)
|
| 319 |
+
x = x + h_rnn * scale
|
| 320 |
+
|
| 321 |
+
if self.use_cnn:
|
| 322 |
+
h_cnn = RMSNorm(dtype=self.dtype, name='cnn_norm')(x)
|
| 323 |
+
h_cnn = LocalContextCNN(
|
| 324 |
+
self.d_model, self.dropout, dtype=self.dtype, name='local_cnn'
|
| 325 |
+
)(h_cnn, training)
|
| 326 |
+
x = x + h_cnn * scale
|
| 327 |
+
|
| 328 |
+
h = RMSNorm(dtype=self.dtype, name='attn_norm')(x)
|
| 329 |
+
h = GroupedQueryAttention(
|
| 330 |
+
self.d_model, self.n_heads, self.n_kv_heads, self.dropout,
|
| 331 |
+
self.freqs_cis, self.yarn_mscale, self.alibi_bias,
|
| 332 |
+
self.alibi_weight, dtype=self.dtype, name='attn'
|
| 333 |
+
)(h, mask, training)
|
| 334 |
+
x = x + h * scale
|
| 335 |
+
|
| 336 |
+
h = RMSNorm(dtype=self.dtype, name='ffn_norm')(x)
|
| 337 |
+
h = SwiGLU(self.d_model, self.ff_dim, self.dropout,
|
| 338 |
+
dtype=self.dtype, name='ffn')(h, training)
|
| 339 |
+
x = x + h * scale
|
| 340 |
+
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
class SAM1HybridModel(nn.Module):
|
| 344 |
+
vocab_size: int
|
| 345 |
+
d_model: int
|
| 346 |
+
n_layers: int
|
| 347 |
+
n_heads: int
|
| 348 |
+
n_kv_heads: int
|
| 349 |
+
ff_dim: int
|
| 350 |
+
max_len: int
|
| 351 |
+
dropout: float = 0.1
|
| 352 |
+
layer_drop_prob: float = 0.05
|
| 353 |
+
rope_theta: float = 10000.0
|
| 354 |
+
yarn_scale: float = 1.0
|
| 355 |
+
yarn_alpha: float = 1.0
|
| 356 |
+
yarn_beta: float = 32.0
|
| 357 |
+
use_alibi: bool = False
|
| 358 |
+
alibi_weight: float = 0.3
|
| 359 |
+
use_cnn: bool = True
|
| 360 |
+
use_rnn: bool = True
|
| 361 |
+
rnn_hidden: int = 384
|
| 362 |
+
dtype: Any = jnp.bfloat16
|
| 363 |
+
|
| 364 |
+
@nn.compact
|
| 365 |
+
def __call__(self, input_ids, training: bool = False):
|
| 366 |
+
head_dim = self.d_model // self.n_heads
|
| 367 |
+
|
| 368 |
+
freqs_cis, yarn_mscale = precompute_yarn_freqs(
|
| 369 |
+
head_dim, self.max_len, self.rope_theta,
|
| 370 |
+
self.yarn_scale, self.yarn_alpha, self.yarn_beta, self.dtype
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
alibi_bias = None
|
| 374 |
+
|
| 375 |
+
x = nn.Embed(self.vocab_size, self.d_model,
|
| 376 |
+
embedding_init=nn.initializers.normal(stddev=0.02),
|
| 377 |
+
dtype=self.dtype, name='embed_tokens')(input_ids)
|
| 378 |
+
|
| 379 |
+
seq_len = input_ids.shape[1]
|
| 380 |
+
mask = jnp.tril(jnp.ones((seq_len, seq_len)))
|
| 381 |
+
mask = jnp.where(mask == 0, -1e9, 0.0).astype(self.dtype)
|
| 382 |
+
|
| 383 |
+
for i in range(self.n_layers):
|
| 384 |
+
use_cnn_layer = self.use_cnn and (i % 3 == 0)
|
| 385 |
+
use_rnn_layer = self.use_rnn and (i % 4 == 0)
|
| 386 |
+
|
| 387 |
+
x = HybridTransformerBlock(
|
| 388 |
+
self.d_model, self.n_heads, self.n_kv_heads, self.ff_dim,
|
| 389 |
+
self.dropout, freqs_cis, yarn_mscale, alibi_bias,
|
| 390 |
+
self.alibi_weight, i, self.layer_drop_prob,
|
| 391 |
+
use_cnn_layer, use_rnn_layer, self.rnn_hidden,
|
| 392 |
+
dtype=self.dtype, name=f'layers_{i}'
|
| 393 |
+
)(x, mask, training)
|
| 394 |
+
|
| 395 |
+
x = RMSNorm(dtype=self.dtype, name='norm')(x)
|
| 396 |
+
|
| 397 |
+
logits = nn.Dense(self.vocab_size, use_bias=False,
|
| 398 |
+
kernel_init=nn.initializers.normal(stddev=0.02),
|
| 399 |
+
dtype=self.dtype, name='lm_head')(x)
|
| 400 |
+
|
| 401 |
+
return logits
|
| 402 |
+
|
| 403 |
+
# ==============================================================================
|
| 404 |
+
# MODEL LOADING & GENERATION
|
| 405 |
+
# ==============================================================================
|
| 406 |
+
|
| 407 |
+
class ModelWrapper:
|
| 408 |
+
def __init__(self, model_path: str):
|
| 409 |
+
print("π§ Loading model...")
|
| 410 |
+
|
| 411 |
+
# Load config
|
| 412 |
+
with open(os.path.join(model_path, "config.json"), 'r') as f:
|
| 413 |
+
config = json.load(f)
|
| 414 |
+
|
| 415 |
+
self.vocab_size = config['vocab_size']
|
| 416 |
+
self.d_model = config['d_model']
|
| 417 |
+
self.n_layers = config['n_layers']
|
| 418 |
+
self.n_heads = config['n_heads']
|
| 419 |
+
self.n_kv_heads = config['n_kv_heads']
|
| 420 |
+
self.ff_dim = int(self.d_model * 2.5)
|
| 421 |
+
self.max_len = config['max_len']
|
| 422 |
+
self.use_cnn = config.get('use_cnn', True)
|
| 423 |
+
self.use_rnn = config.get('use_rnn', True)
|
| 424 |
+
self.rnn_hidden = config.get('rnn_hidden', 384)
|
| 425 |
+
|
| 426 |
+
# Load tokenizer
|
| 427 |
+
self.tokenizer = Tokenizer.from_file(os.path.join(model_path, "tokenizer.json"))
|
| 428 |
+
|
| 429 |
+
# Initialize model
|
| 430 |
+
self.model = SAM1HybridModel(
|
| 431 |
+
vocab_size=self.vocab_size,
|
| 432 |
+
d_model=self.d_model,
|
| 433 |
+
n_layers=self.n_layers,
|
| 434 |
+
n_heads=self.n_heads,
|
| 435 |
+
n_kv_heads=self.n_kv_heads,
|
| 436 |
+
ff_dim=self.ff_dim,
|
| 437 |
+
max_len=self.max_len,
|
| 438 |
+
use_cnn=self.use_cnn,
|
| 439 |
+
use_rnn=self.use_rnn,
|
| 440 |
+
rnn_hidden=self.rnn_hidden,
|
| 441 |
+
dtype=jnp.bfloat16
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Load weights
|
| 445 |
+
flat_params = load_file(os.path.join(model_path, "model.safetensors"))
|
| 446 |
+
|
| 447 |
+
# Unflatten parameters
|
| 448 |
+
def unflatten_dict(flat_dict, sep='.'):
|
| 449 |
+
result = {}
|
| 450 |
+
for key, value in flat_dict.items():
|
| 451 |
+
parts = key.split(sep)
|
| 452 |
+
d = result
|
| 453 |
+
for part in parts[:-1]:
|
| 454 |
+
if part not in d:
|
| 455 |
+
d[part] = {}
|
| 456 |
+
d = d[part]
|
| 457 |
+
d[parts[-1]] = jnp.array(value)
|
| 458 |
+
return result
|
| 459 |
+
|
| 460 |
+
self.params = {'params': unflatten_dict(flat_params)}
|
| 461 |
+
|
| 462 |
+
print(f"β
Model loaded: {self.d_model}d Γ {self.n_layers}L Γ {self.n_heads}H")
|
| 463 |
+
|
| 464 |
+
def generate(self, prompt: str, max_new_tokens: int = 200,
|
| 465 |
+
temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
|
| 466 |
+
# Format prompt correctly (NO newline between User: and Sam:)
|
| 467 |
+
if not prompt.startswith("User:"):
|
| 468 |
+
prompt = f"User: {prompt} Sam:"
|
| 469 |
+
else:
|
| 470 |
+
if " Sam:" not in prompt:
|
| 471 |
+
prompt = prompt + " Sam:"
|
| 472 |
+
|
| 473 |
+
# Tokenize
|
| 474 |
+
encoding = self.tokenizer.encode(prompt)
|
| 475 |
+
input_ids = jnp.array(encoding.ids)[None, :]
|
| 476 |
+
|
| 477 |
+
if input_ids.shape[1] > self.max_len:
|
| 478 |
+
input_ids = input_ids[:, -self.max_len:]
|
| 479 |
+
|
| 480 |
+
rng = random.PRNGKey(42)
|
| 481 |
+
generated_ids = input_ids
|
| 482 |
+
|
| 483 |
+
# Generate tokens
|
| 484 |
+
for _ in range(max_new_tokens):
|
| 485 |
+
logits = self.model.apply(self.params, generated_ids, training=False)
|
| 486 |
+
next_logits = logits[0, -1, :] / temperature
|
| 487 |
+
|
| 488 |
+
# Top-k filtering
|
| 489 |
+
top_k_logits, top_k_indices = jax.lax.top_k(next_logits, top_k)
|
| 490 |
+
|
| 491 |
+
# Top-p (nucleus) filtering
|
| 492 |
+
sorted_logits = jnp.sort(top_k_logits)[::-1]
|
| 493 |
+
sorted_indices = jnp.argsort(top_k_logits)[::-1]
|
| 494 |
+
cumsum_probs = jnp.cumsum(nn.softmax(sorted_logits))
|
| 495 |
+
mask = cumsum_probs <= top_p
|
| 496 |
+
mask = jnp.concatenate([jnp.array([True]), mask[:-1]])
|
| 497 |
+
|
| 498 |
+
filtered_logits = jnp.where(mask, sorted_logits, -1e9)
|
| 499 |
+
|
| 500 |
+
# Sample
|
| 501 |
+
rng, sample_rng = random.split(rng)
|
| 502 |
+
next_token_idx = random.categorical(sample_rng, filtered_logits)
|
| 503 |
+
next_token = top_k_indices[sorted_indices[next_token_idx]][None, None]
|
| 504 |
+
|
| 505 |
+
generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)
|
| 506 |
+
|
| 507 |
+
# Stop on EOS
|
| 508 |
+
if next_token[0, 0] == self.tokenizer.token_to_id("<|endoftext|>"):
|
| 509 |
+
break
|
| 510 |
+
|
| 511 |
+
generated_text = self.tokenizer.decode(generated_ids[0].tolist())
|
| 512 |
+
|
| 513 |
+
# Extract response after "Sam:"
|
| 514 |
+
if "Sam:" in generated_text:
|
| 515 |
+
response = generated_text.split("Sam:")[-1].strip()
|
| 516 |
+
# Clean up
|
| 517 |
+
if "<|endoftext|>" in response:
|
| 518 |
+
response = response.split("<|endoftext|>")[0].strip()
|
| 519 |
+
return response
|
| 520 |
+
else:
|
| 521 |
+
return generated_text
|
| 522 |
+
|
| 523 |
+
# ==============================================================================
|
| 524 |
+
# GRADIO INTERFACE
|
| 525 |
+
# ==============================================================================
|
| 526 |
+
|
| 527 |
+
# Load model
|
| 528 |
+
model = ModelWrapper("Smilyai-labs/MixSam-exp")
|
| 529 |
+
|
| 530 |
+
def chat_fn(message, history, temperature, top_k, top_p, max_tokens):
|
| 531 |
+
# Build conversation context
|
| 532 |
+
conversation = ""
|
| 533 |
+
for user_msg, bot_msg in history:
|
| 534 |
+
conversation += f"User: {user_msg} Sam: {bot_msg} "
|
| 535 |
+
|
| 536 |
+
# Add current message
|
| 537 |
+
conversation += f"User: {message} Sam:"
|
| 538 |
+
|
| 539 |
+
# Generate response
|
| 540 |
+
response = model.generate(
|
| 541 |
+
conversation,
|
| 542 |
+
max_new_tokens=max_tokens,
|
| 543 |
+
temperature=temperature,
|
| 544 |
+
top_k=top_k,
|
| 545 |
+
top_p=top_p
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
return response
|
| 549 |
+
|
| 550 |
+
# Create Gradio interface
|
| 551 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 552 |
+
gr.Markdown("""
|
| 553 |
+
# π€ SAM1 Hybrid Chat
|
| 554 |
+
### Transformer + CNN + RNN Architecture
|
| 555 |
+
Chat with SAM1, a custom hybrid language model combining:
|
| 556 |
+
- π· **Transformer** attention (GQA + YARN + RoPE)
|
| 557 |
+
- πΆ **CNN** for local context (multi-scale convolutions)
|
| 558 |
+
- π΅ **RNN** for sequential modeling (bidirectional MinGRU)
|
| 559 |
+
""")
|
| 560 |
+
|
| 561 |
+
chatbot = gr.Chatbot(height=500, show_copy_button=True)
|
| 562 |
+
|
| 563 |
+
with gr.Row():
|
| 564 |
+
msg = gr.Textbox(
|
| 565 |
+
placeholder="Type your message here...",
|
| 566 |
+
show_label=False,
|
| 567 |
+
scale=4
|
| 568 |
+
)
|
| 569 |
+
submit = gr.Button("Send", scale=1, variant="primary")
|
| 570 |
+
|
| 571 |
+
with gr.Accordion("βοΈ Generation Settings", open=False):
|
| 572 |
+
with gr.Row():
|
| 573 |
+
temperature = gr.Slider(0.1, 2.0, value=0.8, label="Temperature", step=0.1)
|
| 574 |
+
top_k = gr.Slider(1, 100, value=50, label="Top-K", step=1)
|
| 575 |
+
with gr.Row():
|
| 576 |
+
top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top-P", step=0.05)
|
| 577 |
+
max_tokens = gr.Slider(50, 500, value=200, label="Max Tokens", step=10)
|
| 578 |
+
|
| 579 |
+
clear = gr.Button("ποΈ Clear Chat")
|
| 580 |
+
|
| 581 |
+
# Event handlers
|
| 582 |
+
msg.submit(
|
| 583 |
+
chat_fn,
|
| 584 |
+
inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens],
|
| 585 |
+
outputs=chatbot
|
| 586 |
+
).then(lambda: "", None, msg)
|
| 587 |
+
|
| 588 |
+
submit.click(
|
| 589 |
+
chat_fn,
|
| 590 |
+
inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens],
|
| 591 |
+
outputs=chatbot
|
| 592 |
+
).then(lambda: "", None, msg)
|
| 593 |
+
|
| 594 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
| 595 |
+
|
| 596 |
+
gr.Markdown("""
|
| 597 |
+
---
|
| 598 |
+
**Model Details:**
|
| 599 |
+
- Architecture: SAM1 Hybrid (Custom)
|
| 600 |
+
- Parameters: ~600M
|
| 601 |
+
- Context Length: 1024 tokens
|
| 602 |
+
- Format: `User: {query} Sam: {response}` (no newlines)
|
| 603 |
+
|
| 604 |
+
**Tips:**
|
| 605 |
+
- Lower temperature (0.3-0.5) for focused responses
|
| 606 |
+
- Higher temperature (0.8-1.2) for creative responses
|
| 607 |
+
- Adjust top-k/top-p for response diversity
|
| 608 |
+
""")
|
| 609 |
+
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
demo.queue().launch()
|