|
|
|
|
|
|
|
|
|
|
|
|
|
|
importScripts('https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js'); |
|
|
|
|
|
|
|
|
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/'; |
|
|
|
|
|
const MODEL_CACHE_NAME = 'moonshine-models-v1'; |
|
|
|
|
|
|
|
|
async function fetchModelWithProgress(url, modelName) { |
|
|
|
|
|
try { |
|
|
const cache = await caches.open(MODEL_CACHE_NAME); |
|
|
const cachedResponse = await cache.match(url); |
|
|
|
|
|
if (cachedResponse) { |
|
|
const buffer = await cachedResponse.arrayBuffer(); |
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded: buffer.byteLength, |
|
|
total: buffer.byteLength, |
|
|
done: true, |
|
|
cached: true |
|
|
}); |
|
|
console.log(`${modelName} loaded from cache`); |
|
|
return buffer; |
|
|
} |
|
|
} catch (e) { |
|
|
console.warn('Cache API not available:', e.message); |
|
|
} |
|
|
|
|
|
|
|
|
const response = await fetch(url); |
|
|
if (!response.ok) { |
|
|
throw new Error(`Failed to fetch ${modelName}: ${response.status}`); |
|
|
} |
|
|
|
|
|
const contentLength = response.headers.get('Content-Length'); |
|
|
const total = contentLength ? parseInt(contentLength, 10) : 0; |
|
|
|
|
|
if (!response.body || !total) { |
|
|
|
|
|
const buffer = await response.arrayBuffer(); |
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded: buffer.byteLength, |
|
|
total: buffer.byteLength, |
|
|
done: true |
|
|
}); |
|
|
|
|
|
try { |
|
|
const cache = await caches.open(MODEL_CACHE_NAME); |
|
|
await cache.put(url, new Response(buffer.slice(0))); |
|
|
} catch (e) { |
|
|
console.warn('Failed to cache model:', e.message); |
|
|
} |
|
|
return buffer; |
|
|
} |
|
|
|
|
|
const reader = response.body.getReader(); |
|
|
const chunks = []; |
|
|
let loaded = 0; |
|
|
|
|
|
while (true) { |
|
|
const { done, value } = await reader.read(); |
|
|
if (done) break; |
|
|
|
|
|
chunks.push(value); |
|
|
loaded += value.length; |
|
|
|
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded, |
|
|
total, |
|
|
done: false |
|
|
}); |
|
|
} |
|
|
|
|
|
self.postMessage({ |
|
|
type: 'progress', |
|
|
model: modelName, |
|
|
loaded: total, |
|
|
total, |
|
|
done: true |
|
|
}); |
|
|
|
|
|
|
|
|
const result = new Uint8Array(loaded); |
|
|
let offset = 0; |
|
|
for (const chunk of chunks) { |
|
|
result.set(chunk, offset); |
|
|
offset += chunk.length; |
|
|
} |
|
|
|
|
|
|
|
|
try { |
|
|
const cache = await caches.open(MODEL_CACHE_NAME); |
|
|
await cache.put(url, new Response(result.slice(0))); |
|
|
console.log(`${modelName} cached`); |
|
|
} catch (e) { |
|
|
console.warn('Failed to cache model:', e.message); |
|
|
} |
|
|
|
|
|
return result.buffer; |
|
|
} |
|
|
|
|
|
|
|
|
let cfg = null; |
|
|
let preprocessor = null; |
|
|
let encoder = null; |
|
|
let tailLatency = 0; |
|
|
|
|
|
|
|
|
let prepSession = null; |
|
|
let prepDim = 0; |
|
|
let prepC1 = 0; |
|
|
let prepStateC1 = null; |
|
|
let prepStateC2 = null; |
|
|
|
|
|
|
|
|
let encSession = null; |
|
|
let encDim = 0; |
|
|
let encNPast = 0; |
|
|
let encNFuture = 0; |
|
|
let encEncoderDepth = 0; |
|
|
let encContextSize = 0; |
|
|
let encInputBuffer = []; |
|
|
let encTotalInputFrames = 0; |
|
|
let encTotalOutputFrames = 0; |
|
|
|
|
|
function resetPreprocessor() { |
|
|
if (prepStateC1) prepStateC1.fill(0); |
|
|
if (prepStateC2) prepStateC2.fill(0); |
|
|
} |
|
|
|
|
|
function resetEncoder() { |
|
|
encInputBuffer = []; |
|
|
encTotalInputFrames = 0; |
|
|
encTotalOutputFrames = 0; |
|
|
} |
|
|
|
|
|
async function processPreprocessor(audioChunk) { |
|
|
const feeds = { |
|
|
'audio_chunk': new ort.Tensor('float32', audioChunk, [1, audioChunk.length]), |
|
|
'state_c1': new ort.Tensor('float32', prepStateC1, [1, 4, prepDim]), |
|
|
'state_c2': new ort.Tensor('float32', prepStateC2, [1, 4, prepC1]) |
|
|
}; |
|
|
|
|
|
const results = await prepSession.run(feeds); |
|
|
|
|
|
|
|
|
prepStateC1.set(results.new_state_c1.data); |
|
|
prepStateC2.set(results.new_state_c2.data); |
|
|
|
|
|
return { |
|
|
data: results.features.data, |
|
|
dims: results.features.dims |
|
|
}; |
|
|
} |
|
|
|
|
|
async function processEncoder(melData, melDims, flush = true) { |
|
|
const newFrames = melDims[1]; |
|
|
|
|
|
|
|
|
for (let f = 0; f < newFrames; f++) { |
|
|
const frame = new Float32Array(encDim); |
|
|
for (let d = 0; d < encDim; d++) { |
|
|
frame[d] = melData[f * encDim + d]; |
|
|
} |
|
|
encInputBuffer.push(frame); |
|
|
} |
|
|
|
|
|
encTotalInputFrames += newFrames; |
|
|
|
|
|
|
|
|
const canOutput = flush |
|
|
? encTotalInputFrames |
|
|
: Math.max(0, encTotalInputFrames - tailLatency); |
|
|
|
|
|
const outputFrom = flush |
|
|
? Math.max(0, encTotalOutputFrames - tailLatency) |
|
|
: encTotalOutputFrames; |
|
|
|
|
|
const newOutputCount = canOutput - outputFrom; |
|
|
|
|
|
if (newOutputCount <= 0) { |
|
|
return { data: new Float32Array(0), dims: [1, 0, encDim] }; |
|
|
} |
|
|
|
|
|
|
|
|
const bufferFrames = encInputBuffer.length; |
|
|
const bufferData = new Float32Array(bufferFrames * encDim); |
|
|
for (let f = 0; f < bufferFrames; f++) { |
|
|
bufferData.set(encInputBuffer[f], f * encDim); |
|
|
} |
|
|
|
|
|
const feeds = { |
|
|
'input': new ort.Tensor('float32', bufferData, [1, bufferFrames, encDim]) |
|
|
}; |
|
|
|
|
|
const results = await encSession.run(feeds); |
|
|
const fullOutput = results.output; |
|
|
|
|
|
|
|
|
const bufStartFrame = encTotalInputFrames - bufferFrames; |
|
|
const outputStart = outputFrom - bufStartFrame; |
|
|
|
|
|
|
|
|
const resultData = new Float32Array(newOutputCount * encDim); |
|
|
for (let f = 0; f < newOutputCount; f++) { |
|
|
for (let d = 0; d < encDim; d++) { |
|
|
resultData[f * encDim + d] = fullOutput.data[(outputStart + f) * encDim + d]; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (encInputBuffer.length > encContextSize) { |
|
|
encInputBuffer = encInputBuffer.slice(-encContextSize); |
|
|
} |
|
|
|
|
|
encTotalOutputFrames = canOutput; |
|
|
return { data: resultData, dims: [1, newOutputCount, encDim] }; |
|
|
} |
|
|
|
|
|
|
|
|
const messageQueue = []; |
|
|
let isProcessing = false; |
|
|
|
|
|
async function processMessage(e) { |
|
|
const { type, data } = e.data; |
|
|
|
|
|
switch (type) { |
|
|
case 'init': { |
|
|
try { |
|
|
cfg = data.cfg; |
|
|
const onnxUrl = data.onnxUrl; |
|
|
const modelName = data.modelName; |
|
|
const backend = data.backend || 'wasm'; |
|
|
const dtype = 'fp32'; |
|
|
|
|
|
|
|
|
if (backend === 'webgpu') { |
|
|
if (typeof navigator !== 'undefined' && navigator.gpu) { |
|
|
console.log('WebGPU navigator.gpu is available'); |
|
|
const adapter = await navigator.gpu.requestAdapter(); |
|
|
if (adapter) { |
|
|
console.log('WebGPU adapter obtained:', adapter); |
|
|
} else { |
|
|
throw new Error('WebGPU adapter not available'); |
|
|
} |
|
|
} else { |
|
|
throw new Error('WebGPU not supported (navigator.gpu is undefined)'); |
|
|
} |
|
|
} |
|
|
|
|
|
const sessionOptions = { executionProviders: [backend] }; |
|
|
console.log(`Creating sessions with backend: ${backend}`); |
|
|
|
|
|
tailLatency = cfg.n_future * cfg.encoder_depth; |
|
|
|
|
|
|
|
|
const prepUrl = `${onnxUrl}/preprocessor_streaming_${modelName}_${dtype}.onnx`; |
|
|
self.postMessage({ type: 'status', message: 'Loading preprocessor...' }); |
|
|
self.postMessage({ type: 'model_start', model: 'Preprocessor' }); |
|
|
const prepBuffer = await fetchModelWithProgress(prepUrl, 'Preprocessor'); |
|
|
prepSession = await ort.InferenceSession.create(prepBuffer, sessionOptions); |
|
|
self.postMessage({ type: 'model_done', model: 'Preprocessor' }); |
|
|
|
|
|
prepDim = cfg.dim; |
|
|
prepC1 = 2 * cfg.dim; |
|
|
prepStateC1 = new Float32Array(4 * cfg.dim); |
|
|
prepStateC2 = new Float32Array(4 * prepC1); |
|
|
|
|
|
|
|
|
const encUrl = `${onnxUrl}/encoder_${modelName}_${dtype}.onnx`; |
|
|
self.postMessage({ type: 'status', message: 'Loading encoder...' }); |
|
|
self.postMessage({ type: 'model_start', model: 'Encoder' }); |
|
|
const encBuffer = await fetchModelWithProgress(encUrl, 'Encoder'); |
|
|
encSession = await ort.InferenceSession.create(encBuffer, sessionOptions); |
|
|
self.postMessage({ type: 'model_done', model: 'Encoder' }); |
|
|
|
|
|
encDim = cfg.dim; |
|
|
encNPast = cfg.n_past; |
|
|
encNFuture = cfg.n_future; |
|
|
encEncoderDepth = cfg.encoder_depth; |
|
|
encContextSize = cfg.encoder_depth * (cfg.n_past + cfg.n_future); |
|
|
|
|
|
self.postMessage({ type: 'ready', backend: backend }); |
|
|
} catch (err) { |
|
|
self.postMessage({ type: 'error', message: err.message }); |
|
|
} |
|
|
break; |
|
|
} |
|
|
|
|
|
case 'segment_start': { |
|
|
resetPreprocessor(); |
|
|
resetEncoder(); |
|
|
self.postMessage({ |
|
|
type: 'segment_start', |
|
|
segmentId: data.segmentId |
|
|
}); |
|
|
break; |
|
|
} |
|
|
|
|
|
case 'segment_end': { |
|
|
self.postMessage({ |
|
|
type: 'segment_end', |
|
|
segmentId: data.segmentId |
|
|
}); |
|
|
break; |
|
|
} |
|
|
|
|
|
case 'audio': { |
|
|
try { |
|
|
|
|
|
const mel = await processPreprocessor(new Float32Array(data.audio)); |
|
|
|
|
|
const audioMs = (data.audio.length / 16000 * 1000).toFixed(0); |
|
|
console.log(`Audio ${data.audio.length} samples (${audioMs}ms) β Mel ${mel.dims[1]} frames`); |
|
|
|
|
|
|
|
|
const enc = await processEncoder(mel.data, mel.dims, true); |
|
|
|
|
|
console.log(`Mel ${mel.dims[1]} frames β Encoder ${enc.dims[1]} frames (accumulated: ${encTotalOutputFrames})`); |
|
|
|
|
|
if (enc.dims[1] > 0) { |
|
|
self.postMessage({ |
|
|
type: 'features', |
|
|
segmentId: data.segmentId, |
|
|
features: enc.data, |
|
|
dims: enc.dims |
|
|
}, [enc.data.buffer]); |
|
|
} |
|
|
} catch (err) { |
|
|
console.error('Encoder error:', err); |
|
|
} |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
async function processQueue() { |
|
|
if (isProcessing) return; |
|
|
isProcessing = true; |
|
|
|
|
|
while (messageQueue.length > 0) { |
|
|
const msg = messageQueue.shift(); |
|
|
await processMessage(msg); |
|
|
} |
|
|
|
|
|
isProcessing = false; |
|
|
} |
|
|
|
|
|
self.onmessage = function(e) { |
|
|
messageQueue.push(e); |
|
|
processQueue(); |
|
|
}; |
|
|
|