Keeby-smilyai commited on
Commit
0a071aa
Β·
verified Β·
1 Parent(s): 8f29b30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +611 -0
app.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from jax import random
5
+ import flax.linen as nn
6
+ from tokenizers import Tokenizer
7
+ from safetensors.flax import load_file
8
+ import json
9
+ import os
10
+ from typing import Any, Optional
11
+ import numpy as np
12
+
13
+ # ==============================================================================
14
+ # MODEL ARCHITECTURE (from your training code)
15
+ # ==============================================================================
16
+
17
+ class RMSNorm(nn.Module):
18
+ epsilon: float = 1e-5
19
+ dtype: Any = jnp.bfloat16
20
+
21
+ @nn.compact
22
+ def __call__(self, x):
23
+ x = x.astype(jnp.float32)
24
+ scale = self.param('scale', nn.initializers.ones, (x.shape[-1],))
25
+ variance = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
26
+ x = x * jax.lax.rsqrt(variance + self.epsilon) * scale
27
+ return x.astype(self.dtype)
28
+
29
+ def precompute_yarn_freqs(dim: int, end: int, theta: float = 10000.0,
30
+ scale: float = 1.0, alpha: float = 1.0,
31
+ beta: float = 32.0, dtype=jnp.bfloat16):
32
+ freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype=jnp.float32) / dim))
33
+
34
+ if scale > 1.0:
35
+ def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
36
+ return (dim * jnp.log(max_position_embeddings / (num_rotations * 2 * jnp.pi))) / (2 * jnp.log(base))
37
+
38
+ def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
39
+ low = jnp.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
40
+ high = jnp.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings))
41
+ return jnp.maximum(low, 0).astype(jnp.int32), jnp.minimum(high, dim - 1).astype(jnp.int32)
42
+
43
+ def yarn_linear_ramp_mask(min_val, max_val, dim):
44
+ if min_val == max_val:
45
+ max_val += 0.001
46
+ linear_func = (jnp.arange(dim, dtype=jnp.float32) - min_val) / (max_val - min_val)
47
+ return jnp.clip(linear_func, 0, 1)
48
+
49
+ low, high = yarn_find_correction_range(beta, alpha, dim, theta, int(end * scale))
50
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
51
+ freqs = freqs / ((1 - inv_freq_mask) * (scale - 1) + 1)
52
+
53
+ t = jnp.arange(end, dtype=jnp.float32)
54
+ freqs = jnp.outer(t, freqs)
55
+
56
+ mscale = 1.0
57
+ if scale > 1.0:
58
+ mscale = 0.1 * 1.0 * jnp.log(scale) + 1.0
59
+
60
+ cos = jnp.cos(freqs) * mscale
61
+ sin = jnp.sin(freqs) * mscale
62
+
63
+ return jnp.concatenate([cos, sin], axis=-1).astype(dtype), mscale
64
+
65
+ def apply_rotary_emb(xq, xk, freqs_cis, mscale=1.0):
66
+ def rotate_half(x):
67
+ x1, x2 = jnp.split(x, 2, axis=-1)
68
+ return jnp.concatenate([-x2, x1], axis=-1)
69
+
70
+ seq_len = xq.shape[2]
71
+ head_dim = xq.shape[3]
72
+
73
+ freqs = freqs_cis[:seq_len, :]
74
+ half_dim = head_dim // 2
75
+ cos = freqs[:, :half_dim]
76
+ sin = freqs[:, half_dim:]
77
+
78
+ cos = jnp.repeat(cos, 2, axis=-1)
79
+ sin = jnp.repeat(sin, 2, axis=-1)
80
+
81
+ cos = cos[None, None, :, :]
82
+ sin = sin[None, None, :, :]
83
+
84
+ xq_out = (xq * cos) + (rotate_half(xq) * sin)
85
+ xk_out = (xk * cos) + (rotate_half(xk) * sin)
86
+
87
+ return xq_out, xk_out
88
+
89
+ class DepthwiseSeparableConv1D(nn.Module):
90
+ channels: int
91
+ kernel_size: int = 3
92
+ dtype: Any = jnp.bfloat16
93
+
94
+ @nn.compact
95
+ def __call__(self, x):
96
+ depthwise = nn.Conv(
97
+ features=self.channels,
98
+ kernel_size=(self.kernel_size,),
99
+ feature_group_count=self.channels,
100
+ padding='SAME',
101
+ use_bias=False,
102
+ dtype=self.dtype,
103
+ name='depthwise'
104
+ )(x)
105
+
106
+ pointwise = nn.Conv(
107
+ features=self.channels,
108
+ kernel_size=(1,),
109
+ use_bias=False,
110
+ dtype=self.dtype,
111
+ name='pointwise'
112
+ )(depthwise)
113
+
114
+ return pointwise
115
+
116
+ class LocalContextCNN(nn.Module):
117
+ d_model: int
118
+ dropout: float
119
+ dtype: Any = jnp.bfloat16
120
+
121
+ @nn.compact
122
+ def __call__(self, x, training: bool = False):
123
+ conv3 = DepthwiseSeparableConv1D(self.d_model, 3, self.dtype, name='conv3')(x)
124
+ conv5 = DepthwiseSeparableConv1D(self.d_model, 5, self.dtype, name='conv5')(x)
125
+ conv7 = DepthwiseSeparableConv1D(self.d_model, 7, self.dtype, name='conv7')(x)
126
+
127
+ gate = nn.Dense(self.d_model * 3, dtype=self.dtype, name='fusion_gate')(x)
128
+ gate = nn.sigmoid(gate)
129
+ g3, g5, g7 = jnp.split(gate, 3, axis=-1)
130
+
131
+ out = g3 * conv3 + g5 * conv5 + g7 * conv7
132
+
133
+ scale = self.param('layer_scale', nn.initializers.constant(1e-6), (self.d_model,))
134
+ out = out * scale
135
+
136
+ return nn.Dropout(self.dropout, deterministic=not training)(out)
137
+
138
+ class MinGRUCell(nn.Module):
139
+ hidden_size: int
140
+ dtype: Any = jnp.bfloat16
141
+
142
+ @nn.compact
143
+ def __call__(self, x, h):
144
+ z = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='gate')(x)
145
+ h_tilde = nn.Dense(self.hidden_size, use_bias=True, dtype=self.dtype, name='candidate')(x)
146
+
147
+ z = nn.sigmoid(z)
148
+ h_tilde = nn.tanh(h_tilde)
149
+ h_new = (1 - z) * h + z * h_tilde
150
+
151
+ return h_new
152
+
153
+ class BidirectionalMinGRU(nn.Module):
154
+ hidden_size: int
155
+ dropout: float
156
+ dtype: Any = jnp.bfloat16
157
+
158
+ @nn.compact
159
+ def __call__(self, x, training: bool = False):
160
+ batch_size, seq_len, d_model = x.shape
161
+
162
+ x_proj = nn.Dense(self.hidden_size, dtype=self.dtype, name='input_proj')(x)
163
+
164
+ class ScanRNNCell(nn.Module):
165
+ hidden_size: int
166
+ dtype: Any = jnp.bfloat16
167
+
168
+ @nn.compact
169
+ def __call__(self, h, x_t):
170
+ cell = MinGRUCell(self.hidden_size, dtype=self.dtype)
171
+ h_new = cell(x_t, h)
172
+ return h_new, h_new
173
+
174
+ ForwardScanner = nn.scan(
175
+ ScanRNNCell,
176
+ variable_broadcast='params',
177
+ split_rngs={'params': False},
178
+ in_axes=1,
179
+ out_axes=1
180
+ )
181
+
182
+ h0_forward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype)
183
+ _, h_forward = ForwardScanner(
184
+ hidden_size=self.hidden_size,
185
+ dtype=self.dtype,
186
+ name='forward_cell'
187
+ )(h0_forward, x_proj)
188
+
189
+ BackwardScanner = nn.scan(
190
+ ScanRNNCell,
191
+ variable_broadcast='params',
192
+ split_rngs={'params': False},
193
+ in_axes=1,
194
+ out_axes=1
195
+ )
196
+
197
+ h0_backward = jnp.zeros((batch_size, self.hidden_size), dtype=self.dtype)
198
+ x_proj_reversed = jnp.flip(x_proj, axis=1)
199
+ _, h_backward = BackwardScanner(
200
+ hidden_size=self.hidden_size,
201
+ dtype=self.dtype,
202
+ name='backward_cell'
203
+ )(h0_backward, x_proj_reversed)
204
+ h_backward = jnp.flip(h_backward, axis=1)
205
+
206
+ h_bi = jnp.concatenate([h_forward, h_backward], axis=-1)
207
+ out = nn.Dense(d_model, dtype=self.dtype, name='output_proj')(h_bi)
208
+
209
+ scale = self.param('layer_scale', nn.initializers.constant(1e-6), (d_model,))
210
+ out = out * scale
211
+
212
+ return nn.Dropout(self.dropout, deterministic=not training)(out)
213
+
214
+ class GroupedQueryAttention(nn.Module):
215
+ d_model: int
216
+ n_heads: int
217
+ n_kv_heads: int
218
+ dropout: float
219
+ freqs_cis: jnp.ndarray
220
+ yarn_mscale: float
221
+ alibi_bias: Optional[jnp.ndarray]
222
+ alibi_weight: float
223
+ dtype: Any = jnp.bfloat16
224
+
225
+ @nn.compact
226
+ def __call__(self, x, mask, training: bool = False):
227
+ B, T, D = x.shape
228
+ head_dim = self.d_model // self.n_heads
229
+ n_rep = self.n_heads // self.n_kv_heads
230
+
231
+ q = nn.Dense(self.d_model, use_bias=False,
232
+ kernel_init=nn.initializers.normal(stddev=0.02),
233
+ dtype=self.dtype, name='q_proj')(x)
234
+
235
+ kv_dim = self.d_model * self.n_kv_heads // self.n_heads
236
+ k = nn.Dense(kv_dim, use_bias=False,
237
+ kernel_init=nn.initializers.normal(stddev=0.02),
238
+ dtype=self.dtype, name='k_proj')(x)
239
+ v = nn.Dense(kv_dim, use_bias=False,
240
+ kernel_init=nn.initializers.normal(stddev=0.02),
241
+ dtype=self.dtype, name='v_proj')(x)
242
+
243
+ q = q.reshape(B, T, self.n_heads, head_dim).transpose(0, 2, 1, 3)
244
+ k = k.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3)
245
+ v = v.reshape(B, T, self.n_kv_heads, head_dim).transpose(0, 2, 1, 3)
246
+
247
+ k = jnp.repeat(k, n_rep, axis=1)
248
+ v = jnp.repeat(v, n_rep, axis=1)
249
+
250
+ q, k = apply_rotary_emb(q, k, self.freqs_cis, self.yarn_mscale)
251
+
252
+ scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(head_dim).astype(self.dtype)
253
+
254
+ if self.alibi_bias is not None:
255
+ scores = scores * (1 - self.alibi_weight)
256
+ alibi = self.alibi_bias[:, :, :T, :T]
257
+ scores = scores + (alibi * self.alibi_weight)
258
+
259
+ scores = scores + mask
260
+
261
+ attn_weights = nn.softmax(scores, axis=-1)
262
+ attn_weights = nn.Dropout(self.dropout, deterministic=not training)(attn_weights)
263
+
264
+ attn_out = jnp.matmul(attn_weights, v)
265
+ attn_out = attn_out.transpose(0, 2, 1, 3).reshape(B, T, D)
266
+
267
+ out = nn.Dense(self.d_model, use_bias=False,
268
+ kernel_init=nn.initializers.normal(stddev=0.02),
269
+ dtype=self.dtype, name='o_proj')(attn_out)
270
+
271
+ return nn.Dropout(self.dropout, deterministic=not training)(out)
272
+
273
+ class SwiGLU(nn.Module):
274
+ d_model: int
275
+ ff_dim: int
276
+ dropout: float
277
+ dtype: Any = jnp.bfloat16
278
+
279
+ @nn.compact
280
+ def __call__(self, x, training: bool = False):
281
+ gate = nn.Dense(self.ff_dim, use_bias=False,
282
+ kernel_init=nn.initializers.normal(stddev=0.02),
283
+ dtype=self.dtype, name='gate_proj')(x)
284
+ up = nn.Dense(self.ff_dim, use_bias=False,
285
+ kernel_init=nn.initializers.normal(stddev=0.02),
286
+ dtype=self.dtype, name='up_proj')(x)
287
+ hidden = nn.silu(gate) * up
288
+ out = nn.Dense(self.d_model, use_bias=False,
289
+ kernel_init=nn.initializers.normal(stddev=0.02),
290
+ dtype=self.dtype, name='down_proj')(hidden)
291
+ return nn.Dropout(self.dropout, deterministic=not training)(out)
292
+
293
+ class HybridTransformerBlock(nn.Module):
294
+ d_model: int
295
+ n_heads: int
296
+ n_kv_heads: int
297
+ ff_dim: int
298
+ dropout: float
299
+ freqs_cis: jnp.ndarray
300
+ yarn_mscale: float
301
+ alibi_bias: Optional[jnp.ndarray]
302
+ alibi_weight: float
303
+ layer_idx: int
304
+ layer_drop_prob: float = 0.0
305
+ use_cnn: bool = True
306
+ use_rnn: bool = True
307
+ rnn_hidden: int = 512
308
+ dtype: Any = jnp.bfloat16
309
+
310
+ @nn.compact
311
+ def __call__(self, x, mask, training: bool = False):
312
+ scale = 1.0
313
+
314
+ if self.use_rnn:
315
+ h_rnn = RMSNorm(dtype=self.dtype, name='rnn_norm')(x)
316
+ h_rnn = BidirectionalMinGRU(
317
+ self.rnn_hidden, self.dropout, dtype=self.dtype, name='bidirectional_rnn'
318
+ )(h_rnn, training)
319
+ x = x + h_rnn * scale
320
+
321
+ if self.use_cnn:
322
+ h_cnn = RMSNorm(dtype=self.dtype, name='cnn_norm')(x)
323
+ h_cnn = LocalContextCNN(
324
+ self.d_model, self.dropout, dtype=self.dtype, name='local_cnn'
325
+ )(h_cnn, training)
326
+ x = x + h_cnn * scale
327
+
328
+ h = RMSNorm(dtype=self.dtype, name='attn_norm')(x)
329
+ h = GroupedQueryAttention(
330
+ self.d_model, self.n_heads, self.n_kv_heads, self.dropout,
331
+ self.freqs_cis, self.yarn_mscale, self.alibi_bias,
332
+ self.alibi_weight, dtype=self.dtype, name='attn'
333
+ )(h, mask, training)
334
+ x = x + h * scale
335
+
336
+ h = RMSNorm(dtype=self.dtype, name='ffn_norm')(x)
337
+ h = SwiGLU(self.d_model, self.ff_dim, self.dropout,
338
+ dtype=self.dtype, name='ffn')(h, training)
339
+ x = x + h * scale
340
+
341
+ return x
342
+
343
+ class SAM1HybridModel(nn.Module):
344
+ vocab_size: int
345
+ d_model: int
346
+ n_layers: int
347
+ n_heads: int
348
+ n_kv_heads: int
349
+ ff_dim: int
350
+ max_len: int
351
+ dropout: float = 0.1
352
+ layer_drop_prob: float = 0.05
353
+ rope_theta: float = 10000.0
354
+ yarn_scale: float = 1.0
355
+ yarn_alpha: float = 1.0
356
+ yarn_beta: float = 32.0
357
+ use_alibi: bool = False
358
+ alibi_weight: float = 0.3
359
+ use_cnn: bool = True
360
+ use_rnn: bool = True
361
+ rnn_hidden: int = 384
362
+ dtype: Any = jnp.bfloat16
363
+
364
+ @nn.compact
365
+ def __call__(self, input_ids, training: bool = False):
366
+ head_dim = self.d_model // self.n_heads
367
+
368
+ freqs_cis, yarn_mscale = precompute_yarn_freqs(
369
+ head_dim, self.max_len, self.rope_theta,
370
+ self.yarn_scale, self.yarn_alpha, self.yarn_beta, self.dtype
371
+ )
372
+
373
+ alibi_bias = None
374
+
375
+ x = nn.Embed(self.vocab_size, self.d_model,
376
+ embedding_init=nn.initializers.normal(stddev=0.02),
377
+ dtype=self.dtype, name='embed_tokens')(input_ids)
378
+
379
+ seq_len = input_ids.shape[1]
380
+ mask = jnp.tril(jnp.ones((seq_len, seq_len)))
381
+ mask = jnp.where(mask == 0, -1e9, 0.0).astype(self.dtype)
382
+
383
+ for i in range(self.n_layers):
384
+ use_cnn_layer = self.use_cnn and (i % 3 == 0)
385
+ use_rnn_layer = self.use_rnn and (i % 4 == 0)
386
+
387
+ x = HybridTransformerBlock(
388
+ self.d_model, self.n_heads, self.n_kv_heads, self.ff_dim,
389
+ self.dropout, freqs_cis, yarn_mscale, alibi_bias,
390
+ self.alibi_weight, i, self.layer_drop_prob,
391
+ use_cnn_layer, use_rnn_layer, self.rnn_hidden,
392
+ dtype=self.dtype, name=f'layers_{i}'
393
+ )(x, mask, training)
394
+
395
+ x = RMSNorm(dtype=self.dtype, name='norm')(x)
396
+
397
+ logits = nn.Dense(self.vocab_size, use_bias=False,
398
+ kernel_init=nn.initializers.normal(stddev=0.02),
399
+ dtype=self.dtype, name='lm_head')(x)
400
+
401
+ return logits
402
+
403
+ # ==============================================================================
404
+ # MODEL LOADING & GENERATION
405
+ # ==============================================================================
406
+
407
+ class ModelWrapper:
408
+ def __init__(self, model_path: str):
409
+ print("πŸ”§ Loading model...")
410
+
411
+ # Load config
412
+ with open(os.path.join(model_path, "config.json"), 'r') as f:
413
+ config = json.load(f)
414
+
415
+ self.vocab_size = config['vocab_size']
416
+ self.d_model = config['d_model']
417
+ self.n_layers = config['n_layers']
418
+ self.n_heads = config['n_heads']
419
+ self.n_kv_heads = config['n_kv_heads']
420
+ self.ff_dim = int(self.d_model * 2.5)
421
+ self.max_len = config['max_len']
422
+ self.use_cnn = config.get('use_cnn', True)
423
+ self.use_rnn = config.get('use_rnn', True)
424
+ self.rnn_hidden = config.get('rnn_hidden', 384)
425
+
426
+ # Load tokenizer
427
+ self.tokenizer = Tokenizer.from_file(os.path.join(model_path, "tokenizer.json"))
428
+
429
+ # Initialize model
430
+ self.model = SAM1HybridModel(
431
+ vocab_size=self.vocab_size,
432
+ d_model=self.d_model,
433
+ n_layers=self.n_layers,
434
+ n_heads=self.n_heads,
435
+ n_kv_heads=self.n_kv_heads,
436
+ ff_dim=self.ff_dim,
437
+ max_len=self.max_len,
438
+ use_cnn=self.use_cnn,
439
+ use_rnn=self.use_rnn,
440
+ rnn_hidden=self.rnn_hidden,
441
+ dtype=jnp.bfloat16
442
+ )
443
+
444
+ # Load weights
445
+ flat_params = load_file(os.path.join(model_path, "model.safetensors"))
446
+
447
+ # Unflatten parameters
448
+ def unflatten_dict(flat_dict, sep='.'):
449
+ result = {}
450
+ for key, value in flat_dict.items():
451
+ parts = key.split(sep)
452
+ d = result
453
+ for part in parts[:-1]:
454
+ if part not in d:
455
+ d[part] = {}
456
+ d = d[part]
457
+ d[parts[-1]] = jnp.array(value)
458
+ return result
459
+
460
+ self.params = {'params': unflatten_dict(flat_params)}
461
+
462
+ print(f"βœ… Model loaded: {self.d_model}d Γ— {self.n_layers}L Γ— {self.n_heads}H")
463
+
464
+ def generate(self, prompt: str, max_new_tokens: int = 200,
465
+ temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9):
466
+ # Format prompt correctly (NO newline between User: and Sam:)
467
+ if not prompt.startswith("User:"):
468
+ prompt = f"User: {prompt} Sam:"
469
+ else:
470
+ if " Sam:" not in prompt:
471
+ prompt = prompt + " Sam:"
472
+
473
+ # Tokenize
474
+ encoding = self.tokenizer.encode(prompt)
475
+ input_ids = jnp.array(encoding.ids)[None, :]
476
+
477
+ if input_ids.shape[1] > self.max_len:
478
+ input_ids = input_ids[:, -self.max_len:]
479
+
480
+ rng = random.PRNGKey(42)
481
+ generated_ids = input_ids
482
+
483
+ # Generate tokens
484
+ for _ in range(max_new_tokens):
485
+ logits = self.model.apply(self.params, generated_ids, training=False)
486
+ next_logits = logits[0, -1, :] / temperature
487
+
488
+ # Top-k filtering
489
+ top_k_logits, top_k_indices = jax.lax.top_k(next_logits, top_k)
490
+
491
+ # Top-p (nucleus) filtering
492
+ sorted_logits = jnp.sort(top_k_logits)[::-1]
493
+ sorted_indices = jnp.argsort(top_k_logits)[::-1]
494
+ cumsum_probs = jnp.cumsum(nn.softmax(sorted_logits))
495
+ mask = cumsum_probs <= top_p
496
+ mask = jnp.concatenate([jnp.array([True]), mask[:-1]])
497
+
498
+ filtered_logits = jnp.where(mask, sorted_logits, -1e9)
499
+
500
+ # Sample
501
+ rng, sample_rng = random.split(rng)
502
+ next_token_idx = random.categorical(sample_rng, filtered_logits)
503
+ next_token = top_k_indices[sorted_indices[next_token_idx]][None, None]
504
+
505
+ generated_ids = jnp.concatenate([generated_ids, next_token], axis=1)
506
+
507
+ # Stop on EOS
508
+ if next_token[0, 0] == self.tokenizer.token_to_id("<|endoftext|>"):
509
+ break
510
+
511
+ generated_text = self.tokenizer.decode(generated_ids[0].tolist())
512
+
513
+ # Extract response after "Sam:"
514
+ if "Sam:" in generated_text:
515
+ response = generated_text.split("Sam:")[-1].strip()
516
+ # Clean up
517
+ if "<|endoftext|>" in response:
518
+ response = response.split("<|endoftext|>")[0].strip()
519
+ return response
520
+ else:
521
+ return generated_text
522
+
523
+ # ==============================================================================
524
+ # GRADIO INTERFACE
525
+ # ==============================================================================
526
+
527
+ # Load model
528
+ model = ModelWrapper("Smilyai-labs/MixSam-exp")
529
+
530
+ def chat_fn(message, history, temperature, top_k, top_p, max_tokens):
531
+ # Build conversation context
532
+ conversation = ""
533
+ for user_msg, bot_msg in history:
534
+ conversation += f"User: {user_msg} Sam: {bot_msg} "
535
+
536
+ # Add current message
537
+ conversation += f"User: {message} Sam:"
538
+
539
+ # Generate response
540
+ response = model.generate(
541
+ conversation,
542
+ max_new_tokens=max_tokens,
543
+ temperature=temperature,
544
+ top_k=top_k,
545
+ top_p=top_p
546
+ )
547
+
548
+ return response
549
+
550
+ # Create Gradio interface
551
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
552
+ gr.Markdown("""
553
+ # πŸ€– SAM1 Hybrid Chat
554
+ ### Transformer + CNN + RNN Architecture
555
+ Chat with SAM1, a custom hybrid language model combining:
556
+ - πŸ”· **Transformer** attention (GQA + YARN + RoPE)
557
+ - πŸ”Ά **CNN** for local context (multi-scale convolutions)
558
+ - πŸ”΅ **RNN** for sequential modeling (bidirectional MinGRU)
559
+ """)
560
+
561
+ chatbot = gr.Chatbot(height=500, show_copy_button=True)
562
+
563
+ with gr.Row():
564
+ msg = gr.Textbox(
565
+ placeholder="Type your message here...",
566
+ show_label=False,
567
+ scale=4
568
+ )
569
+ submit = gr.Button("Send", scale=1, variant="primary")
570
+
571
+ with gr.Accordion("βš™οΈ Generation Settings", open=False):
572
+ with gr.Row():
573
+ temperature = gr.Slider(0.1, 2.0, value=0.8, label="Temperature", step=0.1)
574
+ top_k = gr.Slider(1, 100, value=50, label="Top-K", step=1)
575
+ with gr.Row():
576
+ top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top-P", step=0.05)
577
+ max_tokens = gr.Slider(50, 500, value=200, label="Max Tokens", step=10)
578
+
579
+ clear = gr.Button("πŸ—‘οΈ Clear Chat")
580
+
581
+ # Event handlers
582
+ msg.submit(
583
+ chat_fn,
584
+ inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens],
585
+ outputs=chatbot
586
+ ).then(lambda: "", None, msg)
587
+
588
+ submit.click(
589
+ chat_fn,
590
+ inputs=[msg, chatbot, temperature, top_k, top_p, max_tokens],
591
+ outputs=chatbot
592
+ ).then(lambda: "", None, msg)
593
+
594
+ clear.click(lambda: None, None, chatbot, queue=False)
595
+
596
+ gr.Markdown("""
597
+ ---
598
+ **Model Details:**
599
+ - Architecture: SAM1 Hybrid (Custom)
600
+ - Parameters: ~600M
601
+ - Context Length: 1024 tokens
602
+ - Format: `User: {query} Sam: {response}` (no newlines)
603
+
604
+ **Tips:**
605
+ - Lower temperature (0.3-0.5) for focused responses
606
+ - Higher temperature (0.8-1.2) for creative responses
607
+ - Adjust top-k/top-p for response diversity
608
+ """)
609
+
610
+ if __name__ == "__main__":
611
+ demo.queue().launch()