Spaces:
Running
Running
| import os,time,shutil,sys | |
| #os.environ['htts_proxy']='http://127.0.0.1:10808' | |
| #os.environ['htt_proxy']='http://127.0.0.1:10808' | |
| from pathlib import Path | |
| import inspect | |
| import threading | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| #from chatterbox.tts import ChatterboxTTS | |
| def _env_int(name: str, default: int) -> int: | |
| val = os.environ.get(name) | |
| if val is None or val == "": | |
| return default | |
| try: | |
| return int(val) | |
| except ValueError: | |
| return default | |
| host = os.environ.get("HOST", "127.0.0.1") | |
| port = _env_int("PORT", 5093) | |
| threads = _env_int("THREADS", 4) | |
| ROOT_DIR=Path(os.getcwd()).as_posix() | |
| # 对于国内用户,使用Hugging Face镜像能显著提高下载速度 | |
| os.environ.setdefault('HF_HOME', ROOT_DIR + "/models") | |
| os.environ.setdefault('HF_HUB_DISABLE_SYMLINKS_WARNING', 'true') | |
| os.environ.setdefault('HF_HUB_DISABLE_PROGRESS_BARS', 'true') | |
| os.environ.setdefault('HF_HUB_DOWNLOAD_TIMEOUT', "1200") | |
| import subprocess,traceback | |
| import io | |
| import uuid | |
| import tempfile | |
| from flask import Flask, request, jsonify, send_file, render_template, make_response | |
| from waitress import serve | |
| import torch | |
| try: | |
| import soundfile as sf | |
| except ImportError: | |
| print('No soundfile, exec cmd ` runtime\\\\python -m pip install soundfile`') | |
| sys.exit() | |
| try: | |
| from pydub import AudioSegment | |
| except ImportError: | |
| print('No soundfile, exec cmd ` runtime\\\\python -m pip install pydub`') | |
| sys.exit() | |
| if sys.platform == 'win32': | |
| os.environ['PATH'] = ROOT_DIR + f';{ROOT_DIR}/ffmpeg;{ROOT_DIR}/tools;' + os.environ['PATH'] | |
| from chatterbox.mtl_tts import ChatterboxMultilingualTTS as ChatterboxTTS | |
| # 检查ffmpeg是否安装 | |
| def check_ffmpeg(): | |
| """检查系统中是否安装了ffmpeg""" | |
| try: | |
| subprocess.run(["ffmpeg", "-version"], check=True, capture_output=True) | |
| print("FFmpeg 已安装.") | |
| return True | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| print("ERROR: 不存在ffmpeg,请先安装ffmpeg.") | |
| sys.exit(1) # 强制退出,因为MP3转换是必须功能 | |
| # 加载Chatterbox TTS模型 | |
| def load_tts_model(): | |
| """加载TTS模型到指定设备""" | |
| print("⏳ 开始加载模型 ChatterboxTTS model... 请耐心等待.") | |
| try: | |
| # 自动检测可用设备 (CUDA > CPU) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| if device == "cpu": | |
| # HF 免费 CPU 资源较少,限制线程数通常更快、更稳 | |
| torch_threads = _env_int("TORCH_NUM_THREADS", min(os.cpu_count() or 1, 2)) | |
| torch_interop_threads = _env_int("TORCH_NUM_INTEROP_THREADS", 1) | |
| try: | |
| torch.set_num_threads(max(1, torch_threads)) | |
| torch.set_num_interop_threads(max(1, torch_interop_threads)) | |
| except Exception: | |
| pass | |
| def _try_set_attr(obj, name, value) -> bool: | |
| if obj is None or not hasattr(obj, name): | |
| return False | |
| try: | |
| setattr(obj, name, value) | |
| return True | |
| except Exception: | |
| return False | |
| def _get_by_path(root, path: str): | |
| cur = root | |
| for part in path.split("."): | |
| cur = getattr(cur, part, None) | |
| if cur is None: | |
| return None | |
| return cur | |
| def _tune_transformers(model_obj): | |
| # 这些设置用于避免 CPU 上 `output_attentions=True` 触发 attention fallback,导致极慢/看似卡住 | |
| candidates = [] | |
| for p in ( | |
| "config", | |
| "generation_config", | |
| "t3", | |
| "t3.config", | |
| "t3.generation_config", | |
| "t3.model", | |
| "t3.model.config", | |
| "t3.model.generation_config", | |
| "model", | |
| "model.config", | |
| "model.generation_config", | |
| ): | |
| obj = _get_by_path(model_obj, p) | |
| if obj is not None: | |
| candidates.append(obj) | |
| changed = False | |
| for cfg in candidates: | |
| changed |= _try_set_attr(cfg, "output_attentions", False) | |
| changed |= _try_set_attr(cfg, "return_dict_in_generate", True) | |
| changed |= _try_set_attr(cfg, "use_cache", True) | |
| return changed | |
| # 从预训练模型加载 | |
| # CPU 环境下,某些 checkpoint 可能会带 CUDA tensors;强制 map_location 防止反序列化失败。 | |
| if device == "cpu": | |
| original_torch_load = torch.load | |
| def _cpu_safe_load(*args, **kwargs): | |
| kwargs.setdefault("map_location", "cpu") | |
| return original_torch_load(*args, **kwargs) | |
| torch.load = _cpu_safe_load | |
| try: | |
| tts_model = ChatterboxTTS.from_pretrained(device=device) | |
| finally: | |
| torch.load = original_torch_load | |
| else: | |
| tts_model = ChatterboxTTS.from_pretrained(device=device) | |
| if _tune_transformers(tts_model): | |
| print("✅ 已禁用 output_attentions 并启用 return_dict_in_generate(提升 CPU 推理速度)。") | |
| print("模型加载完成.") | |
| return tts_model | |
| except Exception as e: | |
| print(f"FATAL: 模型加载失败: {e}") | |
| sys.exit(1) | |
| # --- 全局变量初始化 --- | |
| check_ffmpeg() | |
| model = None | |
| model_lock = threading.Lock() | |
| app = Flask(__name__) | |
| def generate_tts(tts_model, text, *, steps=None, **kwargs): | |
| sig = None | |
| try: | |
| sig = inspect.signature(tts_model.generate) | |
| except Exception: | |
| sig = None | |
| if steps is not None and sig is not None: | |
| for name in ("steps", "num_steps", "n_steps", "sampling_steps", "num_inference_steps"): | |
| if name in sig.parameters: | |
| kwargs[name] = int(steps) | |
| break | |
| if sig is not None: | |
| kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} | |
| return tts_model.generate(text, **kwargs) | |
| def get_model(): | |
| global model | |
| if model is not None: | |
| return model | |
| with model_lock: | |
| if model is None: | |
| model = load_tts_model() | |
| return model | |
| def convert_to_wav(input_path, output_path, sample_rate=16000): | |
| """ | |
| Converts any audio file to a standardized WAV format using ffmpeg. | |
| - 16-bit PCM | |
| - Specified sample rate (default 16kHz, common for TTS) | |
| - Mono channel | |
| """ | |
| print(f" - Converting '{input_path}' to WAV at {sample_rate}Hz...") | |
| command = [ | |
| 'ffmpeg', | |
| '-i', input_path, # Input file | |
| '-y', # Overwrite output file if it exists | |
| '-acodec', 'pcm_s16le',# Use 16-bit PCM encoding | |
| '-ar', str(sample_rate),# Set audio sample rate | |
| '-ac', '1', # Set to 1 audio channel (mono) | |
| output_path # Output file | |
| ] | |
| try: | |
| process = subprocess.run( | |
| command, | |
| check=True, # Raise an exception if ffmpeg fails | |
| capture_output=True, # Capture stdout and stderr | |
| text=True, # Decode stdout/stderr as text | |
| encoding='utf-8', # 明确指定使用 UTF-8 解码 | |
| errors='replace' # 如果遇到解码错误,用'�'替换,而不是崩溃 | |
| ) | |
| print(f" - FFmpeg conversion successful.") | |
| except subprocess.CalledProcessError as e: | |
| # If ffmpeg fails, print its error output for easier debugging | |
| print("FFmpeg conversion failed!") | |
| print(f" - Command: {' '.join(command)}") | |
| print(f" - Stderr: {e.stderr}") | |
| raise e # Re-raise the exception to be caught by the main try...except block | |
| # --- API 接口 --- | |
| def index(): | |
| """提供前端界面""" | |
| return render_template('index.html') | |
| # 接口1: 兼容OpenAI TTS接口 | |
| def tts_openai_compatible(): | |
| """ | |
| OpenAI TTS兼容接口。 | |
| 接收JSON: {"input": "text", "model": "chatterbox", "voice": "default", ...} | |
| `model`和`voice`参数会被接收但当前实现中忽略。 | |
| """ | |
| if not request.is_json: | |
| return jsonify({"error": "Request must be JSON"}), 400 | |
| data = request.get_json() | |
| text = data.get('input') | |
| # voice 用来接收语言代码 | |
| lang=data.get('voice','en') | |
| # speed用于接收 cfg_weight | |
| cfg_weight=float(data.get('speed',0.5)) | |
| # instructions 用于接收 exaggeration | |
| exaggeration=float(data.get('instructions',0.5)) | |
| #if lang != 'en': | |
| # return jsonify({"error": "Only support English"}), 400 | |
| if not text: | |
| return jsonify({"error": "Missing 'input' field in request body"}), 400 | |
| print(f"[APIv1] Received text: '{text[:50]}...'") | |
| try: | |
| # 生成WAV音频 | |
| tts_model = get_model() | |
| steps = data.get("steps", data.get("num_steps", None)) | |
| if steps is None: | |
| steps = _env_int("DEFAULT_STEPS", 200 if (not torch.cuda.is_available()) else 1000) | |
| print(f"[APIv1] steps={steps}") | |
| t0 = time.time() | |
| wav_tensor = generate_tts( | |
| tts_model, | |
| text, | |
| exaggeration=exaggeration, | |
| cfg_weight=cfg_weight, | |
| language_id=lang, | |
| steps=steps, | |
| ) | |
| print(f"[APIv1] generate() done in {time.time()-t0:.2f}s") | |
| # 检查请求的响应格式,默认为mp3 | |
| response_format = data.get('response_format', 'mp3').lower() | |
| download_name=f'{time.time()}' | |
| # 对于其他格式(如wav),直接返回 | |
| wav_buffer = io.BytesIO() | |
| wav_tensor = wav_tensor.detach().cpu() | |
| if wav_tensor.ndim == 2: | |
| wav_np = wav_tensor.transpose(0, 1).numpy() | |
| else: | |
| wav_np = wav_tensor.numpy() | |
| # 写入 WAV 格式到内存 | |
| sf.write(wav_buffer, wav_np, tts_model.sr, format='wav') | |
| wav_buffer.seek(0) | |
| if response_format=='mp3': | |
| mp3_buffer = io.BytesIO() | |
| AudioSegment.from_file(wav_buffer, format="wav").export(mp3_buffer, format="mp3") | |
| mp3_buffer.seek(0) | |
| return send_file( | |
| mp3_buffer, | |
| mimetype='audio/mpeg', | |
| as_attachment=False, | |
| download_name=f'{download_name}.mp3' | |
| ) | |
| return send_file( | |
| wav_buffer, | |
| mimetype='audio/wav', | |
| as_attachment=False, | |
| download_name=f'{download_name}.wav' | |
| ) | |
| except Exception as e: | |
| print(f"[APIv1] Error during TTS generation: {e}") | |
| return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500 | |
| # 接口2: 带参考音频的TTS | |
| def tts_with_prompt(): | |
| """ | |
| 带参考音频的接口。 | |
| 接收 multipart/form-data: | |
| - 'input': (string) 要转换的文本 | |
| - 'audio_prompt': (file) 参考音频文件 | |
| """ | |
| if 'input' not in request.form: | |
| return jsonify({"error": "Missing 'input' field in form data"}), 400 | |
| if 'audio_prompt' not in request.files: | |
| return jsonify({"error": "Missing 'audio_prompt' file in form data"}), 400 | |
| text = request.form['input'] | |
| audio_file = request.files['audio_prompt'] | |
| response_format = request.form.get('response_format', 'wav').lower() | |
| cfg_weight=float(request.form.get('cfg_weight',0.5)) | |
| exaggeration=float(request.form.get('exaggeration',0.5)) | |
| lang = request.form.get('language','en') | |
| #if lang != 'en': | |
| # return jsonify({"error": "Only support English"}), 400 | |
| print(f"[APIv2] Received text: '{text[:50]}...' with audio prompt '{audio_file.filename}'") | |
| temp_upload_path = None | |
| temp_wav_path = None | |
| try: | |
| # --- Stage 1 & 2: Save and Convert uploaded file --- | |
| temp_dir = tempfile.gettempdir() | |
| upload_suffix = os.path.splitext(audio_file.filename)[1] | |
| temp_upload_path = os.path.join(temp_dir, f"{uuid.uuid4()}{upload_suffix}") | |
| audio_file.save(temp_upload_path) | |
| print(f" - Uploaded audio saved to: {temp_upload_path}") | |
| temp_wav_path = os.path.join(temp_dir, f"{uuid.uuid4()}.wav") | |
| convert_to_wav(temp_upload_path, temp_wav_path) | |
| # --- Stage 3: Generate TTS using the converted WAV file --- | |
| print(f" - Generating TTS with prompt: {temp_wav_path}") | |
| tts_model = get_model() | |
| steps = request.form.get("steps") or request.form.get("num_steps") | |
| if steps is None or str(steps).strip() == "": | |
| steps = _env_int("DEFAULT_STEPS", 200 if (not torch.cuda.is_available()) else 1000) | |
| print(f"[APIv2] steps={steps}") | |
| t0 = time.time() | |
| wav_tensor = generate_tts( | |
| tts_model, | |
| text, | |
| audio_prompt_path=temp_wav_path, | |
| exaggeration=exaggeration, | |
| cfg_weight=cfg_weight, | |
| language_id=lang, | |
| steps=steps, | |
| ) | |
| print(f"[APIv2] generate() done in {time.time()-t0:.2f}s") | |
| # --- Stage 4: Format and Return Response Based on Request --- | |
| download_name=f'{time.time()}' | |
| print(" - Formatting response as WAV.") | |
| wav_buffer = io.BytesIO() | |
| wav_tensor = wav_tensor.detach().cpu() | |
| if wav_tensor.ndim == 2: | |
| wav_np = wav_tensor.transpose(0, 1).numpy() | |
| else: | |
| wav_np = wav_tensor.numpy() | |
| # 写入 WAV 格式到内存 | |
| sf.write(wav_buffer, wav_np, tts_model.sr, format='wav') | |
| wav_buffer.seek(0) | |
| if response_format == 'mp3': | |
| mp3_buffer = io.BytesIO() | |
| AudioSegment.from_file(wav_buffer, format="wav").export(mp3_buffer, format="mp3") | |
| mp3_buffer.seek(0) | |
| return send_file( | |
| mp3_buffer, | |
| mimetype='audio/mpeg', | |
| as_attachment=False, | |
| download_name=f'{download_name}.mp3' | |
| ) | |
| return send_file( | |
| wav_buffer, | |
| mimetype='audio/wav', | |
| as_attachment=False, | |
| download_name=f'{download_name}.wav' | |
| ) | |
| except Exception as e: | |
| print(f"[APIv2] An error occurred: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500 | |
| finally: | |
| # --- Stage 5: Cleanup --- | |
| if temp_upload_path and os.path.exists(temp_upload_path): | |
| try: | |
| os.remove(temp_upload_path) | |
| print(f" - Cleaned up upload file: {temp_upload_path}") | |
| except OSError as e: | |
| print(f" - Error cleaning up upload file {temp_upload_path}: {e}") | |
| if temp_wav_path and os.path.exists(temp_wav_path): | |
| try: | |
| os.remove(temp_wav_path) | |
| print(f" - Cleaned up WAV file: {temp_wav_path}") | |
| except OSError as e: | |
| print(f" - Error cleaning up WAV file {temp_wav_path}: {e}") | |
| # --- 服务启动 --- | |
| if __name__ == '__main__': | |
| print(f"\n服务启动完成,http地址是: http://{host}:{port} \n") | |
| serve(app, host=host, port=port, threads=threads) | |