Manjunath Kudlur commited on
Commit
82b533a
·
1 Parent(s): 25b818c

Add download progress tracking

Browse files
Files changed (4) hide show
  1. decoder_worker.js +353 -0
  2. encoder_worker.js +74 -6
  3. index.html +44 -0
  4. streaming_asr.js +92 -13
decoder_worker.js ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Decoder Worker - Runs adapter + decoder in a separate thread
3
+ */
4
+
5
+ importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/ort.min.js');
6
+
7
+ // Configure ONNX Runtime to find WASM files from CDN
8
+ ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/';
9
+
10
+ // Helper to fetch model with progress reporting
11
+ async function fetchModelWithProgress(url, modelName) {
12
+ const response = await fetch(url);
13
+ if (!response.ok) {
14
+ throw new Error(`Failed to fetch ${modelName}: ${response.status}`);
15
+ }
16
+
17
+ const contentLength = response.headers.get('Content-Length');
18
+ const total = contentLength ? parseInt(contentLength, 10) : 0;
19
+
20
+ if (!response.body || !total) {
21
+ // No streaming support or unknown size - just download
22
+ const buffer = await response.arrayBuffer();
23
+ self.postMessage({
24
+ type: 'progress',
25
+ model: modelName,
26
+ loaded: buffer.byteLength,
27
+ total: buffer.byteLength,
28
+ done: true
29
+ });
30
+ return buffer;
31
+ }
32
+
33
+ const reader = response.body.getReader();
34
+ const chunks = [];
35
+ let loaded = 0;
36
+
37
+ while (true) {
38
+ const { done, value } = await reader.read();
39
+ if (done) break;
40
+
41
+ chunks.push(value);
42
+ loaded += value.length;
43
+
44
+ self.postMessage({
45
+ type: 'progress',
46
+ model: modelName,
47
+ loaded,
48
+ total,
49
+ done: false
50
+ });
51
+ }
52
+
53
+ self.postMessage({
54
+ type: 'progress',
55
+ model: modelName,
56
+ loaded: total,
57
+ total,
58
+ done: true
59
+ });
60
+
61
+ // Combine chunks into single ArrayBuffer
62
+ const result = new Uint8Array(loaded);
63
+ let offset = 0;
64
+ for (const chunk of chunks) {
65
+ result.set(chunk, offset);
66
+ offset += chunk.length;
67
+ }
68
+
69
+ return result.buffer;
70
+ }
71
+
72
+ // Model config
73
+ let cfg = null;
74
+ let tailLatency = 0;
75
+
76
+ // Sessions
77
+ let adapterSession = null;
78
+ let decoderInitSession = null;
79
+ let decoderStepSession = null;
80
+
81
+ // Decoder state
82
+ let crossCache = null;
83
+ let selfCache = null;
84
+
85
+ // Tokenizer
86
+ let tokenizer = null;
87
+
88
+ // Accumulated features
89
+ let accumulatedFeatures = null;
90
+ let currentSegmentId = null;
91
+
92
+ class MoonshineTokenizer {
93
+ constructor() {
94
+ this.decoder = null;
95
+ this.vocab = null;
96
+ }
97
+
98
+ load(tokenizerJson) {
99
+ this.vocab = tokenizerJson.model.vocab;
100
+ this.decoder = Object.fromEntries(
101
+ Object.entries(this.vocab).map(([k, v]) => [v, k])
102
+ );
103
+ }
104
+
105
+ decode(tokenIds, skipSpecial = true) {
106
+ const specialTokens = new Set([0, 1, 2]);
107
+ let text = '';
108
+
109
+ for (const id of tokenIds) {
110
+ if (skipSpecial && specialTokens.has(id)) continue;
111
+ const token = this.decoder[id] || '';
112
+ text += token;
113
+ }
114
+
115
+ // Handle various space placeholder representations
116
+ text = text.replace(/\u0120/g, ' '); // Ġ (GPT-2 style)
117
+ text = text.replace(/Ġ/g, ' '); // Literal Ġ character
118
+ text = text.replace(/▁/g, ' '); // SentencePiece style (U+2581)
119
+ text = text.replace(/\u010a/g, '\n'); // Newline marker
120
+
121
+ return text.trim();
122
+ }
123
+ }
124
+
125
+ async function runAdapter(features, dims) {
126
+ const feeds = {
127
+ 'encoder_output': new ort.Tensor('float32', features, dims)
128
+ };
129
+ const results = await adapterSession.run(feeds);
130
+ return results.context;
131
+ }
132
+
133
+ async function initDecoderCache(context) {
134
+ const feeds = { 'context': context };
135
+ const results = await decoderInitSession.run(feeds);
136
+
137
+ // Store cross-attention cache (even-indexed layers)
138
+ crossCache = [];
139
+ for (let i = 0; i < cfg.depth * 2; i++) {
140
+ if ((i + 1) % 2 === 0) {
141
+ crossCache.push({
142
+ k: results[`cache_${i}_k`],
143
+ v: results[`cache_${i}_v`]
144
+ });
145
+ }
146
+ }
147
+
148
+ // Initialize empty self-attention cache
149
+ selfCache = [];
150
+ for (let i = 0; i < cfg.depth; i++) {
151
+ selfCache.push({
152
+ k: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]),
153
+ v: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim])
154
+ });
155
+ }
156
+ }
157
+
158
+ async function decodeStep(tokenId, position) {
159
+ const feeds = {
160
+ 'token_id': new ort.Tensor('int64', BigInt64Array.from([BigInt(tokenId)]), [1, 1]),
161
+ 'position': new ort.Tensor('int64', BigInt64Array.from([BigInt(position)]), [1])
162
+ };
163
+
164
+ // Add cache inputs
165
+ let selfIdx = 0;
166
+ let crossIdx = 0;
167
+ for (let i = 0; i < cfg.depth * 2; i++) {
168
+ if ((i + 1) % 2 !== 0) {
169
+ feeds[`in_cache_${i}_k`] = selfCache[selfIdx].k;
170
+ feeds[`in_cache_${i}_v`] = selfCache[selfIdx].v;
171
+ selfIdx++;
172
+ } else {
173
+ feeds[`in_cache_${i}_k`] = crossCache[crossIdx].k;
174
+ feeds[`in_cache_${i}_v`] = crossCache[crossIdx].v;
175
+ crossIdx++;
176
+ }
177
+ }
178
+
179
+ const results = await decoderStepSession.run(feeds);
180
+
181
+ // Update self-attention cache
182
+ selfIdx = 0;
183
+ for (let i = 0; i < cfg.depth * 2; i++) {
184
+ if ((i + 1) % 2 !== 0) {
185
+ selfCache[selfIdx] = {
186
+ k: results[`out_cache_${i}_k`],
187
+ v: results[`out_cache_${i}_v`]
188
+ };
189
+ selfIdx++;
190
+ }
191
+ }
192
+
193
+ return results.logits;
194
+ }
195
+
196
+ async function decodeAccumulated() {
197
+ if (!accumulatedFeatures || accumulatedFeatures.dims[1] === 0) {
198
+ return '';
199
+ }
200
+
201
+ try {
202
+ const context = await runAdapter(accumulatedFeatures.data, accumulatedFeatures.dims);
203
+ await initDecoderCache(context);
204
+
205
+ const numFrames = accumulatedFeatures.dims[1];
206
+ const maxTokens = Math.max(10, Math.floor(numFrames * 1.5));
207
+
208
+ const tokens = [1]; // BOS
209
+ for (let step = 0; step < maxTokens; step++) {
210
+ const logits = await decodeStep(tokens[tokens.length - 1], step);
211
+
212
+ let maxIdx = 0;
213
+ let maxVal = logits.data[0];
214
+ for (let i = 1; i < cfg.vocab_size; i++) {
215
+ if (logits.data[i] > maxVal) {
216
+ maxVal = logits.data[i];
217
+ maxIdx = i;
218
+ }
219
+ }
220
+
221
+ tokens.push(maxIdx);
222
+ if (maxIdx === 2) break; // EOS
223
+ }
224
+
225
+ return tokenizer.decode(tokens, true);
226
+ } catch (e) {
227
+ console.error('Decode error:', e);
228
+ return '';
229
+ }
230
+ }
231
+
232
+ self.onmessage = async function(e) {
233
+ const { type, data } = e.data;
234
+
235
+ switch (type) {
236
+ case 'init': {
237
+ try {
238
+ cfg = data.cfg;
239
+ const onnxUrl = data.onnxUrl;
240
+ const modelName = data.modelName;
241
+ const dtype = 'fp32';
242
+
243
+ tailLatency = cfg.n_future * cfg.encoder_depth;
244
+
245
+ // Load tokenizer
246
+ self.postMessage({ type: 'status', message: 'Loading tokenizer...' });
247
+ self.postMessage({ type: 'model_start', model: 'Tokenizer' });
248
+ const tokenizerResponse = await fetch(`${onnxUrl}/tokenizer.json`);
249
+ const tokenizerJson = await tokenizerResponse.json();
250
+ tokenizer = new MoonshineTokenizer();
251
+ tokenizer.load(tokenizerJson);
252
+ self.postMessage({ type: 'model_done', model: 'Tokenizer' });
253
+
254
+ // Initialize adapter
255
+ const adapterUrl = `${onnxUrl}/adapter_${modelName}_${dtype}.onnx`;
256
+ self.postMessage({ type: 'status', message: 'Loading adapter...' });
257
+ self.postMessage({ type: 'model_start', model: 'Adapter' });
258
+ const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter');
259
+ adapterSession = await ort.InferenceSession.create(adapterBuffer);
260
+ self.postMessage({ type: 'model_done', model: 'Adapter' });
261
+
262
+ // Initialize decoder init
263
+ const decInitUrl = `${onnxUrl}/decoder_init_${modelName}_${dtype}.onnx`;
264
+ self.postMessage({ type: 'status', message: 'Loading decoder (init)...' });
265
+ self.postMessage({ type: 'model_start', model: 'Decoder Init' });
266
+ const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init');
267
+ decoderInitSession = await ort.InferenceSession.create(decInitBuffer);
268
+ self.postMessage({ type: 'model_done', model: 'Decoder Init' });
269
+
270
+ // Initialize decoder step
271
+ const decStepUrl = `${onnxUrl}/decoder_step_${modelName}_${dtype}.onnx`;
272
+ self.postMessage({ type: 'status', message: 'Loading decoder (step)...' });
273
+ self.postMessage({ type: 'model_start', model: 'Decoder Step' });
274
+ const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step');
275
+ decoderStepSession = await ort.InferenceSession.create(decStepBuffer);
276
+ self.postMessage({ type: 'model_done', model: 'Decoder Step' });
277
+
278
+ self.postMessage({ type: 'ready' });
279
+ } catch (err) {
280
+ self.postMessage({ type: 'error', message: err.message });
281
+ }
282
+ break;
283
+ }
284
+
285
+ case 'segment_start': {
286
+ accumulatedFeatures = null;
287
+ currentSegmentId = data.segmentId;
288
+ self.postMessage({ type: 'live_caption', text: '' });
289
+ break;
290
+ }
291
+
292
+ case 'segment_end': {
293
+ if (data.segmentId !== currentSegmentId) break;
294
+
295
+ const text = await decodeAccumulated();
296
+ self.postMessage({
297
+ type: 'transcript',
298
+ segmentId: data.segmentId,
299
+ text: text
300
+ });
301
+
302
+ accumulatedFeatures = null;
303
+ currentSegmentId = null;
304
+ self.postMessage({ type: 'live_caption', text: '' });
305
+ break;
306
+ }
307
+
308
+ case 'features': {
309
+ if (data.segmentId !== currentSegmentId) break;
310
+
311
+ const newFeatures = {
312
+ data: new Float32Array(data.features),
313
+ dims: data.dims
314
+ };
315
+
316
+ console.log(`Decoder received ${data.dims[1]} frames, accumulated: ${accumulatedFeatures ? accumulatedFeatures.dims[1] : 0}`);
317
+
318
+ if (accumulatedFeatures === null) {
319
+ accumulatedFeatures = newFeatures;
320
+ } else {
321
+ // Trim last tailLatency frames from accumulated
322
+ const numFrames = accumulatedFeatures.dims[1];
323
+ const keepFrames = Math.max(0, numFrames - tailLatency);
324
+
325
+ if (keepFrames > 0) {
326
+ const totalFrames = keepFrames + newFeatures.dims[1];
327
+ const combined = new Float32Array(totalFrames * cfg.dim);
328
+
329
+ // Copy kept frames
330
+ for (let f = 0; f < keepFrames; f++) {
331
+ for (let d = 0; d < cfg.dim; d++) {
332
+ combined[f * cfg.dim + d] = accumulatedFeatures.data[f * cfg.dim + d];
333
+ }
334
+ }
335
+ // Copy new frames
336
+ combined.set(newFeatures.data, keepFrames * cfg.dim);
337
+
338
+ accumulatedFeatures = {
339
+ data: combined,
340
+ dims: [1, totalFrames, cfg.dim]
341
+ };
342
+ } else {
343
+ accumulatedFeatures = newFeatures;
344
+ }
345
+ }
346
+
347
+ // Live caption
348
+ const partialText = await decodeAccumulated();
349
+ self.postMessage({ type: 'live_caption', text: partialText });
350
+ break;
351
+ }
352
+ }
353
+ };
encoder_worker.js CHANGED
@@ -7,6 +7,68 @@ importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/ort.min.
7
  // Configure ONNX Runtime to find WASM files from CDN
