moonshine-streaming-demo / encoder_worker.js
Manjunath Kudlur
Ensure WebGPU enabled onnx runtime is loaded
2be96ef
/**
* Encoder Worker - Runs preprocessor + encoder in a separate thread
*/
importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js');
// Configure ONNX Runtime paths
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/';
const MODEL_CACHE_NAME = 'moonshine-models-v1';
// Helper to fetch model with progress reporting and caching
async function fetchModelWithProgress(url, modelName) {
// Try to get from cache first
try {
const cache = await caches.open(MODEL_CACHE_NAME);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const buffer = await cachedResponse.arrayBuffer();
self.postMessage({
type: 'progress',
model: modelName,
loaded: buffer.byteLength,
total: buffer.byteLength,
done: true,
cached: true
});
console.log(`${modelName} loaded from cache`);
return buffer;
}
} catch (e) {
console.warn('Cache API not available:', e.message);
}
// Fetch from network
const response = await fetch(url);
if (!response.ok) {
throw new Error(`Failed to fetch ${modelName}: ${response.status}`);
}
const contentLength = response.headers.get('Content-Length');
const total = contentLength ? parseInt(contentLength, 10) : 0;
if (!response.body || !total) {
// No streaming support or unknown size - just download
const buffer = await response.arrayBuffer();
self.postMessage({
type: 'progress',
model: modelName,
loaded: buffer.byteLength,
total: buffer.byteLength,
done: true
});
// Cache the response
try {
const cache = await caches.open(MODEL_CACHE_NAME);
await cache.put(url, new Response(buffer.slice(0)));
} catch (e) {
console.warn('Failed to cache model:', e.message);
}
return buffer;
}
const reader = response.body.getReader();
const chunks = [];
let loaded = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.length;
self.postMessage({
type: 'progress',
model: modelName,
loaded,
total,
done: false
});
}
self.postMessage({
type: 'progress',
model: modelName,
loaded: total,
total,
done: true
});
// Combine chunks into single ArrayBuffer
const result = new Uint8Array(loaded);
let offset = 0;
for (const chunk of chunks) {
result.set(chunk, offset);
offset += chunk.length;
}
// Cache the result
try {
const cache = await caches.open(MODEL_CACHE_NAME);
await cache.put(url, new Response(result.slice(0)));
console.log(`${modelName} cached`);
} catch (e) {
console.warn('Failed to cache model:', e.message);
}
return result.buffer;
}
// Model config
let cfg = null;
let preprocessor = null;
let encoder = null;
let tailLatency = 0;
// Preprocessor state
let prepSession = null;
let prepDim = 0;
let prepC1 = 0;
let prepStateC1 = null;
let prepStateC2 = null;
// Encoder state
let encSession = null;
let encDim = 0;
let encNPast = 0;
let encNFuture = 0;
let encEncoderDepth = 0;
let encContextSize = 0;
let encInputBuffer = [];
let encTotalInputFrames = 0;
let encTotalOutputFrames = 0;
function resetPreprocessor() {
if (prepStateC1) prepStateC1.fill(0);
if (prepStateC2) prepStateC2.fill(0);
}
function resetEncoder() {
encInputBuffer = [];
encTotalInputFrames = 0;
encTotalOutputFrames = 0;
}
async function processPreprocessor(audioChunk) {
const feeds = {
'audio_chunk': new ort.Tensor('float32', audioChunk, [1, audioChunk.length]),
'state_c1': new ort.Tensor('float32', prepStateC1, [1, 4, prepDim]),
'state_c2': new ort.Tensor('float32', prepStateC2, [1, 4, prepC1])
};
const results = await prepSession.run(feeds);
// Update states
prepStateC1.set(results.new_state_c1.data);
prepStateC2.set(results.new_state_c2.data);
return {
data: results.features.data,
dims: results.features.dims
};
}
async function processEncoder(melData, melDims, flush = true) {
const newFrames = melDims[1];
// Append new frames to buffer
for (let f = 0; f < newFrames; f++) {
const frame = new Float32Array(encDim);
for (let d = 0; d < encDim; d++) {
frame[d] = melData[f * encDim + d];
}
encInputBuffer.push(frame);
}
encTotalInputFrames += newFrames;
// Calculate output range
const canOutput = flush
? encTotalInputFrames
: Math.max(0, encTotalInputFrames - tailLatency);
const outputFrom = flush
? Math.max(0, encTotalOutputFrames - tailLatency)
: encTotalOutputFrames;
const newOutputCount = canOutput - outputFrom;
if (newOutputCount <= 0) {
return { data: new Float32Array(0), dims: [1, 0, encDim] };
}
// Prepare input buffer tensor
const bufferFrames = encInputBuffer.length;
const bufferData = new Float32Array(bufferFrames * encDim);
for (let f = 0; f < bufferFrames; f++) {
bufferData.set(encInputBuffer[f], f * encDim);
}
const feeds = {
'input': new ort.Tensor('float32', bufferData, [1, bufferFrames, encDim])
};
const results = await encSession.run(feeds);
const fullOutput = results.output;
// Calculate which frames to return
const bufStartFrame = encTotalInputFrames - bufferFrames;
const outputStart = outputFrom - bufStartFrame;
// Extract the subset of output
const resultData = new Float32Array(newOutputCount * encDim);
for (let f = 0; f < newOutputCount; f++) {
for (let d = 0; d < encDim; d++) {
resultData[f * encDim + d] = fullOutput.data[(outputStart + f) * encDim + d];
}
}
// Trim input buffer to context size
if (encInputBuffer.length > encContextSize) {
encInputBuffer = encInputBuffer.slice(-encContextSize);
}
encTotalOutputFrames = canOutput;
return { data: resultData, dims: [1, newOutputCount, encDim] };
}
// Message queue for sequential processing
const messageQueue = [];
let isProcessing = false;
async function processMessage(e) {
const { type, data } = e.data;
switch (type) {
case 'init': {
try {
cfg = data.cfg;
const onnxUrl = data.onnxUrl;
const modelName = data.modelName;
const backend = data.backend || 'wasm';
const dtype = 'fp32';
// Check WebGPU availability
if (backend === 'webgpu') {
if (typeof navigator !== 'undefined' && navigator.gpu) {
console.log('WebGPU navigator.gpu is available');
const adapter = await navigator.gpu.requestAdapter();
if (adapter) {
console.log('WebGPU adapter obtained:', adapter);
} else {
throw new Error('WebGPU adapter not available');
}
} else {
throw new Error('WebGPU not supported (navigator.gpu is undefined)');
}
}
const sessionOptions = { executionProviders: [backend] };
console.log(`Creating sessions with backend: ${backend}`);
tailLatency = cfg.n_future * cfg.encoder_depth;
// Initialize preprocessor
const prepUrl = `${onnxUrl}/preprocessor_streaming_${modelName}_${dtype}.onnx`;
self.postMessage({ type: 'status', message: 'Loading preprocessor...' });
self.postMessage({ type: 'model_start', model: 'Preprocessor' });
const prepBuffer = await fetchModelWithProgress(prepUrl, 'Preprocessor');
prepSession = await ort.InferenceSession.create(prepBuffer, sessionOptions);
self.postMessage({ type: 'model_done', model: 'Preprocessor' });
prepDim = cfg.dim;
prepC1 = 2 * cfg.dim;
prepStateC1 = new Float32Array(4 * cfg.dim);
prepStateC2 = new Float32Array(4 * prepC1);
// Initialize encoder
const encUrl = `${onnxUrl}/encoder_${modelName}_${dtype}.onnx`;
self.postMessage({ type: 'status', message: 'Loading encoder...' });
self.postMessage({ type: 'model_start', model: 'Encoder' });
const encBuffer = await fetchModelWithProgress(encUrl, 'Encoder');
encSession = await ort.InferenceSession.create(encBuffer, sessionOptions);
self.postMessage({ type: 'model_done', model: 'Encoder' });
encDim = cfg.dim;
encNPast = cfg.n_past;
encNFuture = cfg.n_future;
encEncoderDepth = cfg.encoder_depth;
encContextSize = cfg.encoder_depth * (cfg.n_past + cfg.n_future);
self.postMessage({ type: 'ready', backend: backend });
} catch (err) {
self.postMessage({ type: 'error', message: err.message });
}
break;
}
case 'segment_start': {
resetPreprocessor();
resetEncoder();
self.postMessage({
type: 'segment_start',
segmentId: data.segmentId
});
break;
}
case 'segment_end': {
self.postMessage({
type: 'segment_end',
segmentId: data.segmentId
});
break;
}
case 'audio': {
try {
// Process through preprocessor
const mel = await processPreprocessor(new Float32Array(data.audio));
const audioMs = (data.audio.length / 16000 * 1000).toFixed(0);
console.log(`Audio ${data.audio.length} samples (${audioMs}ms) β†’ Mel ${mel.dims[1]} frames`);
// Process through encoder with flush=true
const enc = await processEncoder(mel.data, mel.dims, true);
console.log(`Mel ${mel.dims[1]} frames β†’ Encoder ${enc.dims[1]} frames (accumulated: ${encTotalOutputFrames})`);
if (enc.dims[1] > 0) {
self.postMessage({
type: 'features',
segmentId: data.segmentId,
features: enc.data,
dims: enc.dims
}, [enc.data.buffer]); // Transfer ownership
}
} catch (err) {
console.error('Encoder error:', err);
}
break;
}
}
}
async function processQueue() {
if (isProcessing) return;
isProcessing = true;
while (messageQueue.length > 0) {
const msg = messageQueue.shift();
await processMessage(msg);
}
isProcessing = false;
}
self.onmessage = function(e) {
messageQueue.push(e);
processQueue();
};