zhendery commited on
Commit
c786b94
·
1 Parent(s): 7efcbef

initial commit

Browse files
Files changed (2) hide show
  1. Dockerfile +21 -0
  2. api.py +134 -0
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.7.0-cuda12.8-cudnn9-devel
2
+ # 设置时区
3
+ RUN ln -sf /share/zoneinfo/Asia/Shanghai /etc/localtime && \
4
+ echo "Asia/Shanghai" > /etc/timezone
5
+
6
+ RUN apt-get update && apt-get install -y \
7
+ curl wget unzip git git-lfs ffmpeg && \
8
+ apt-get clean && rm -rf /var/lib/apt/lists/*
9
+
10
+ RUN pip install voxcpm && pip cache purge
11
+
12
+ WORKDIR /workspace
13
+ COPY . .
14
+
15
+ ENV API_TOKEN my_secret_token
16
+ ENV VOICE_DOWNLOAD_URL http://localhost/voices.zip
17
+
18
+ VOLUME /workspace/voices
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "api.py"]
api.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Response, Depends, File, Form
2
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
+ import soundfile as sf
4
+ import numpy as np
5
+ from voxcpm import VoxCPM
6
+ from pydantic import BaseModel
7
+ import os
8
+ import requests
9
+ import zipfile
10
+
11
+ model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
12
+
13
+ security = HTTPBearer()
14
+ app = FastAPI()
15
+
16
+ # 验证函数
17
+ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
18
+ # 从环境变量获取token
19
+ expected_token = os.getenv("API_TOKEN", "my_secret_token")
20
+ if credentials.credentials != expected_token:
21
+ raise HTTPException(status_code=401, detail="Invalid or missing token")
22
+ return credentials.credentials
23
+
24
+ class GenerateRequest(BaseModel):
25
+ text: str
26
+ voice: str
27
+ cfg_value: float = 2.0
28
+ inference_timesteps: int = 10
29
+ do_normalize: bool = True
30
+ denoise: bool = True
31
+
32
+ @app.post("/generate")
33
+ def generate_tts(request: GenerateRequest, token: str = Depends(verify_token)):
34
+ text = (request.text or "").strip()
35
+ if len(text) == 0:
36
+ raise ValueError("Please input text to synthesize.")
37
+ print(f"Generating audio for text: '{text[:60]}...'")
38
+
39
+ with open(f"/workspace/voices/{request.voice}.pmt", 'r', encoding='utf-8') as f:
40
+ wav = model.generate(
41
+ text=text,
42
+ prompt_wav_path=f"/workspace/voices/{request.voice}.wav",
43
+ prompt_text=f.read(),
44
+ cfg_value=request.cfg_value,
45
+ inference_timesteps=request.inference_timesteps,
46
+ normalize=request.do_normalize,
47
+ denoise=request.denoise
48
+ )
49
+
50
+ sf.write("output.wav", wav, 16000)
51
+ return Response(content=open("output.wav", 'rb').read(), media_type="audio/wav")
52
+
53
+
54
+ def download_voices():
55
+ # 检查 /workspace/voices/ 目录中是否有 .pmt 文件
56
+ voices_dir = "/workspace/voices"
57
+ if not os.path.exists(voices_dir):
58
+ os.makedirs(voices_dir)
59
+
60
+ pmt_files = [f for f in os.listdir(voices_dir) if f.endswith(".pmt")]
61
+ if not pmt_files:
62
+ # 如果没有 .pmt 文件,尝试从远程下载
63
+ voice_download_url = os.getenv("VOICE_DOWNLOAD_URL")
64
+
65
+ if voice_download_url:
66
+ try:
67
+ response = requests.get(voice_download_url)
68
+ response.raise_for_status()
69
+
70
+ # 保存下载的zip文件
71
+ zip_path = f"{voices_dir}/voices.zip"
72
+ with open(zip_path, "wb") as f:
73
+ f.write(response.content)
74
+
75
+ # 解压zip文件
76
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
77
+ zip_ref.extractall(voices_dir)
78
+
79
+ # 删除临时zip文件
80
+ os.remove(zip_path)
81
+
82
+ except Exception as e:
83
+ print(f"Failed to download and extract voices: {e}")
84
+ raise HTTPException(status_code=500, detail="Failed to download voice files")
85
+
86
+ @app.post("/upload_voice")
87
+ def upload_voice(name: str = Form(...), wav: bytes = File(...), prompt: str = Form(...), token: str = Depends(verify_token)):
88
+ # 保存wav文件
89
+ with open(f"/workspace/voices/{name}.wav", 'wb') as f:
90
+ f.write(wav)
91
+
92
+ # 保存pmt文件
93
+ with open(f"/workspace/voices/{name}.pmt", 'w', encoding='utf-8') as f:
94
+ f.write(prompt)
95
+
96
+ return {"status": "success"}
97
+
98
+ @app.delete("/delete_voice")
99
+ def delete_voice(name: str, token: str = Depends(verify_token)):
100
+ wav_file = f"/workspace/voices/{name}.wav"
101
+ pmt_file = f"/workspace/voices/{name}.pmt"
102
+
103
+ # 检查文件是否存在
104
+ if os.path.exists(wav_file):
105
+ os.remove(wav_file)
106
+ if os.path.exists(pmt_file):
107
+ os.remove(pmt_file)
108
+ return {"status": "success"}
109
+ else:
110
+ return {"status": "不存在"}
111
+
112
+ @app.get("/voices")
113
+ def get_voices(token: str = Depends(verify_token)):
114
+ download_voices()
115
+ # 获取所有 .pmt 文件
116
+ pmt_files = [f for f in os.listdir("/workspace/voices") if f.endswith(".pmt")]
117
+ # 提取文件名(去掉 .pmt 后缀)
118
+ voices = [f.split(".")[0] for f in pmt_files]
119
+ # 确保对应的 .wav 文件也存在
120
+ valid_voices = []
121
+ for voice in voices:
122
+ if os.path.exists(f"/workspace/voices/{voice}.wav"):
123
+ valid_voices.append(voice)
124
+ return {"voices": valid_voices}
125
+
126
+
127
+ # ↓↓↓↓↓↓↓↓↓无需验证↓↓↓↓↓↓↓↓
128
+ @app.get("/")
129
+ def health_check():
130
+ return {"status": "health"}
131
+
132
+ if __name__ == "__main__":
133
+ import uvicorn
134
+ uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)