Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,11 +8,11 @@ from cnocr import CnOcr
|
|
| 8 |
import numpy as np
|
| 9 |
import openai
|
| 10 |
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
|
| 11 |
-
from transformers import pipeline
|
| 12 |
import opencc
|
| 13 |
import scipy
|
| 14 |
import torch
|
| 15 |
-
import
|
| 16 |
|
| 17 |
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
|
| 18 |
ocr = CnOcr() # 初始化ocr模型
|
|
@@ -21,8 +21,9 @@ all_max_len = 2000 # 输入的最大长度
|
|
| 21 |
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
|
| 22 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 23 |
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
|
| 24 |
-
|
| 25 |
-
|
|
|
|
| 26 |
|
| 27 |
def get_text_emb(open_ai_key, text): # 文本向量化
|
| 28 |
openai.api_key = open_ai_key # 设置openai的key
|
|
@@ -145,14 +146,16 @@ def get_response_by_llama_index(open_ai_key, msg, bot, query_engine): # 获取
|
|
| 145 |
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
|
| 146 |
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
|
|
@@ -160,8 +163,7 @@ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_eng
|
|
| 160 |
bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
|
| 161 |
else: # 如果是使用llama_index索引
|
| 162 |
bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
|
| 163 |
-
|
| 164 |
-
return bot, gr.Audio(audio_answer_dir)
|
| 165 |
|
| 166 |
|
| 167 |
def up_file(files): # 上传文件
|
|
@@ -268,7 +270,7 @@ with gr.Blocks() as demo:
|
|
| 268 |
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
|
| 269 |
chat_bu.click(get_response,
|
| 270 |
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
|
| 271 |
-
[chat_bot, audio_answer]) # 发送消息
|
| 272 |
|
| 273 |
if __name__ == "__main__":
|
| 274 |
demo.queue(concurrency_count=4).launch()
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import openai
|
| 10 |
from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
|
| 11 |
+
from transformers import pipeline, BarkModel, BarkProcessor
|
| 12 |
import opencc
|
| 13 |
import scipy
|
| 14 |
import torch
|
| 15 |
+
import hashlib
|
| 16 |
|
| 17 |
converter = opencc.OpenCC('t2s') # 创建一个OpenCC实例,指定繁体字转为简体字
|
| 18 |
ocr = CnOcr() # 初始化ocr模型
|
|
|
|
| 21 |
asr_model_id = "openai/whisper-tiny" # 更新为你的模型ID
|
| 22 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 23 |
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model_id, device=device)
|
| 24 |
+
bark_model = BarkModel.from_pretrained("suno/bark-small")
|
| 25 |
+
bark_processor = BarkProcessor.from_pretrained("suno/bark-small")
|
| 26 |
+
sampling_rate = bark_model.generation_config.sample_rate
|
| 27 |
|
| 28 |
def get_text_emb(open_ai_key, text): # 文本向量化
|
| 29 |
openai.api_key = open_ai_key # 设置openai的key
|
|
|
|
| 146 |
return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
|
| 147 |
|
| 148 |
|
| 149 |
+
def get_audio_answer(bot): # 获取语音回答
|
| 150 |
+
answer = bot[-1][1]
|
| 151 |
+
inputs = bark_processor(
|
| 152 |
+
text=[answer],
|
| 153 |
+
return_tensors="pt",
|
| 154 |
+
)
|
| 155 |
+
speech_values = bark_model.generate(**inputs, do_sample=True)
|
| 156 |
+
au_dir = hashlib.md5(answer.encode('utf-8')).hexdigest() + '.wav' # 获取md5
|
| 157 |
+
scipy.io.wavfile.write(au_dir, rate=sampling_rate, data=speech_values.cpu().numpy().squeeze())
|
| 158 |
+
return gr.Audio().update(au_dir, autoplay=True)
|
| 159 |
|
| 160 |
|
| 161 |
def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
|
|
|
|
| 163 |
bot = get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
|
| 164 |
else: # 如果是使用llama_index索引
|
| 165 |
bot = get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
|
| 166 |
+
return bot
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def up_file(files): # 上传文件
|
|
|
|
| 270 |
audio_inputs.change(transcribe_speech, [open_ai_key, audio_inputs, asr_type], [msg_txt]) # 录音输入
|
| 271 |
chat_bu.click(get_response,
|
| 272 |
[open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
|
| 273 |
+
[chat_bot])# .then(get_audio_answer, [chat_bot], [audio_answer]) # 发送消息
|
| 274 |
|
| 275 |
if __name__ == "__main__":
|
| 276 |
demo.queue(concurrency_count=4).launch()
|