voxcpm / api.py
zhendery
initial commit
c786b94
raw
history blame
4.67 kB
from fastapi import FastAPI, HTTPException, Response, Depends, File, Form
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import soundfile as sf
import numpy as np
from voxcpm import VoxCPM
from pydantic import BaseModel
import os
import requests
import zipfile
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
security = HTTPBearer()
app = FastAPI()
# 验证函数
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
# 从环境变量获取token
expected_token = os.getenv("API_TOKEN", "my_secret_token")
if credentials.credentials != expected_token:
raise HTTPException(status_code=401, detail="Invalid or missing token")
return credentials.credentials
class GenerateRequest(BaseModel):
text: str
voice: str
cfg_value: float = 2.0
inference_timesteps: int = 10
do_normalize: bool = True
denoise: bool = True
@app.post("/generate")
def generate_tts(request: GenerateRequest, token: str = Depends(verify_token)):
text = (request.text or "").strip()
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
print(f"Generating audio for text: '{text[:60]}...'")
with open(f"/workspace/voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
wav = model.generate(
text=text,
prompt_wav_path=f"/workspace/voices/{request.voice}.wav",
prompt_text=f.read(),
cfg_value=request.cfg_value,
inference_timesteps=request.inference_timesteps,
normalize=request.do_normalize,
denoise=request.denoise
)
sf.write("output.wav", wav, 16000)
return Response(content=open("output.wav", 'rb').read(), media_type="audio/wav")
def download_voices():
# 检查 /workspace/voices/ 目录中是否有 .pmt 文件
voices_dir = "/workspace/voices"
if not os.path.exists(voices_dir):
os.makedirs(voices_dir)
pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")]
if not pmt_files:
# 如果没有 .pmt 文件,尝试从远程下载
voice_download_url = os.getenv("VOICE_DOWNLOAD_URL")
if voice_download_url:
try:
response = requests.get(voice_download_url)
response.raise_for_status()
# 保存下载的zip文件
zip_path = f"{voices_dir}/voices.zip"
with open(zip_path, "wb") as f:
f.write(response.content)
# 解压zip文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(voices_dir)
# 删除临时zip文件
os.remove(zip_path)
except Exception as e:
print(f"Failed to download and extract voices: {e}")
raise HTTPException(status_code=500, detail="Failed to download voice files")
@app.post("/upload_voice")
def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)):
# 保存wav文件
with open(f"/workspace/voices/{name}.wav", 'wb') as f:
f.write(wav)
# 保存pmt文件
with open(f"/workspace/voices/{name}.pmt", 'w', encoding='utf-8') as f:
f.write(prompt)
return {"status": "success"}
@app.delete("/delete_voice")
def delete_voice(name: str, token: str = Depends(verify_token)):
wav_file = f"/workspace/voices/{name}.wav"
pmt_file = f"/workspace/voices/{name}.pmt"
# 检查文件是否存在
if os.path.exists(wav_file):
os.remove(wav_file)
if os.path.exists(pmt_file):
os.remove(pmt_file)
return {"status": "success"}
else:
return {"status": "不存在"}
@app.get("/voices")
def get_voices(token: str = Depends(verify_token)):
download_voices()
# 获取所有 .pmt 文件
pmt_files = [f for f in os.listdir("/workspace/voices") if f.endswith(".pmt")]
# 提取文件名(去掉 .pmt 后缀)
voices = [f.split(".")[0] for f in pmt_files]
# 确保对应的 .wav 文件也存在
valid_voices = []
for voice in voices:
if os.path.exists(f"/workspace/voices/{voice}.wav"):
valid_voices.append(voice)
return {"voices": valid_voices}
# ↓↓↓↓↓↓↓↓↓无需验证↓↓↓↓↓↓↓↓
@app.get("/")
def health_check():
return {"status": "health"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)