Ko-TTS-Arena / tts.py
blackhole1218's picture
change humelo voice
e9c8adc
raw
history blame
16.5 kB
# 한국어 TTS Arena - TTS Router
import os
import json
import base64
import tempfile
import requests
import urllib.request
import urllib.parse
import wave
import struct
from dotenv import load_dotenv
# Optional: scipy for high-quality resampling
try:
from scipy import signal
from scipy.io import wavfile
import numpy as np
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
print("Warning: scipy not installed. Using basic resampling.")
load_dotenv()
# Target sample rate for all TTS outputs (for fair comparison)
TARGET_SAMPLE_RATE = 16000
# 한국어 지원 TTS 제공자 매핑
# - 채널톡: 자체 API
# - ElevenLabs: 직접 API
# - OpenAI: API (gpt-4o-mini-tts)
# - Google: API
# - CLOVA: 네이버 클라우드 API
# - Supertone: API
CHANNEL_TTS_URL = os.getenv(
"CHANNEL_TTS_URL",
"https://ch-tts-streaming-demo.channel.io/v1/text-to-speech"
)
ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY")
ELEVENLABS_VOICE_ID = os.getenv("ELEVENLABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") # Rachel (기본)
SUPERTONE_API_KEY = os.getenv("SUPERTONE_API_KEY")
SUPERTONE_VOICE_ID = os.getenv("SUPERTONE_VOICE_ID", "91992bbd4758bdcf9c9b01") # 기본 보이스
# CLOVA TTS (네이버 클라우드)
CLOVA_CLIENT_ID = os.getenv("CLOVA_CLIENT_ID")
CLOVA_API_KEY = os.getenv("CLOVA_API_KEY")
# Humelo DIVE TTS
HUMELO_API_KEY = os.getenv("HUMELO_API_KEY")
HUMELO_API_URL = "https://agitvxptajouhvoatxio.supabase.co/functions/v1/dive-synthesize-v1"
def resample_wav_to_16khz(input_path: str) -> str:
"""
Resample a WAV file to 16kHz for fair comparison.
Returns the path to the resampled file.
"""
if not HAS_SCIPY:
# If scipy is not available, return original file
print(f"[Resample] scipy not available, skipping resample for {input_path}")
return input_path
try:
# Read the original WAV file
original_rate, data = wavfile.read(input_path)
# If already 16kHz, return as-is
if original_rate == TARGET_SAMPLE_RATE:
print(f"[Resample] Already {TARGET_SAMPLE_RATE}Hz, no resample needed")
return input_path
print(f"[Resample] Resampling from {original_rate}Hz to {TARGET_SAMPLE_RATE}Hz")
# Handle stereo to mono conversion if needed
if len(data.shape) > 1:
data = data.mean(axis=1).astype(data.dtype)
# Calculate the number of samples in the output
num_samples = int(len(data) * TARGET_SAMPLE_RATE / original_rate)
# Resample using scipy
resampled_data = signal.resample(data, num_samples)
# Normalize to int16 range
if resampled_data.dtype != np.int16:
# Normalize float to int16
max_val = np.max(np.abs(resampled_data))
if max_val > 0:
resampled_data = (resampled_data / max_val * 32767).astype(np.int16)
else:
resampled_data = resampled_data.astype(np.int16)
# Save to new temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
output_path = f.name
wavfile.write(output_path, TARGET_SAMPLE_RATE, resampled_data)
# Remove original file
os.remove(input_path)
print(f"[Resample] Successfully resampled to {output_path}")
return output_path
except Exception as e:
print(f"[Resample] Error resampling: {e}, returning original")
return input_path
def convert_mp3_to_wav_16khz(input_path: str) -> str:
"""
Convert MP3 to WAV at 16kHz using pydub (if available) or ffmpeg.
"""
try:
from pydub import AudioSegment
print(f"[Convert] Converting MP3 to WAV 16kHz: {input_path}")
# Load MP3
audio = AudioSegment.from_mp3(input_path)
# Convert to mono and set sample rate
audio = audio.set_channels(1).set_frame_rate(TARGET_SAMPLE_RATE)
# Export as WAV
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
output_path = f.name
audio.export(output_path, format="wav")
# Remove original MP3
os.remove(input_path)
print(f"[Convert] Successfully converted to {output_path}")
return output_path
except ImportError:
print("[Convert] pydub not available, trying ffmpeg directly")
try:
import subprocess
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
output_path = f.name
subprocess.run([
"ffmpeg", "-y", "-i", input_path,
"-ar", str(TARGET_SAMPLE_RATE),
"-ac", "1",
output_path
], check=True, capture_output=True)
os.remove(input_path)
return output_path
except Exception as e:
print(f"[Convert] ffmpeg conversion failed: {e}, returning original")
return input_path
except Exception as e:
print(f"[Convert] Error converting: {e}, returning original")
return input_path
model_mapping = {
# 채널톡 TTS (한국어 특화)
"channel-hana": {
"provider": "channel",
"voice": "hana",
},
# ElevenLabs (다국어 지원) - 직접 API 호출
"eleven-multilingual-v2": {
"provider": "elevenlabs",
"model": "eleven_multilingual_v2",
},
# OpenAI TTS (gpt-4o-mini-tts)
"openai-gpt-4o-mini-tts": {
"provider": "openai",
"model": "gpt-4o-mini-tts",
"voice": "coral",
},
# Google Cloud TTS
"google-wavenet": {
"provider": "google",
"voice": "ko-KR-Wavenet-A",
},
"google-neural2": {
"provider": "google",
"voice": "ko-KR-Neural2-A",
},
# CLOVA TTS (네이버 클라우드 - 한국어 특화)
"clova-nara": {
"provider": "clova",
"speaker": "nara",
},
# Supertone TTS (한국어 특화)
"supertone-sona": {
"provider": "supertone",
"model": "sona_speech_1",
},
# Humelo DIVE TTS (한국어 특화)
"humelo-sia": {
"provider": "humelo",
"voice": "리아",
"emotion": "neutral",
},
}
def predict_channel_tts(text: str, voice: str = "hana") -> str:
"""채널톡 TTS API 호출"""
url = f"{CHANNEL_TTS_URL}/{voice}"
response = requests.post(
url,
headers={"Content-Type": "application/json"},
json={"text": text, "output_format": "wav_24000"},
timeout=30,
)
response.raise_for_status()
# 임시 파일에 저장
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
f.write(response.content)
return f.name
def predict_elevenlabs_tts(text: str, model: str = "eleven_multilingual_v2") -> str:
"""ElevenLabs TTS API 직접 호출"""
api_key = ELEVENLABS_API_KEY
if not api_key:
raise ValueError("ELEVENLABS_API_KEY 환경 변수가 설정되지 않았습니다.")
voice_id = ELEVENLABS_VOICE_ID
response = requests.post(
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
headers={
"xi-api-key": api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
},
json={
"text": text,
"model_id": model,
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75,
},
},
timeout=60,
)
response.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
f.write(response.content)
return f.name
def predict_openai_tts(text: str, model: str = "gpt-4o-mini-tts", voice: str = "coral") -> str:
"""OpenAI TTS API 호출 (gpt-4o-mini-tts 지원)"""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY 환경 변수가 설정되지 않았습니다.")
# gpt-4o-mini-tts용 instructions (한국어 TTS에 최적화)
instructions = """Voice: Natural and clear Korean voice, with appropriate intonation and rhythm.
Punctuation: Well-structured with natural pauses for clarity.
Delivery: Calm, professional, and easy to understand.
Phrasing: Clear pronunciation with proper Korean phonetics.
Tone: Friendly yet professional, suitable for various contexts."""
payload = {
"model": model,
"input": text,
"voice": voice,
"response_format": "wav",
}
# gpt-4o-mini-tts 모델은 instructions 지원
if model == "gpt-4o-mini-tts":
payload["instructions"] = instructions
response = requests.post(
"https://api.openai.com/v1/audio/speech",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json=payload,
timeout=60,
)
response.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
f.write(response.content)
return f.name
def predict_clova_tts(text: str, speaker: str = "nara") -> str:
"""네이버 클라우드 CLOVA TTS API 호출"""
client_id = CLOVA_CLIENT_ID
client_secret = CLOVA_API_KEY
if not client_id or not client_secret:
raise ValueError("CLOVA_CLIENT_ID 또는 CLOVA_API_KEY 환경 변수가 설정되지 않았습니다.")
enc_text = urllib.parse.quote(text)
data = f"speaker={speaker}&volume=0&speed=0&pitch=0&format=mp3&text={enc_text}"
url = "https://naveropenapi.apigw.ntruss.com/tts-premium/v1/tts"
request = urllib.request.Request(url)
request.add_header("X-NCP-APIGW-API-KEY-ID", client_id)
request.add_header("X-NCP-APIGW-API-KEY", client_secret)
response = urllib.request.urlopen(request, data=data.encode('utf-8'), timeout=60)
if response.getcode() != 200:
raise ValueError(f"CLOVA TTS API 오류: {response.getcode()}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f:
f.write(response.read())
return f.name
def predict_supertone_tts(text: str, model: str = "sona_speech_1") -> str:
"""Supertone TTS API 호출"""
api_key = SUPERTONE_API_KEY
if not api_key:
raise ValueError("SUPERTONE_API_KEY 환경 변수가 설정되지 않았습니다.")
voice_id = SUPERTONE_VOICE_ID
response = requests.post(
f"https://supertoneapi.com/v1/text-to-speech/{voice_id}",
headers={
"x-sup-api-key": api_key,
"Content-Type": "application/json",
},
json={
"text": text,
"language": "ko",
"model": model,
"output_format": "wav",
"voice_settings": {
"pitch_shift": 0,
"pitch_variance": 1,
"speed": 1,
},
},
timeout=60,
)
response.raise_for_status()
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
f.write(response.content)
return f.name
def predict_humelo_tts(text: str, voice: str = "리아", emotion: str = "neutral") -> str:
"""Humelo DIVE TTS API 호출"""
api_key = HUMELO_API_KEY
if not api_key:
raise ValueError("HUMELO_API_KEY 환경 변수가 설정되지 않았습니다.")
response = requests.post(
HUMELO_API_URL,
headers={
"Content-Type": "application/json",
"X-API-Key": api_key,
},
json={
"text": text,
"mode": "preset",
"voiceName": voice,
"emotion": emotion,
"lang": "ko",
},
timeout=60,
)
response.raise_for_status()
data = response.json()
audio_url = data.get("audio_url")
if not audio_url:
raise ValueError("Humelo API가 오디오 URL을 반환하지 않았습니다.")
# Download audio from URL
audio_response = requests.get(audio_url, timeout=60)
audio_response.raise_for_status()
# Determine file extension from URL or content-type
content_type = audio_response.headers.get("Content-Type", "")
if "mp3" in content_type or audio_url.endswith(".mp3"):
suffix = ".mp3"
else:
suffix = ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as f:
f.write(audio_response.content)
return f.name
def predict_google_tts(text: str, voice: str = "ko-KR-Wavenet-A") -> str:
"""Google Cloud TTS API 호출"""
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise ValueError("GOOGLE_API_KEY 환경 변수가 설정되지 않았습니다.")
response = requests.post(
f"https://texttospeech.googleapis.com/v1/text:synthesize?key={api_key}",
headers={"Content-Type": "application/json"},
json={
"input": {"text": text},
"voice": {
"languageCode": "ko-KR",
"name": voice,
},
"audioConfig": {
"audioEncoding": "LINEAR16",
"sampleRateHertz": 24000,
},
},
timeout=30,
)
response.raise_for_status()
audio_content = response.json().get("audioContent")
if not audio_content:
raise ValueError("Google TTS API가 오디오를 반환하지 않았습니다.")
audio_bytes = base64.b64decode(audio_content)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
f.write(audio_bytes)
return f.name
def predict_tts(text: str, model: str) -> str:
"""
TTS 생성 메인 함수
Args:
text: 합성할 텍스트
model: 모델 ID (model_mapping의 키)
Returns:
생성된 오디오 파일 경로 (16kHz WAV로 통일)
"""
print(f"[TTS] Predicting for model: {model}")
if model not in model_mapping:
raise ValueError(f"지원하지 않는 모델입니다: {model}")
config = model_mapping[model]
provider = config["provider"]
audio_path = None
is_mp3 = False
if provider == "channel":
audio_path = predict_channel_tts(text, config.get("voice", "hana"))
# Channel TTS returns WAV at 24kHz
elif provider == "openai":
audio_path = predict_openai_tts(
text,
config.get("model", "gpt-4o-mini-tts"),
config.get("voice", "coral"),
)
# OpenAI returns WAV
elif provider == "google":
audio_path = predict_google_tts(text, config.get("voice", "ko-KR-Wavenet-A"))
# Google returns WAV at 24kHz
elif provider == "elevenlabs":
audio_path = predict_elevenlabs_tts(text, config.get("model", "eleven_multilingual_v2"))
is_mp3 = True # ElevenLabs returns MP3
elif provider == "supertone":
audio_path = predict_supertone_tts(text, config.get("model", "sona_speech_1"))
# Supertone returns WAV
elif provider == "clova":
audio_path = predict_clova_tts(text, config.get("speaker", "nara"))
is_mp3 = True # CLOVA returns MP3
elif provider == "humelo":
audio_path = predict_humelo_tts(
text,
config.get("voice", "리아"),
config.get("emotion", "neutral"),
)
# Humelo might return MP3 or WAV, check extension
is_mp3 = audio_path.endswith(".mp3")
else:
raise ValueError(f"알 수 없는 provider: {provider}")
# Standardize to 16kHz WAV for fair comparison
if audio_path:
if is_mp3:
# Convert MP3 to WAV at 16kHz
audio_path = convert_mp3_to_wav_16khz(audio_path)
else:
# Resample WAV to 16kHz
audio_path = resample_wav_to_16khz(audio_path)
return audio_path
if __name__ == "__main__":
# 테스트
test_text = "안녕하세요, 채널톡 TTS 테스트입니다."
print("Testing Channel TTS...")
try:
path = predict_channel_tts(test_text)
print(f" Success: {path}")
except Exception as e:
print(f" Error: {e}")