chatterbox-api / app.py
HFHash789's picture
Upload folder using huggingface_hub
d5e7e0d verified
import os,time,shutil,sys
#os.environ['htts_proxy']='http://127.0.0.1:10808'
#os.environ['htt_proxy']='http://127.0.0.1:10808'
from pathlib import Path
import inspect
import threading
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
#from chatterbox.tts import ChatterboxTTS
def _env_int(name: str, default: int) -> int:
val = os.environ.get(name)
if val is None or val == "":
return default
try:
return int(val)
except ValueError:
return default
host = os.environ.get("HOST", "127.0.0.1")
port = _env_int("PORT", 5093)
threads = _env_int("THREADS", 4)
ROOT_DIR=Path(os.getcwd()).as_posix()
# 对于国内用户,使用Hugging Face镜像能显著提高下载速度
os.environ.setdefault('HF_HOME', ROOT_DIR + "/models")
os.environ.setdefault('HF_HUB_DISABLE_SYMLINKS_WARNING', 'true')
os.environ.setdefault('HF_HUB_DISABLE_PROGRESS_BARS', 'true')
os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', "1200")
import subprocess,traceback
import io
import uuid
import tempfile
from flask import Flask, request, jsonify, send_file, render_template, make_response
from waitress import serve
import torch
try:
import soundfile as sf
except ImportError:
print('No soundfile, exec cmd ` runtime\\\\python -m pip install soundfile`')
sys.exit()
try:
from pydub import AudioSegment
except ImportError:
print('No soundfile, exec cmd ` runtime\\\\python -m pip install pydub`')
sys.exit()
if sys.platform == 'win32':
os.environ['PATH'] = ROOT_DIR + f';{ROOT_DIR}/ffmpeg;{ROOT_DIR}/tools;' + os.environ['PATH']
from chatterbox.mtl_tts import ChatterboxMultilingualTTS as ChatterboxTTS
# 检查ffmpeg是否安装
def check_ffmpeg():
"""检查系统中是否安装了ffmpeg"""
try:
subprocess.run(["ffmpeg", "-version"], check=True, capture_output=True)
print("FFmpeg 已安装.")
return True
except (subprocess.CalledProcessError, FileNotFoundError):
print("ERROR: 不存在ffmpeg,请先安装ffmpeg.")
sys.exit(1) # 强制退出,因为MP3转换是必须功能
# 加载Chatterbox TTS模型
def load_tts_model():
"""加载TTS模型到指定设备"""
print("⏳ 开始加载模型 ChatterboxTTS model... 请耐心等待.")
try:
# 自动检测可用设备 (CUDA > CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cpu":
# HF 免费 CPU 资源较少,限制线程数通常更快、更稳
torch_threads = _env_int("TORCH_NUM_THREADS", min(os.cpu_count() or 1, 2))
torch_interop_threads = _env_int("TORCH_NUM_INTEROP_THREADS", 1)
try:
torch.set_num_threads(max(1, torch_threads))
torch.set_num_interop_threads(max(1, torch_interop_threads))
except Exception:
pass
def _try_set_attr(obj, name, value) -> bool:
if obj is None or not hasattr(obj, name):
return False
try:
setattr(obj, name, value)
return True
except Exception:
return False
def _get_by_path(root, path: str):
cur = root
for part in path.split("."):
cur = getattr(cur, part, None)
if cur is None:
return None
return cur
def _tune_transformers(model_obj):
# 这些设置用于避免 CPU 上 `output_attentions=True` 触发 attention fallback,导致极慢/看似卡住
candidates = []
for p in (
"config",
"generation_config",
"t3",
"t3.config",
"t3.generation_config",
"t3.model",
"t3.model.config",
"t3.model.generation_config",
"model",
"model.config",
"model.generation_config",
):
obj = _get_by_path(model_obj, p)
if obj is not None:
candidates.append(obj)
changed = False
for cfg in candidates:
changed |= _try_set_attr(cfg, "output_attentions", False)
changed |= _try_set_attr(cfg, "return_dict_in_generate", True)
changed |= _try_set_attr(cfg, "use_cache", True)
return changed
# 从预训练模型加载
# CPU 环境下,某些 checkpoint 可能会带 CUDA tensors;强制 map_location 防止反序列化失败。
if device == "cpu":
original_torch_load = torch.load
def _cpu_safe_load(*args, **kwargs):
kwargs.setdefault("map_location", "cpu")
return original_torch_load(*args, **kwargs)
torch.load = _cpu_safe_load
try:
tts_model = ChatterboxTTS.from_pretrained(device=device)
finally:
torch.load = original_torch_load
else:
tts_model = ChatterboxTTS.from_pretrained(device=device)
if _tune_transformers(tts_model):
print("✅ 已禁用 output_attentions 并启用 return_dict_in_generate(提升 CPU 推理速度)。")
print("模型加载完成.")
return tts_model
except Exception as e:
print(f"FATAL: 模型加载失败: {e}")
sys.exit(1)
# --- 全局变量初始化 ---
check_ffmpeg()
model = None
model_lock = threading.Lock()
app = Flask(__name__)
def generate_tts(tts_model, text, *, steps=None, **kwargs):
sig = None
try:
sig = inspect.signature(tts_model.generate)
except Exception:
sig = None
if steps is not None and sig is not None:
for name in ("steps", "num_steps", "n_steps", "sampling_steps", "num_inference_steps"):
if name in sig.parameters:
kwargs[name] = int(steps)
break
if sig is not None:
kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
return tts_model.generate(text, **kwargs)
def get_model():
global model
if model is not None:
return model
with model_lock:
if model is None:
model = load_tts_model()
return model
def convert_to_wav(input_path, output_path, sample_rate=16000):
"""
Converts any audio file to a standardized WAV format using ffmpeg.
- 16-bit PCM
- Specified sample rate (default 16kHz, common for TTS)
- Mono channel
"""
print(f" - Converting '{input_path}' to WAV at {sample_rate}Hz...")
command = [
'ffmpeg',
'-i', input_path, # Input file
'-y', # Overwrite output file if it exists
'-acodec', 'pcm_s16le',# Use 16-bit PCM encoding
'-ar', str(sample_rate),# Set audio sample rate
'-ac', '1', # Set to 1 audio channel (mono)
output_path # Output file
]
try:
process = subprocess.run(
command,
check=True, # Raise an exception if ffmpeg fails
capture_output=True, # Capture stdout and stderr
text=True, # Decode stdout/stderr as text
encoding='utf-8', # 明确指定使用 UTF-8 解码
errors='replace' # 如果遇到解码错误,用'�'替换,而不是崩溃
)
print(f" - FFmpeg conversion successful.")
except subprocess.CalledProcessError as e:
# If ffmpeg fails, print its error output for easier debugging
print("FFmpeg conversion failed!")
print(f" - Command: {' '.join(command)}")
print(f" - Stderr: {e.stderr}")
raise e # Re-raise the exception to be caught by the main try...except block
# --- API 接口 ---
@app.route('/')
def index():
"""提供前端界面"""
return render_template('index.html')
# 接口1: 兼容OpenAI TTS接口
@app.route('/v1/audio/speech', methods=['POST'])
def tts_openai_compatible():
"""
OpenAI TTS兼容接口。
接收JSON: {"input": "text", "model": "chatterbox", "voice": "default", ...}
`model`和`voice`参数会被接收但当前实现中忽略。
"""
if not request.is_json:
return jsonify({"error": "Request must be JSON"}), 400
data = request.get_json()
text = data.get('input')
# voice 用来接收语言代码
lang=data.get('voice','en')
# speed用于接收 cfg_weight
cfg_weight=float(data.get('speed',0.5))
# instructions 用于接收 exaggeration
exaggeration=float(data.get('instructions',0.5))
#if lang != 'en':
# return jsonify({"error": "Only support English"}), 400
if not text:
return jsonify({"error": "Missing 'input' field in request body"}), 400
print(f"[APIv1] Received text: '{text[:50]}...'")
try:
# 生成WAV音频
tts_model = get_model()
steps = data.get("steps", data.get("num_steps", None))
if steps is None:
steps = _env_int("DEFAULT_STEPS", 200 if (not torch.cuda.is_available()) else 1000)
print(f"[APIv1] steps={steps}")
t0 = time.time()
wav_tensor = generate_tts(
tts_model,
text,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
language_id=lang,
steps=steps,
)
print(f"[APIv1] generate() done in {time.time()-t0:.2f}s")
# 检查请求的响应格式,默认为mp3
response_format = data.get('response_format', 'mp3').lower()
download_name=f'{time.time()}'
# 对于其他格式(如wav),直接返回
wav_buffer = io.BytesIO()
wav_tensor = wav_tensor.detach().cpu()
if wav_tensor.ndim == 2:
wav_np = wav_tensor.transpose(0, 1).numpy()
else:
wav_np = wav_tensor.numpy()
# 写入 WAV 格式到内存
sf.write(wav_buffer, wav_np, tts_model.sr, format='wav')
wav_buffer.seek(0)
if response_format=='mp3':
mp3_buffer = io.BytesIO()
AudioSegment.from_file(wav_buffer, format="wav").export(mp3_buffer, format="mp3")
mp3_buffer.seek(0)
return send_file(
mp3_buffer,
mimetype='audio/mpeg',
as_attachment=False,
download_name=f'{download_name}.mp3'
)
return send_file(
wav_buffer,
mimetype='audio/wav',
as_attachment=False,
download_name=f'{download_name}.wav'
)
except Exception as e:
print(f"[APIv1] Error during TTS generation: {e}")
return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500
# 接口2: 带参考音频的TTS
@app.route('/v2/audio/speech_with_prompt', methods=['POST'])
def tts_with_prompt():
"""
带参考音频的接口。
接收 multipart/form-data:
- 'input': (string) 要转换的文本
- 'audio_prompt': (file) 参考音频文件
"""
if 'input' not in request.form:
return jsonify({"error": "Missing 'input' field in form data"}), 400
if 'audio_prompt' not in request.files:
return jsonify({"error": "Missing 'audio_prompt' file in form data"}), 400
text = request.form['input']
audio_file = request.files['audio_prompt']
response_format = request.form.get('response_format', 'wav').lower()
cfg_weight=float(request.form.get('cfg_weight',0.5))
exaggeration=float(request.form.get('exaggeration',0.5))
lang = request.form.get('language','en')
#if lang != 'en':
# return jsonify({"error": "Only support English"}), 400
print(f"[APIv2] Received text: '{text[:50]}...' with audio prompt '{audio_file.filename}'")
temp_upload_path = None
temp_wav_path = None
try:
# --- Stage 1 & 2: Save and Convert uploaded file ---
temp_dir = tempfile.gettempdir()
upload_suffix = os.path.splitext(audio_file.filename)[1]
temp_upload_path = os.path.join(temp_dir, f"{uuid.uuid4()}{upload_suffix}")
audio_file.save(temp_upload_path)
print(f" - Uploaded audio saved to: {temp_upload_path}")
temp_wav_path = os.path.join(temp_dir, f"{uuid.uuid4()}.wav")
convert_to_wav(temp_upload_path, temp_wav_path)
# --- Stage 3: Generate TTS using the converted WAV file ---
print(f" - Generating TTS with prompt: {temp_wav_path}")
tts_model = get_model()
steps = request.form.get("steps") or request.form.get("num_steps")
if steps is None or str(steps).strip() == "":
steps = _env_int("DEFAULT_STEPS", 200 if (not torch.cuda.is_available()) else 1000)
print(f"[APIv2] steps={steps}")
t0 = time.time()
wav_tensor = generate_tts(
tts_model,
text,
audio_prompt_path=temp_wav_path,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
language_id=lang,
steps=steps,
)
print(f"[APIv2] generate() done in {time.time()-t0:.2f}s")
# --- Stage 4: Format and Return Response Based on Request ---
download_name=f'{time.time()}'
print(" - Formatting response as WAV.")
wav_buffer = io.BytesIO()
wav_tensor = wav_tensor.detach().cpu()
if wav_tensor.ndim == 2:
wav_np = wav_tensor.transpose(0, 1).numpy()
else:
wav_np = wav_tensor.numpy()
# 写入 WAV 格式到内存
sf.write(wav_buffer, wav_np, tts_model.sr, format='wav')
wav_buffer.seek(0)
if response_format == 'mp3':
mp3_buffer = io.BytesIO()
AudioSegment.from_file(wav_buffer, format="wav").export(mp3_buffer, format="mp3")
mp3_buffer.seek(0)
return send_file(
mp3_buffer,
mimetype='audio/mpeg',
as_attachment=False,
download_name=f'{download_name}.mp3'
)
return send_file(
wav_buffer,
mimetype='audio/wav',
as_attachment=False,
download_name=f'{download_name}.wav'
)
except Exception as e:
print(f"[APIv2] An error occurred: {e}")
traceback.print_exc()
return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500
finally:
# --- Stage 5: Cleanup ---
if temp_upload_path and os.path.exists(temp_upload_path):
try:
os.remove(temp_upload_path)
print(f" - Cleaned up upload file: {temp_upload_path}")
except OSError as e:
print(f" - Error cleaning up upload file {temp_upload_path}: {e}")
if temp_wav_path and os.path.exists(temp_wav_path):
try:
os.remove(temp_wav_path)
print(f" - Cleaned up WAV file: {temp_wav_path}")
except OSError as e:
print(f" - Error cleaning up WAV file {temp_wav_path}: {e}")
# --- 服务启动 ---
if __name__ == '__main__':
print(f"\n服务启动完成,http地址是: http://{host}:{port} \n")
serve(app, host=host, port=port, threads=threads)