|
|
from fastapi import FastAPI, HTTPException, Response, Depends, File, Form |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
from voxcpm import VoxCPM |
|
|
from pydantic import BaseModel |
|
|
import os |
|
|
import requests |
|
|
import zipfile |
|
|
|
|
|
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B") |
|
|
|
|
|
security = HTTPBearer() |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
|
|
|
|
|
expected_token = os.getenv("API_TOKEN", "my_secret_token") |
|
|
if credentials.credentials != expected_token: |
|
|
raise HTTPException(status_code=401, detail="Invalid or missing token") |
|
|
return credentials.credentials |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
text: str |
|
|
voice: str |
|
|
cfg_value: float = 2.0 |
|
|
inference_timesteps: int = 10 |
|
|
do_normalize: bool = True |
|
|
denoise: bool = True |
|
|
|
|
|
@app.post("/generate") |
|
|
def generate_tts(request: GenerateRequest, token: str = Depends(verify_token)): |
|
|
text = (request.text or "").strip() |
|
|
if len(text) == 0: |
|
|
raise ValueError("Please input text to synthesize.") |
|
|
print(f"Generating audio for text: '{text[:60]}...'") |
|
|
|
|
|
with open(f"/workspace/voices/{request.voice}.pmt", 'r', encoding='utf-8') as f: |
|
|
wav = model.generate( |
|
|
text=text, |
|
|
prompt_wav_path=f"/workspace/voices/{request.voice}.wav", |
|
|
prompt_text=f.read(), |
|
|
cfg_value=request.cfg_value, |
|
|
inference_timesteps=request.inference_timesteps, |
|
|
normalize=request.do_normalize, |
|
|
denoise=request.denoise |
|
|
) |
|
|
|
|
|
sf.write("output.wav", wav, 16000) |
|
|
return Response(content=open("output.wav", 'rb').read(), media_type="audio/wav") |
|
|
|
|
|
|
|
|
def download_voices(): |
|
|
|
|
|
voices_dir = "/workspace/voices" |
|
|
if not os.path.exists(voices_dir): |
|
|
os.makedirs(voices_dir) |
|
|
|
|
|
pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")] |
|
|
if not pmt_files: |
|
|
|
|
|
voice_download_url = os.getenv("VOICE_DOWNLOAD_URL") |
|
|
|
|
|
if voice_download_url: |
|
|
try: |
|
|
response = requests.get(voice_download_url) |
|
|
response.raise_for_status() |
|
|
|
|
|
|
|
|
zip_path = f"{voices_dir}/voices.zip" |
|
|
with open(zip_path, "wb") as f: |
|
|
f.write(response.content) |
|
|
|
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(voices_dir) |
|
|
|
|
|
|
|
|
os.remove(zip_path) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to download and extract voices: {e}") |
|
|
raise HTTPException(status_code=500, detail="Failed to download voice files") |
|
|
|
|
|
@app.post("/upload_voice") |
|
|
def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)): |
|
|
|
|
|
with open(f"/workspace/voices/{name}.wav", 'wb') as f: |
|
|
f.write(wav) |
|
|
|
|
|
|
|
|
with open(f"/workspace/voices/{name}.pmt", 'w', encoding='utf-8') as f: |
|
|
f.write(prompt) |
|
|
|
|
|
return {"status": "success"} |
|
|
|
|
|
@app.delete("/delete_voice") |
|
|
def delete_voice(name: str, token: str = Depends(verify_token)): |
|
|
wav_file = f"/workspace/voices/{name}.wav" |
|
|
pmt_file = f"/workspace/voices/{name}.pmt" |
|
|
|
|
|
|
|
|
if os.path.exists(wav_file): |
|
|
os.remove(wav_file) |
|
|
if os.path.exists(pmt_file): |
|
|
os.remove(pmt_file) |
|
|
return {"status": "success"} |
|
|
else: |
|
|
return {"status": "不存在"} |
|
|
|
|
|
@app.get("/voices") |
|
|
def get_voices(token: str = Depends(verify_token)): |
|
|
download_voices() |
|
|
|
|
|
pmt_files = [f for f in os.listdir("/workspace/voices") if f.endswith(".pmt")] |
|
|
|
|
|
voices = [f.split(".")[0] for f in pmt_files] |
|
|
|
|
|
valid_voices = [] |
|
|
for voice in voices: |
|
|
if os.path.exists(f"/workspace/voices/{voice}.wav"): |
|
|
valid_voices.append(voice) |
|
|
return {"voices": valid_voices} |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def health_check(): |
|
|
return {"status": "health"} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |
|
|
|