/** * Decoder Worker - Runs adapter + decoder in a separate thread */ importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js'); // Configure ONNX Runtime paths ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/'; const MODEL_CACHE_NAME = 'moonshine-models-v1'; // Helper to fetch model with progress reporting and caching async function fetchModelWithProgress(url, modelName) { // Try to get from cache first try { const cache = await caches.open(MODEL_CACHE_NAME); const cachedResponse = await cache.match(url); if (cachedResponse) { const buffer = await cachedResponse.arrayBuffer(); self.postMessage({ type: 'progress', model: modelName, loaded: buffer.byteLength, total: buffer.byteLength, done: true, cached: true }); console.log(`${modelName} loaded from cache`); return buffer; } } catch (e) { console.warn('Cache API not available:', e.message); } // Fetch from network const response = await fetch(url); if (!response.ok) { throw new Error(`Failed to fetch ${modelName}: ${response.status}`); } const contentLength = response.headers.get('Content-Length'); const total = contentLength ? parseInt(contentLength, 10) : 0; if (!response.body || !total) { // No streaming support or unknown size - just download const buffer = await response.arrayBuffer(); self.postMessage({ type: 'progress', model: modelName, loaded: buffer.byteLength, total: buffer.byteLength, done: true }); // Cache the response try { const cache = await caches.open(MODEL_CACHE_NAME); await cache.put(url, new Response(buffer.slice(0))); } catch (e) { console.warn('Failed to cache model:', e.message); } return buffer; } const reader = response.body.getReader(); const chunks = []; let loaded = 0; while (true) { const { done, value } = await reader.read(); if (done) break; chunks.push(value); loaded += value.length; self.postMessage({ type: 'progress', model: modelName, loaded, total, done: false }); } self.postMessage({ type: 'progress', model: modelName, loaded: total, total, done: true }); // Combine chunks into single ArrayBuffer const result = new Uint8Array(loaded); let offset = 0; for (const chunk of chunks) { result.set(chunk, offset); offset += chunk.length; } // Cache the result try { const cache = await caches.open(MODEL_CACHE_NAME); await cache.put(url, new Response(result.slice(0))); console.log(`${modelName} cached`); } catch (e) { console.warn('Failed to cache model:', e.message); } return result.buffer; } // Model config let cfg = null; let tailLatency = 0; // Decoding config const TOKENS_PER_SECOND = 6.5; // Max tokens per second of audio const FRAME_DURATION_MS = 20; // Each encoder frame is 20ms // Check for repetitive token patterns that indicate decoding should stop function hasRepetition(tokens) { const len = tokens.length; if (len < 5) return false; // Check if last 5 tokens are the same const last5 = tokens.slice(-5); if (last5.every(t => t === last5[0])) { return true; } // Check for 3 repeated same pairs (e.g., [A,B,A,B,A,B]) if (len >= 6) { const pair1 = [tokens[len - 6], tokens[len - 5]]; const pair2 = [tokens[len - 4], tokens[len - 3]]; const pair3 = [tokens[len - 2], tokens[len - 1]]; if (pair1[0] === pair2[0] && pair2[0] === pair3[0] && pair1[1] === pair2[1] && pair2[1] === pair3[1]) { return true; } } // Check for 2 repeated same triples (e.g., [A,B,C,A,B,C]) if (len >= 6) { const triple1 = [tokens[len - 6], tokens[len - 5], tokens[len - 4]]; const triple2 = [tokens[len - 3], tokens[len - 2], tokens[len - 1]]; if (triple1[0] === triple2[0] && triple1[1] === triple2[1] && triple1[2] === triple2[2]) { return true; } } return false; } // Sessions let adapterSession = null; let decoderInitSession = null; let decoderStepSession = null; // Decoder state let crossCache = null; let selfCache = null; // Tokenizer let tokenizer = null; // Accumulated features let accumulatedFeatures = null; let currentSegmentId = null; // Live caption throttling to prevent pipeline backup let isDecoding = false; let lastDecodeTime = 0; let pendingDecode = false; const MIN_DECODE_INTERVAL_MS = 500; // Don't decode more often than every 500ms for live captions class MoonshineTokenizer { constructor() { this.decoder = null; this.vocab = null; } load(tokenizerJson) { this.vocab = tokenizerJson.model.vocab; this.decoder = Object.fromEntries( Object.entries(this.vocab).map(([k, v]) => [v, k]) ); } decode(tokenIds, skipSpecial = true) { const specialTokens = new Set([0, 1, 2]); let text = ''; for (const id of tokenIds) { if (skipSpecial && specialTokens.has(id)) continue; const token = this.decoder[id] || ''; text += token; } // Handle various space placeholder representations text = text.replace(/\u0120/g, ' '); // Ġ (GPT-2 style) text = text.replace(/Ġ/g, ' '); // Literal Ġ character text = text.replace(/▁/g, ' '); // SentencePiece style (U+2581) text = text.replace(/\u010a/g, '\n'); // Newline marker return text.trim(); } } async function runAdapter(features, dims) { const feeds = { 'encoder_output': new ort.Tensor('float32', features, dims) }; const results = await adapterSession.run(feeds); return results.context; } async function initDecoderCache(context) { const feeds = { 'context': context }; const results = await decoderInitSession.run(feeds); // Store cross-attention cache (even-indexed layers) crossCache = []; for (let i = 0; i < cfg.depth * 2; i++) { if ((i + 1) % 2 === 0) { crossCache.push({ k: results[`cache_${i}_k`], v: results[`cache_${i}_v`] }); } } // Initialize empty self-attention cache selfCache = []; for (let i = 0; i < cfg.depth; i++) { selfCache.push({ k: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]), v: new ort.Tensor('float32', new Float32Array(0), [1, cfg.nheads, 0, cfg.head_dim]) }); } } async function decodeStep(tokenId, position) { const feeds = { 'token_id': new ort.Tensor('int64', BigInt64Array.from([BigInt(tokenId)]), [1, 1]), 'position': new ort.Tensor('int64', BigInt64Array.from([BigInt(position)]), [1]) }; // Add cache inputs let selfIdx = 0; let crossIdx = 0; for (let i = 0; i < cfg.depth * 2; i++) { if ((i + 1) % 2 !== 0) { feeds[`in_cache_${i}_k`] = selfCache[selfIdx].k; feeds[`in_cache_${i}_v`] = selfCache[selfIdx].v; selfIdx++; } else { feeds[`in_cache_${i}_k`] = crossCache[crossIdx].k; feeds[`in_cache_${i}_v`] = crossCache[crossIdx].v; crossIdx++; } } const results = await decoderStepSession.run(feeds); // Update self-attention cache selfIdx = 0; for (let i = 0; i < cfg.depth * 2; i++) { if ((i + 1) % 2 !== 0) { selfCache[selfIdx] = { k: results[`out_cache_${i}_k`], v: results[`out_cache_${i}_v`] }; selfIdx++; } } return results.logits; } async function decodeAccumulated() { if (!accumulatedFeatures || accumulatedFeatures.dims[1] === 0) { return ''; } try { const context = await runAdapter(accumulatedFeatures.data, accumulatedFeatures.dims); await initDecoderCache(context); const numFrames = accumulatedFeatures.dims[1]; // Calculate duration in seconds and max tokens based on that const durationSeconds = (numFrames * FRAME_DURATION_MS) / 1000; const maxTokens = Math.max(10, Math.floor(durationSeconds * TOKENS_PER_SECOND)); const tokens = [1]; // BOS for (let step = 0; step < maxTokens; step++) { const logits = await decodeStep(tokens[tokens.length - 1], step); let maxIdx = 0; let maxVal = logits.data[0]; for (let i = 1; i < cfg.vocab_size; i++) { if (logits.data[i] > maxVal) { maxVal = logits.data[i]; maxIdx = i; } } tokens.push(maxIdx); // Stop on EOS if (maxIdx === 2) break; // Stop on repetitive patterns if (hasRepetition(tokens)) { console.log('Stopping decode due to repetition detected'); break; } } return tokenizer.decode(tokens, true); } catch (e) { console.error('Decode error:', e); return ''; } } // Helper to accumulate features data function accumulateFeaturesData(data) { const newFeatures = { data: new Float32Array(data.features), dims: data.dims }; if (accumulatedFeatures === null) { accumulatedFeatures = newFeatures; } else { // Trim last tailLatency frames from accumulated const numFrames = accumulatedFeatures.dims[1]; const keepFrames = Math.max(0, numFrames - tailLatency); if (keepFrames > 0) { const totalFrames = keepFrames + newFeatures.dims[1]; const combined = new Float32Array(totalFrames * cfg.dim); // Copy kept frames for (let f = 0; f < keepFrames; f++) { for (let d = 0; d < cfg.dim; d++) { combined[f * cfg.dim + d] = accumulatedFeatures.data[f * cfg.dim + d]; } } // Copy new frames combined.set(newFeatures.data, keepFrames * cfg.dim); accumulatedFeatures = { data: combined, dims: [1, totalFrames, cfg.dim] }; } else { accumulatedFeatures = newFeatures; } } } // Message queue for sequential processing const messageQueue = []; let isProcessingQueue = false; async function processMessage(e) { const { type, data } = e.data; switch (type) { case 'init': { try { cfg = data.cfg; const onnxUrl = data.onnxUrl; const modelName = data.modelName; const backend = data.backend || 'wasm'; const dtype = 'fp32'; const sessionOptions = { executionProviders: [backend] }; tailLatency = cfg.n_future * cfg.encoder_depth; // Load tokenizer self.postMessage({ type: 'status', message: 'Loading tokenizer...' }); self.postMessage({ type: 'model_start', model: 'Tokenizer' }); const tokenizerResponse = await fetch(`${onnxUrl}/tokenizer.json`); const tokenizerJson = await tokenizerResponse.json(); tokenizer = new MoonshineTokenizer(); tokenizer.load(tokenizerJson); self.postMessage({ type: 'model_done', model: 'Tokenizer' }); // Initialize adapter const adapterUrl = `${onnxUrl}/adapter_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading adapter...' }); self.postMessage({ type: 'model_start', model: 'Adapter' }); const adapterBuffer = await fetchModelWithProgress(adapterUrl, 'Adapter'); adapterSession = await ort.InferenceSession.create(adapterBuffer, sessionOptions); self.postMessage({ type: 'model_done', model: 'Adapter' }); // Initialize decoder init const decInitUrl = `${onnxUrl}/decoder_init_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading decoder (init)...' }); self.postMessage({ type: 'model_start', model: 'Decoder Init' }); const decInitBuffer = await fetchModelWithProgress(decInitUrl, 'Decoder Init'); decoderInitSession = await ort.InferenceSession.create(decInitBuffer, sessionOptions); self.postMessage({ type: 'model_done', model: 'Decoder Init' }); // Initialize decoder step const decStepUrl = `${onnxUrl}/decoder_step_${modelName}_${dtype}.onnx`; self.postMessage({ type: 'status', message: 'Loading decoder (step)...' }); self.postMessage({ type: 'model_start', model: 'Decoder Step' }); const decStepBuffer = await fetchModelWithProgress(decStepUrl, 'Decoder Step'); decoderStepSession = await ort.InferenceSession.create(decStepBuffer, sessionOptions); self.postMessage({ type: 'model_done', model: 'Decoder Step' }); self.postMessage({ type: 'ready', backend: backend }); } catch (err) { self.postMessage({ type: 'error', message: err.message }); } break; } case 'segment_start': { accumulatedFeatures = null; currentSegmentId = data.segmentId; isDecoding = false; lastDecodeTime = 0; pendingDecode = false; self.postMessage({ type: 'live_caption', text: '' }); break; } case 'segment_end': { if (data.segmentId !== currentSegmentId) break; // Wait for any in-progress decode to finish before final decode while (isDecoding) { await new Promise(resolve => setTimeout(resolve, 50)); } isDecoding = true; const text = await decodeAccumulated(); isDecoding = false; self.postMessage({ type: 'transcript', segmentId: data.segmentId, text: text }); accumulatedFeatures = null; currentSegmentId = null; self.postMessage({ type: 'live_caption', text: '' }); break; } case 'features': { if (data.segmentId !== currentSegmentId) break; // Accumulate this message's features accumulateFeaturesData(data); // Drain all pending features messages from the queue and accumulate them too while (messageQueue.length > 0 && messageQueue[0].data.type === 'features') { const nextMsg = messageQueue.shift(); const nextData = nextMsg.data.data; if (nextData.segmentId === currentSegmentId) { accumulateFeaturesData(nextData); } } console.log(`Decoder accumulated features, total: ${accumulatedFeatures ? accumulatedFeatures.dims[1] : 0} frames`); // Live caption with throttling const now = Date.now(); const timeSinceLastDecode = now - lastDecodeTime; if (isDecoding) { // Already decoding, mark that we need another decode when done pendingDecode = true; } else if (timeSinceLastDecode >= MIN_DECODE_INTERVAL_MS) { // Enough time has passed, decode now isDecoding = true; lastDecodeTime = now; try { const partialText = await decodeAccumulated(); self.postMessage({ type: 'live_caption', text: partialText }); } finally { isDecoding = false; // If there was a pending decode request, schedule it if (pendingDecode) { pendingDecode = false; setTimeout(async () => { if (!isDecoding && currentSegmentId !== null) { isDecoding = true; lastDecodeTime = Date.now(); try { const text = await decodeAccumulated(); self.postMessage({ type: 'live_caption', text: text }); } finally { isDecoding = false; } } }, MIN_DECODE_INTERVAL_MS); } } } else { // Too soon since last decode, mark pending pendingDecode = true; } break; } } } async function processQueue() { if (isProcessingQueue) return; isProcessingQueue = true; while (messageQueue.length > 0) { const msg = messageQueue.shift(); await processMessage(msg); } isProcessingQueue = false; } self.onmessage = function(e) { messageQueue.push(e); processQueue(); };