8
  ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/';
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  // Model config
11
  let cfg = null;
12
  let preprocessor = null;
@@ -139,20 +201,26 @@ self.onmessage = async function(e) {
139
  tailLatency = cfg.n_future * cfg.encoder_depth;
140
 
141
  // Initialize preprocessor
 
142
  self.postMessage({ type: 'status', message: 'Loading preprocessor...' });
143
- prepSession = await ort.InferenceSession.create(
144
- `${onnxUrl}/preprocessor_streaming_${modelName}_${dtype}.onnx`
145
- );
 
 
146
  prepDim = cfg.dim;
147
  prepC1 = 2 * cfg.dim;
148
  prepStateC1 = new Float32Array(4 * cfg.dim);
149
  prepStateC2 = new Float32Array(4 * prepC1);
150
 
151
  // Initialize encoder
 
152
  self.postMessage({ type: 'status', message: 'Loading encoder...' });
153
- encSession = await ort.InferenceSession.create(
154
- `${onnxUrl}/encoder_${modelName}_${dtype}.onnx`
155
- );
 
 
156
  encDim = cfg.dim;
157
  encNPast = cfg.n_past;
158
  encNFuture = cfg.n_future;
 
7
  // Configure ONNX Runtime to find WASM files from CDN
8
  ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/';
9
 
10
+ // Helper to fetch model with progress reporting
11
+ async function fetchModelWithProgress(url, modelName) {
12
+ const response = await fetch(url);
13
+ if (!response.ok) {
14
+ throw new Error(`Failed to fetch ${modelName}: ${response.status}`);
15
+ }
16
+
17
+ const contentLength = response.headers.get('Content-Length');
18
+ const total = contentLength ? parseInt(contentLength, 10) : 0;
19
+
20
+ if (!response.body || !total) {
21
+ // No streaming support or unknown size - just download
22
+ const buffer = await response.arrayBuffer();
23
+ self.postMessage({
24
+ type: 'progress',
25
+ model: modelName,
26
+ loaded: buffer.byteLength,
27
+ total: buffer.byteLength,
28
+ done: true
29
+ });
30
+ return buffer;
31
+ }
32
+
33
+ const reader = response.body.getReader();
34
+ const chunks = [];
35
+ let loaded = 0;
36
+
37
+ while (true) {
38
+ const { done, value } = await reader.read();
39
+ if (done) break;
40
+
41
+ chunks.push(value);
42
+ loaded += value.length;
43
+
44
+ self.postMessage({
45
+ type: 'progress',
46
+ model: modelName,
47
+ loaded,
48
+ total,
49
+ done: false
50
+ });
51
+ }
52
+
53
+ self.postMessage({
54
+ type: 'progress',
55
+ model: modelName,
56
+ loaded: total,
57
+ total,
58
+ done: true
59
+ });
60
+
61
+ // Combine chunks into single ArrayBuffer
62
+ const result = new Uint8Array(loaded);
63
+ let offset = 0;
64
+ for (const chunk of chunks) {
65
+ result.set(chunk, offset);
66
+ offset += chunk.length;
67
+ }
68
+
69
+ return result.buffer;
70
+ }
71
+
72
  // Model config
73
  let cfg = null;
74
  let preprocessor = null;
 
201
  tailLatency = cfg.n_future * cfg.encoder_depth;
202
 
203
  // Initialize preprocessor
204
+ const prepUrl = `${onnxUrl}/preprocessor_streaming_${modelName}_${dtype}.onnx`;
205
  self.postMessage({ type: 'status', message: 'Loading preprocessor...' });
206
+ self.postMessage({ type: 'model_start', model: 'Preprocessor' });
207
+ const prepBuffer = await fetchModelWithProgress(prepUrl, 'Preprocessor');
208
+ prepSession = await ort.InferenceSession.create(prepBuffer);
209
+ self.postMessage({ type: 'model_done', model: 'Preprocessor' });
210
+
211
  prepDim = cfg.dim;
212
  prepC1 = 2 * cfg.dim;
213
  prepStateC1 = new Float32Array(4 * cfg.dim);
214
  prepStateC2 = new Float32Array(4 * prepC1);
215
 
216
  // Initialize encoder
217
+ const encUrl = `${onnxUrl}/encoder_${modelName}_${dtype}.onnx`;
218
  self.postMessage({ type: 'status', message: 'Loading encoder...' });
219
+ self.postMessage({ type: 'model_start', model: 'Encoder' });
220
+ const encBuffer = await fetchModelWithProgress(encUrl, 'Encoder');
221
+ encSession = await ort.InferenceSession.create(encBuffer);
222
+ self.postMessage({ type: 'model_done', model: 'Encoder' });
223
+
224
  encDim = cfg.dim;
225
  encNPast = cfg.n_past;
226
  encNFuture = cfg.n_future;
index.html CHANGED
@@ -325,6 +325,43 @@
325
 
326
  .loading-text {
327
  color: #00d4ff;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  }
329
 
330
  .error-message {
@@ -346,6 +383,13 @@
346
  <div class="loading-content">
347
  <div class="loading-spinner"></div>
348
  <div class="loading-text" id="loadingText">Loading models...</div>
 
 
 
 
 
 
 
349
  </div>
350
  </div>
351
 
 
325
 
326
  .loading-text {
327
  color: #00d4ff;
328
+ font-size: 18px;
329
+ margin-bottom: 20px;
330
+ }
331
+
332
+ .loading-progress {
333
+ width: 300px;
334
+ margin: 0 auto;
335
+ }
336
+
337
+ .loading-progress-bar {
338
+ height: 8px;
339
+ background: #333;
340
+ border-radius: 4px;
341
+ overflow: hidden;
342
+ margin-bottom: 10px;
343
+ }
344
+
345
+ .loading-progress-fill {
346
+ height: 100%;
347
+ background: linear-gradient(90deg, #00d4ff, #00ff88);
348
+ width: 0%;
349
+ transition: width 0.3s ease;
350
+ border-radius: 4px;
351
+ }
352
+
353
+ .loading-progress-text {
354
+ font-size: 13px;
355
+ color: #888;
356
+ margin-bottom: 15px;
357
+ }
358
+
359
+ .loading-details {
360
+ font-size: 12px;
361
+ color: #666;
362
+ font-family: monospace;
363
+ max-height: 60px;
364
+ overflow: hidden;
365
  }
366
 
367
  .error-message {
 
383
  <div class="loading-content">
384
  <div class="loading-spinner"></div>
385
  <div class="loading-text" id="loadingText">Loading models...</div>
386
+ <div class="loading-progress">
387
+ <div class="loading-progress-bar">
388
+ <div class="loading-progress-fill" id="loadingProgressFill"></div>
389
+ </div>
390
+ <div class="loading-progress-text" id="loadingProgressText">0 / 7 models</div>
391
+ </div>
392
+ <div class="loading-details" id="loadingDetails"></div>
393
  </div>
394
  </div>
395
 
streaming_asr.js CHANGED
@@ -231,36 +231,73 @@ class PipelinedStreamingASR {
231
  this.onQueueUpdate = null;
232
  }
233
 
234
- async loadModels(progressCallback) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  // Initialize VAD
236
  try {
 
237
  progressCallback?.('Loading TenVAD...');
 
238
  this.vad = new TenVAD(VAD_CHUNK_SAMPLES, 0.5);
239
  await this.vad.init('./ten_vad.js');
240
  console.log('Using TenVAD');
 
 
241
  } catch (e) {
242
  console.warn('TenVAD failed, using SimpleVAD:', e.message);
243
  this.vad = new SimpleVAD(SAMPLE_RATE, VAD_CHUNK_SAMPLES);
244
  await this.vad.init();
 
 
245
  }
246
 
247
- // Initialize Encoder Worker
248
- progressCallback?.('Loading encoder...');
249
- await this.initEncoderWorker();
 
 
 
 
 
 
 
250
 
251
- // Initialize Decoder Worker
252
- progressCallback?.('Loading decoder...');
253
- await this.initDecoderWorker();
 
 
 
 
 
 
 
254
 
255
  progressCallback?.('Ready!');
256
  }
257
 
258
- initEncoderWorker() {
259
  return new Promise((resolve, reject) => {
260
  this.encoderWorker = new Worker('./encoder_worker.js');
261
 
262
  this.encoderWorker.onmessage = (e) => {
263
- const { type, data } = e.data;
264
 
265
  switch (type) {
266
  case 'ready':
@@ -273,6 +310,12 @@ class PipelinedStreamingASR {
273
  case 'status':
274
  // Progress update from worker
275
  break;
 
 
 
 
 
 
276
  case 'segment_start':
277
  this.decoderWorker?.postMessage({ type: 'segment_start', data: { segmentId: e.data.segmentId } });
278
  break;
@@ -304,7 +347,7 @@ class PipelinedStreamingASR {
304
  });
305
  }
306
 
307
- initDecoderWorker() {
308
  return new Promise((resolve, reject) => {
309
  this.decoderWorker = new Worker('./decoder_worker.js');
310
 
@@ -321,6 +364,12 @@ class PipelinedStreamingASR {
321
  break;
322
  case 'status':
323
  break;
 
 
 
 
 
 
324
  case 'transcript':
325
  this.onTranscript?.(e.data.text, e.data.segmentId);
326
  break;
@@ -610,6 +659,9 @@ class ASRDemoUI {
610
  initElements() {
611
  this.loadingOverlay = document.getElementById('loadingOverlay');
612
  this.loadingText = document.getElementById('loadingText');
 
 
 
613
  this.errorMessage = document.getElementById('errorMessage');
614
  this.statusDot = document.getElementById('statusDot');
615
  this.statusText = document.getElementById('statusText');
@@ -662,9 +714,14 @@ class ASRDemoUI {
662
  this.asr.onLiveCaption = (text) => this.updateLiveCaption(text);
663
  this.asr.onStatusUpdate = (status, text) => this.updateStatus(status, text);
664
 
665
- await this.asr.loadModels((text) => {
666
- this.loadingText.textContent = text;
667
- });
 
 
 
 
 
668
 
669
  await this.asr.start();
670
 
@@ -769,6 +826,9 @@ class ASRDemoUI {
769
 
770
  showLoading(text) {
771
  this.loadingText.textContent = text;
 
 
 
772
  this.loadingOverlay.classList.remove('hidden');
773
  }
774
 
@@ -776,6 +836,25 @@ class ASRDemoUI {
776
  this.loadingOverlay.classList.add('hidden');
777
  }
778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  showError(message) {
780
  this.errorMessage.textContent = message;
781
  this.errorMessage.classList.add('visible');
 
231
  this.onQueueUpdate = null;
232
  }
233
 
234
+ async loadModels(progressCallback, detailedProgressCallback) {
235
+ // Track overall progress
236
+ const totalModels = 7; // VAD, Preprocessor, Encoder, Tokenizer, Adapter, Decoder Init, Decoder Step
237
+ let completedModels = 0;
238
+ let currentModel = '';
239
+ let currentProgress = { loaded: 0, total: 0 };
240
+
241
+ const updateProgress = () => {
242
+ const overallPercent = (completedModels / totalModels) * 100;
243
+ detailedProgressCallback?.({
244
+ completedModels,
245
+ totalModels,
246
+ overallPercent,
247
+ currentModel,
248
+ currentProgress
249
+ });
250
+ };
251
+
252
  // Initialize VAD
253
  try {
254
+ currentModel = 'VAD';
255
  progressCallback?.('Loading TenVAD...');
256
+ updateProgress();
257
  this.vad = new TenVAD(VAD_CHUNK_SAMPLES, 0.5);
258
  await this.vad.init('./ten_vad.js');
259
  console.log('Using TenVAD');
260
+ completedModels++;
261
+ updateProgress();
262
  } catch (e) {
263
  console.warn('TenVAD failed, using SimpleVAD:', e.message);
264
  this.vad = new SimpleVAD(SAMPLE_RATE, VAD_CHUNK_SAMPLES);
265
  await this.vad.init();
266
+ completedModels++;
267
+ updateProgress();
268
  }
269
 
270
+ // Initialize Encoder Worker with progress tracking
271
+ progressCallback?.('Loading encoder models...');
272
+ await this.initEncoderWorker((model, progress) => {
273
+ currentModel = model;
274
+ currentProgress = progress;
275
+ updateProgress();
276
+ }, () => {
277
+ completedModels++;
278
+ updateProgress();
279
+ });
280
 
281
+ // Initialize Decoder Worker with progress tracking
282
+ progressCallback?.('Loading decoder models...');
283
+ await this.initDecoderWorker((model, progress) => {
284
+ currentModel = model;
285
+ currentProgress = progress;
286
+ updateProgress();
287
+ }, () => {
288
+ completedModels++;
289
+ updateProgress();
290
+ });
291
 
292
  progressCallback?.('Ready!');
293
  }
294
 
295
+ initEncoderWorker(onProgress, onModelDone) {
296
  return new Promise((resolve, reject) => {
297
  this.encoderWorker = new Worker('./encoder_worker.js');
298
 
299
  this.encoderWorker.onmessage = (e) => {
300
+ const { type } = e.data;
301
 
302
  switch (type) {
303
  case 'ready':
 
310
  case 'status':
311
  // Progress update from worker
312
  break;
313
+ case 'progress':
314
+ onProgress?.(e.data.model, { loaded: e.data.loaded, total: e.data.total });
315
+ break;
316
+ case 'model_done':
317
+ onModelDone?.(e.data.model);
318
+ break;
319
  case 'segment_start':
320
  this.decoderWorker?.postMessage({ type: 'segment_start', data: { segmentId: e.data.segmentId } });
321
  break;
 
347
  });
348
  }
349
 
350
+ initDecoderWorker(onProgress, onModelDone) {
351
  return new Promise((resolve, reject) => {
352
  this.decoderWorker = new Worker('./decoder_worker.js');
353
 
 
364
  break;
365
  case 'status':
366
  break;
367
+ case 'progress':
368
+ onProgress?.(e.data.model, { loaded: e.data.loaded, total: e.data.total });
369
+ break;
370
+ case 'model_done':
371
+ onModelDone?.(e.data.model);
372
+ break;
373
  case 'transcript':
374
  this.onTranscript?.(e.data.text, e.data.segmentId);
375
  break;
 
659
  initElements() {
660
  this.loadingOverlay = document.getElementById('loadingOverlay');
661
  this.loadingText = document.getElementById('loadingText');
662
+ this.loadingProgressFill = document.getElementById('loadingProgressFill');
663
+ this.loadingProgressText = document.getElementById('loadingProgressText');
664
+ this.loadingDetails = document.getElementById('loadingDetails');
665
  this.errorMessage = document.getElementById('errorMessage');
666
  this.statusDot = document.getElementById('statusDot');
667
  this.statusText = document.getElementById('statusText');
 
714
  this.asr.onLiveCaption = (text) => this.updateLiveCaption(text);
715
  this.asr.onStatusUpdate = (status, text) => this.updateStatus(status, text);
716
 
717
+ await this.asr.loadModels(
718
+ (text) => {
719
+ this.loadingText.textContent = text;
720
+ },
721
+ (progress) => {
722
+ this.updateLoadingProgress(progress);
723
+ }
724
+ );
725
 
726
  await this.asr.start();
727
 
 
826
 
827
  showLoading(text) {
828
  this.loadingText.textContent = text;
829
+ this.loadingProgressFill.style.width = '0%';
830
+ this.loadingProgressText.textContent = '0 / 7 models';
831
+ this.loadingDetails.textContent = '';
832
  this.loadingOverlay.classList.remove('hidden');
833
  }
834
 
 
836
  this.loadingOverlay.classList.add('hidden');
837
  }
838
 
839
+ updateLoadingProgress(progress) {
840
+ const { completedModels, totalModels, currentModel, currentProgress } = progress;
841
+
842
+ // Update overall progress bar
843
+ const overallPercent = (completedModels / totalModels) * 100;
844
+ this.loadingProgressFill.style.width = `${overallPercent}%`;
845
+ this.loadingProgressText.textContent = `${completedModels} / ${totalModels} models`;
846
+
847
+ // Update details with current model and byte progress
848
+ if (currentModel && currentProgress.total > 0) {
849
+ const loadedMB = (currentProgress.loaded / (1024 * 1024)).toFixed(1);
850
+ const totalMB = (currentProgress.total / (1024 * 1024)).toFixed(1);
851
+ const percent = Math.round((currentProgress.loaded / currentProgress.total) * 100);
852
+ this.loadingDetails.textContent = `${currentModel}: ${loadedMB} / ${totalMB} MB (${percent}%)`;
853
+ } else if (currentModel) {
854
+ this.loadingDetails.textContent = `Loading ${currentModel}...`;
855
+ }
856
+ }
857
+
858
  showError(message) {
859
  this.errorMessage.textContent = message;
860
  this.errorMessage.classList.add('visible');