|
|
from fastapi import FastAPI, HTTPException, Response, Depends, File, Form |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
from voxcpm import VoxCPM |
|
|
from pydantic import BaseModel |
|
|
import os |
|
|
import requests |
|
|
import zipfile |
|
|
from utils import * |
|
|
import uuid |
|
|
import queue |
|
|
import threading |
|
|
import asyncio |
|
|
|
|
|
|
|
|
security = HTTPBearer() |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
|
|
|
|
|
expected_token = os.getenv("API_TOKEN", "my_secret_token") |
|
|
if credentials.credentials != expected_token: |
|
|
raise HTTPException(status_code=401, detail="Invalid or missing token") |
|
|
return credentials.credentials |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
text: str |
|
|
voice: str |
|
|
cfg_value: float = 2.0 |
|
|
inference_timesteps: int = 10 |
|
|
do_normalize: bool = True |
|
|
denoise: bool = True |
|
|
|
|
|
def download_voices(bForce=False): |
|
|
|
|
|
voices_dir = "/workspace/voices" |
|
|
if not os.path.exists(voices_dir): |
|
|
os.makedirs(voices_dir) |
|
|
|
|
|
pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")] |
|
|
if bForce or not pmt_files: |
|
|
|
|
|
voice_download_url = os.getenv("VOICE_DOWNLOAD_URL") |
|
|
|
|
|
if voice_download_url: |
|
|
try: |
|
|
response = requests.get(voice_download_url) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
zip_path = f"{voices_dir}/voices.zip" |
|
|
with open(zip_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(voices_dir) |
|
|
|
|
|
|
|
|
os.remove(zip_path) |
|
|
|
|
|
except Exception as e: |
|
|
print_with_time(f"Failed to download and extract voices: {e}") |
|
|
raise HTTPException(status_code=500, detail="Failed to download voice files") |
|
|
|
|
|
|
|
|
task_queue = queue.Queue() |
|
|
output_dir = "./output" |
|
|
max_output_files = 10 |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
def cleanup_old_files(): |
|
|
"""清理最老的文件,保持最多10个""" |
|
|
try: |
|
|
files = [(f, os.path.getctime(os.path.join(output_dir, f))) for f in os.listdir(output_dir) if f.endswith('.wav')] |
|
|
files.sort(key=lambda x: x[1]) |
|
|
|
|
|
|
|
|
while len(files) > max_output_files: |
|
|
oldest_file = files.pop(0)[0] |
|
|
os.remove(os.path.join(output_dir, oldest_file)) |
|
|
except Exception as e: |
|
|
print_with_time(f"Error cleaning up old files: {e}") |
|
|
|
|
|
async def process_queue(): |
|
|
print_with_time("Loading VoxCPM model...") |
|
|
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B") |
|
|
print_with_time("VoxCPM model loaded.") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
task_data = task_queue.get_nowait() |
|
|
request = task_data["request"] |
|
|
text = (request.text or "").strip() |
|
|
if len(text) == 0: |
|
|
continue |
|
|
|
|
|
if model is None: |
|
|
raise RuntimeError("Failed to initialize model") |
|
|
|
|
|
download_voices() |
|
|
print_with_time(f"Generating audio for : '{text[:60]}...'") |
|
|
with open(f"./voices/{request.voice}.pmt", 'r', encoding='utf-8') as f: |
|
|
wav = model.generate( |
|
|
text=text, |
|
|
prompt_wav_path=f"./voices/{request.voice}.wav", |
|
|
prompt_text=f.read(), |
|
|
cfg_value=request.cfg_value, |
|
|
inference_timesteps=request.inference_timesteps, |
|
|
normalize=request.do_normalize, |
|
|
denoise=request.denoise |
|
|
) |
|
|
sf.write(os.path.join(output_dir, f"{task_data['task_id']}.wav"), wav, 16000) |
|
|
|
|
|
|
|
|
cleanup_old_files() |
|
|
|
|
|
task_queue.task_done() |
|
|
print_with_time("audio generated.") |
|
|
await asyncio.sleep(0.6) |
|
|
except queue.Empty: |
|
|
await asyncio.sleep(0.6) |
|
|
except Exception as e: |
|
|
print_with_time(f"Error processing queue item: {e}") |
|
|
await asyncio.sleep(0.6) |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_tts_async(request: GenerateRequest, token: str = Depends(verify_token)): |
|
|
task_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
task_data = {"task_id": task_id, "request": request} |
|
|
task_queue.put(task_data) |
|
|
|
|
|
return {"task_id": task_id} |
|
|
|
|
|
@app.get("/tts/{task_id}") |
|
|
async def get_generate_result(task_id: str, token: str = Depends(verify_token)): |
|
|
filepath = os.path.join(output_dir, f"{task_id}.wav") |
|
|
|
|
|
if not os.path.exists(filepath): |
|
|
raise HTTPException(status_code=404, detail="Result file not found") |
|
|
try: |
|
|
with open(filepath, 'rb') as f: |
|
|
content = f.read() |
|
|
return Response(content=content, media_type="audio/wav") |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Failed to read result file: {str(e)}") |
|
|
|
|
|
@app.post("/upload_voice") |
|
|
def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)): |
|
|
|
|
|
with open(f"/workspace/voices/{name}.wav", 'wb') as f: |
|
|
f.write(wav) |
|
|
|
|
|
|
|
|
with open(f"/workspace/voices/{name}.pmt", 'w', encoding='utf-8') as f: |
|
|
f.write(prompt) |
|
|
|
|
|
return {"status": "success"} |
|
|
|
|
|
@app.delete("/delete_voice") |
|
|
def delete_voice(name: str, token: str = Depends(verify_token)): |
|
|
wav_file = f"/workspace/voices/{name}.wav" |
|
|
pmt_file = f"/workspace/voices/{name}.pmt" |
|
|
|
|
|
|
|
|
if os.path.exists(wav_file): |
|
|
os.remove(wav_file) |
|
|
if os.path.exists(pmt_file): |
|
|
os.remove(pmt_file) |
|
|
return {"status": "success"} |
|
|
else: |
|
|
return {"status": "不存在"} |
|
|
|
|
|
@app.get("/voices") |
|
|
def get_voices(token: str = Depends(verify_token)): |
|
|
download_voices() |
|
|
|
|
|
pmt_files = [f for f in os.listdir("./voices") if f.endswith(".pmt")] |
|
|
|
|
|
voices = [f.split(".")[0] for f in pmt_files] |
|
|
|
|
|
valid_voices = [] |
|
|
for voice in voices: |
|
|
if os.path.exists(f"./voices/{voice}.wav"): |
|
|
valid_voices.append(voice) |
|
|
return {"voices": valid_voices} |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return {"status": "health"} |
|
|
|
|
|
def start_api_server(): |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|
|
threading.Thread(target=start_api_server, daemon=True).start() |
|
|
asyncio.run(process_queue()) |