supertonic / index.js
akhaliq's picture
akhaliq HF Staff
Update index.js
644b6c1 verified
// Configuration
const MODEL_ID = 'onnx-community/Supertonic-TTS-ONNX';
const VOICE_BASE_URL = 'https://huggingface.co/onnx-community/Supertonic-TTS-ONNX/resolve/main/voices/';
// DOM Elements
const generateBtn = document.getElementById('generate-btn');
const inputText = document.getElementById('input-text');
const voiceSelect = document.getElementById('voice-select');
const gpuToggle = document.getElementById('gpu-toggle');
const deviceLabel = document.getElementById('device-label');
const statusContainer = document.getElementById('status-container');
const statusText = document.getElementById('status-text');
const progressBar = document.getElementById('progress-bar');
const outputCard = document.getElementById('output-card');
const audioPlayer = document.getElementById('audio-player');
const downloadLink = document.getElementById('download-link');
const errorMsg = document.getElementById('error-msg');
// State
let ttsPipeline = null;
let currentDevice = 'wasm';
// Helper: Check WebGPU support
async function checkWebGPU() {
if (!navigator.gpu) {
gpuToggle.disabled = true;
deviceLabel.innerText = "WebGPU not supported (CPU only)";
return false;
}
return true;
}
checkWebGPU();
// UI Event Listeners
gpuToggle.addEventListener('change', (e) => {
const useGPU = e.target.checked;
currentDevice = useGPU ? 'webgpu' : 'wasm';
deviceLabel.innerText = useGPU ? 'Run on WebGPU' : 'Run on CPU';
// Reset pipeline to force reload with new device setting next time
ttsPipeline = null;
});
inputText.addEventListener('input', () => {
document.querySelector('.char-count').innerText = `${inputText.value.length} / 500`;
});
generateBtn.addEventListener('click', async () => {
const text = inputText.value.trim();
if (!text) return;
resetUI();
statusContainer.classList.remove('hidden');
generateBtn.disabled = true;
try {
// 1. Initialize Pipeline if needed
if (!ttsPipeline) {
updateStatus('Loading model... (this may take a moment)', 0);
// Import pipeline from window (set in HTML)
const { pipeline } = window;
ttsPipeline = await pipeline('text-to-speech', MODEL_ID, {
device: currentDevice,
dtype: 'fp32', // Required for this specific model as per prompt
progress_callback: (data) => {
if (data.status === 'progress') {
updateStatus(`Downloading ${data.file}...`, data.progress);
} else if (data.status === 'ready') {
updateStatus('Model ready!', 100);
}
}
});
}
// 2. Generate Audio
updateStatus('Generating audio...', 100);
progressBar.classList.add('pulsing'); // Add animation for inference time
const voiceFile = voiceSelect.value;
const speaker_embeddings = `${VOICE_BASE_URL}${voiceFile}`;
// Run inference
const output = await ttsPipeline(text, {
speaker_embeddings: speaker_embeddings
});
// 3. Process Output
// output.audio is a Float32Array, output.sampling_rate is a number
const wavUrl = createWavUrl(output.audio, output.sampling_rate);
audioPlayer.src = wavUrl;
downloadLink.href = wavUrl;
outputCard.classList.remove('hidden');
// Auto-play result
try {
await audioPlayer.play();
} catch (e) {
console.log("Auto-play blocked by browser policy");
}
} catch (err) {
console.error(err);
showError(err.message);
} finally {
generateBtn.disabled = false;
progressBar.classList.remove('pulsing');
statusContainer.classList.add('hidden');
}
});
// Helper: Update Progress UI
function updateStatus(text, progressPercent) {
statusText.innerText = text;
progressBar.style.width = `${progressPercent}%`;
}
function resetUI() {
outputCard.classList.add('hidden');
errorMsg.classList.add('hidden');
progressBar.style.width = '0%';
}
function showError(msg) {
errorMsg.innerText = `Error: ${msg}`;
errorMsg.classList.remove('hidden');
}
// Audio Utility: Convert Float32Array to WAV Blob URL
function createWavUrl(audioData, sampleRate) {
const buffer = encodeWAV(audioData, sampleRate);
const blob = new Blob([buffer], { type: 'audio/wav' });
return URL.createObjectURL(blob);
}
function encodeWAV(samples, sampleRate) {
const buffer = new ArrayBuffer(44 + samples.length * 2);
const view = new DataView(buffer);
// RIFF chunk descriptor
writeString(view, 0, 'RIFF');
view.setUint32(4, 36 + samples.length * 2, true);
writeString(view, 8, 'WAVE');
// fmt sub-chunk
writeString(view, 12, 'fmt ');
view.setUint32(16, 16, true);
view.setUint16(20, 1, true); // PCM format
view.setUint16(22, 1, true); // Mono
view.setUint32(24, sampleRate, true);
view.setUint32(28, sampleRate * 2, true);
view.setUint16(32, 2, true);
view.setUint16(34, 16, true); // 16-bit
// data sub-chunk
writeString(view, 36, 'data');
view.setUint32(40, samples.length * 2, true);
// Write PCM samples
floatTo16BitPCM(view, 44, samples);
return buffer;
}
function writeString(view, offset, string) {
for (let i = 0; i < string.length; i++) {
view.setUint8(offset + i, string.charCodeAt(i));
}
}
function floatTo16BitPCM(view, offset, input) {
for (let i = 0; i < input.length; i++, offset += 2) {
let s = Math.max(-1, Math.min(1, input[i]));
s = s < 0 ? s * 0x8000 : s * 0x7FFF;
view.setInt16(offset, s, true);
}
}