Spaces:
Running
Running
| import os | |
| import time | |
| import tempfile | |
| import subprocess | |
| import threading | |
| import json | |
| import base64 | |
| import io | |
| import random | |
| import logging | |
| from queue import Queue | |
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| import librosa | |
| import soundfile as sf | |
| import requests | |
| import numpy as np | |
| from scipy import signal | |
| from transformers import pipeline, AutoTokenizer, AutoModel | |
| # Thiết lập logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Tạo các thư mục cần thiết | |
| os.makedirs("data", exist_ok=True) | |
| os.makedirs("data/audio", exist_ok=True) | |
| os.makedirs("data/reports", exist_ok=True) | |
| os.makedirs("data/models", exist_ok=True) | |
| class AsyncProcessor: | |
| """Xử lý các tác vụ nặng trong thread riêng để không làm 'đơ' giao diện.""" | |
| def __init__(self): | |
| self.task_queue = Queue() | |
| self.result_queue = Queue() | |
| self.running = True | |
| self.worker_thread = Thread(target=self._worker) | |
| self.worker_thread.daemon = True | |
| self.worker_thread.start() | |
| def _worker(self): | |
| while self.running: | |
| if not self.task_queue.empty(): | |
| task_id, func, args, kwargs = self.task_queue.get() | |
| try: | |
| result = func(*args, **kwargs) | |
| self.result_queue.put((task_id, result, None)) | |
| except Exception as e: | |
| logger.error(f"Lỗi trong xử lý tác vụ {task_id}: {str(e)}") | |
| self.result_queue.put((task_id, None, str(e))) | |
| self.task_queue.task_done() | |
| time.sleep(0.1) | |
| def add_task(self, task_id, func, *args, **kwargs): | |
| self.task_queue.put((task_id, func, args, kwargs)) | |
| def get_result(self): | |
| if not self.result_queue.empty(): | |
| return self.result_queue.get() | |
| return None | |
| def stop(self): | |
| self.running = False | |
| if self.worker_thread.is_alive(): | |
| self.worker_thread.join(timeout=1) | |
| class VietSpeechTrainer: | |
| def __init__(self): | |
| # Đọc cấu hình từ file config.json và từ biến môi trường | |
| self.config = self._load_config() | |
| # Khởi tạo bộ xử lý bất đồng bộ | |
| self.async_processor = AsyncProcessor() | |
| # Lưu trữ lịch sử phiên làm việc | |
| self.session_history = [] | |
| self.current_session_id = int(time.time()) | |
| # Các biến trạng thái hội thoại | |
| self.current_scenario = None | |
| self.current_prompt_index = 0 | |
| # Khởi tạo các mô hình (STT, TTS và phân tích LLM) | |
| logger.info("Đang tải các mô hình...") | |
| self._initialize_models() | |
| def _load_config(self): | |
| """Đọc file config.json và cập nhật từ biến môi trường (Secrets khi deploy)""" | |
| config = { | |
| "stt_model": "nguyenvulebinh/wav2vec2-base-vietnamese-250h", | |
| "use_phowhisper": False, | |
| "use_phobert": False, | |
| "use_vncorenlp": False, | |
| "llm_provider": "none", # openai, gemini, local hoặc none | |
| "openai_api_key": "", | |
| "gemini_api_key": "", | |
| "local_llm_endpoint": "", | |
| "use_viettts": False, | |
| "default_dialect": "Bắc", | |
| "enable_pronunciation_eval": False, | |
| "preprocess_audio": True, | |
| "save_history": True, | |
| "enable_english_tts": False | |
| } | |
| if os.path.exists("config.json"): | |
| try: | |
| with open("config.json", "r", encoding="utf-8") as f: | |
| file_config = json.load(f) | |
| config.update(file_config) | |
| except Exception as e: | |
| logger.error(f"Lỗi đọc config.json: {e}") | |
| # Cập nhật từ biến môi trường | |
| if os.environ.get("LLM_PROVIDER"): | |
| config["llm_provider"] = os.environ.get("LLM_PROVIDER").lower() | |
| if os.environ.get("OPENAI_API_KEY"): | |
| config["openai_api_key"] = os.environ.get("OPENAI_API_KEY") | |
| if os.environ.get("GEMINI_API_KEY"): | |
| config["gemini_api_key"] = os.environ.get("GEMINI_API_KEY") | |
| if os.environ.get("LOCAL_LLM_ENDPOINT"): | |
| config["local_llm_endpoint"] = os.environ.get("LOCAL_LLM_ENDPOINT") | |
| if os.environ.get("ENABLE_ENGLISH_TTS") and os.environ.get("ENABLE_ENGLISH_TTS").lower() == "true": | |
| config["enable_english_tts"] = True | |
| return config | |
| def _initialize_models(self): | |
| """Khởi tạo mô hình STT và thiết lập CSM cho TTS tiếng Anh nếu được bật.""" | |
| try: | |
| # Khởi tạo STT | |
| if self.config["use_phowhisper"]: | |
| logger.info("Loading PhoWhisper...") | |
| self.stt_model = pipeline("automatic-speech-recognition", | |
| model="vinai/PhoWhisper-small", | |
| device=0 if torch.cuda.is_available() else -1) | |
| else: | |
| logger.info(f"Loading STT model: {self.config['stt_model']}") | |
| self.stt_model = pipeline("automatic-speech-recognition", | |
| model=self.config["stt_model"], | |
| device=0 if torch.cuda.is_available() else -1) | |
| except Exception as e: | |
| logger.error(f"Lỗi khởi tạo STT: {e}") | |
| self.stt_model = None | |
| # Các mô hình NLP (PhoBERT, VnCoreNLP) nếu cần. | |
| # ... | |
| # Nếu bật TTS tiếng Anh thì thiết lập CSM | |
| if self.config.get("enable_english_tts", False): | |
| self._setup_csm() | |
| else: | |
| self.csm_ready = False | |
| def _setup_csm(self): | |
| """Cài đặt mô hình CSM (Conversational Speech Generation Model) cho TTS tiếng Anh.""" | |
| try: | |
| csm_dir = os.path.join(os.getcwd(), "csm") | |
| if not os.path.exists(csm_dir): | |
| logger.info("Cloning CSM repo...") | |
| subprocess.run(["git", "clone", "https://github.com/SesameAILabs/csm", csm_dir], check=True) | |
| logger.info("Installing CSM requirements...") | |
| subprocess.run(["pip", "install", "-r", os.path.join(csm_dir, "requirements.txt")], check=True) | |
| self.csm_ready = True | |
| logger.info("CSM đã được thiết lập thành công!") | |
| except Exception as e: | |
| logger.error(f"Failed to set up CSM: {e}") | |
| self.csm_ready = False | |
| def text_to_speech(self, text, language="vi", dialect="Bắc"): | |
| """ | |
| Chuyển văn bản thành giọng nói: | |
| - Nếu language == "en": sử dụng CSM để tạo TTS tiếng Anh. | |
| - Nếu language == "vi": sử dụng API hoặc logic TTS tiếng Việt. | |
| """ | |
| if language == "en": | |
| if not self.csm_ready: | |
| logger.error("CSM chưa được thiết lập hoặc không được bật.") | |
| return None | |
| output_file = f"data/audio/csm_{int(time.time())}.wav" | |
| csm_script_path = os.path.join(os.getcwd(), "csm", "run_csm.py") | |
| cmd = [ | |
| "python", | |
| csm_script_path, | |
| "--text", text, | |
| "--speaker_id", "0", # Mặc định, có thể cho phép người dùng chọn | |
| "--output", output_file | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True) | |
| return output_file | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"CSM generation failed: {e}") | |
| return None | |
| else: | |
| # Ví dụ: Nếu có API TTS tiếng Việt, gọi API đó. | |
| tts_api_url = self.config.get("tts_api_url", "") | |
| if tts_api_url: | |
| try: | |
| resp = requests.post(tts_api_url, json={"text": text, "dialect": dialect.lower()}) | |
| if resp.status_code == 200: | |
| output_file = f"data/audio/tts_{int(time.time())}.wav" | |
| with open(output_file, "wb") as f: | |
| f.write(resp.content) | |
| return output_file | |
| else: | |
| logger.error(f"Error calling TTS API: {resp.text}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Lỗi gọi TTS API: {e}") | |
| return None | |
| else: | |
| # Nếu không có API TTS, bạn có thể tích hợp VietTTS hoặc khác. | |
| return None | |
| def transcribe_audio(self, audio_path): | |
| """Chuyển đổi giọng nói thành văn bản (STT).""" | |
| if not self.stt_model: | |
| return "STT model not available." | |
| try: | |
| result = self.stt_model(audio_path) | |
| if isinstance(result, dict) and "text" in result: | |
| return result["text"] | |
| elif isinstance(result, list): | |
| return " ".join([chunk.get("text", "") for chunk in result]) | |
| else: | |
| return str(result) | |
| except Exception as e: | |
| logger.error(f"Lỗi chuyển giọng nói: {e}") | |
| return f"Lỗi: {str(e)}" | |
| def analyze_text(self, transcript, dialect="Bắc"): | |
| """ | |
| Phân tích văn bản sử dụng LLM: | |
| - Nếu LLM_PROVIDER là "openai", "gemini" hay "local" thì gọi API tương ứng. | |
| - Nếu LLM_PROVIDER là "none", sử dụng phân tích rule-based. | |
| """ | |
| llm_provider = self.config["llm_provider"] | |
| if llm_provider == "openai" and self.config["openai_api_key"]: | |
| return self._analyze_with_openai(transcript) | |
| elif llm_provider == "gemini" and self.config["gemini_api_key"]: | |
| return self._analyze_with_gemini(transcript) | |
| elif llm_provider == "local" and self.config["local_llm_endpoint"]: | |
| return self._analyze_with_local_llm(transcript) | |
| else: | |
| return self._rule_based_analysis(transcript, dialect) | |
| def _analyze_with_openai(self, transcript): | |
| headers = { | |
| "Authorization": f"Bearer {self.config['openai_api_key']}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": "gpt-3.5-turbo", | |
| "messages": [ | |
| {"role": "system", "content": "Bạn là trợ lý dạy tiếng Việt."}, | |
| {"role": "user", "content": transcript} | |
| ], | |
| "temperature": 0.5, | |
| "max_tokens": 150 | |
| } | |
| try: | |
| response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data) | |
| if response.status_code == 200: | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| else: | |
| return "Lỗi khi gọi OpenAI API." | |
| except Exception as e: | |
| logger.error(f"Lỗi OpenAI: {e}") | |
| return "Lỗi phân tích với OpenAI." | |
| def _analyze_with_gemini(self, transcript): | |
| # Ví dụ minh họa: Gọi Gemini API (chi tiết phụ thuộc vào tài liệu của Gemini) | |
| return "Gemini analysis..." | |
| def _analyze_with_local_llm(self, transcript): | |
| # Giả sử gọi một endpoint local (nếu có) cho LLM cục bộ. | |
| headers = {"Content-Type": "application/json"} | |
| data = { | |
| "model": "local-model", | |
| "messages": [ | |
| {"role": "system", "content": "Bạn là trợ lý dạy tiếng Việt."}, | |
| {"role": "user", "content": transcript} | |
| ], | |
| "temperature": 0.5, | |
| "max_tokens": 150 | |
| } | |
| try: | |
| response = requests.post(self.config["local_llm_endpoint"] + "/chat/completions", headers=headers, json=data) | |
| if response.status_code == 200: | |
| result = response.json() | |
| return result["choices"][0]["message"]["content"] | |
| else: | |
| return "Lỗi khi gọi Local LLM." | |
| except Exception as e: | |
| logger.error(f"Lỗi local LLM: {e}") | |
| return "Lỗi phân tích với LLM local." | |
| def _rule_based_analysis(self, transcript, dialect): | |
| # Phân tích đơn giản không dùng LLM | |
| return "Phân tích rule-based: " + transcript | |
| def clean_up(self): | |
| self.async_processor.stop() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("Clean up done.") | |
| def create_demo(): | |
| trainer = VietSpeechTrainer() | |
| with gr.Blocks(title="Ứng dụng Luyện Nói & TTS", theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
| gr.Markdown("## Ứng dụng Luyện Nói & TTS (Tiếng Việt & Tiếng Anh)") | |
| with gr.Tabs(): | |
| # Tab 1: TTS Tiếng Việt | |
| with gr.Tab("TTS Tiếng Việt"): | |
| vi_text_input = gr.Textbox(label="Nhập văn bản tiếng Việt") | |
| vi_audio_output = gr.Audio(label="Kết quả âm thanh") | |
| gen_vi_btn = gr.Button("Chuyển thành giọng nói") | |
| def gen_vi_tts(txt): | |
| return trainer.text_to_speech(txt, language="vi", dialect=trainer.config["default_dialect"]) | |
| gen_vi_btn.click(fn=gen_vi_tts, inputs=vi_text_input, outputs=vi_audio_output) | |
| # Tab 2: TTS Tiếng Anh (sử dụng CSM) | |
| with gr.Tab("TTS Tiếng Anh"): | |
| en_text_input = gr.Textbox(label="Enter English text") | |
| en_audio_output = gr.Audio(label="Generated English Audio (CSM)") | |
| gen_en_btn = gr.Button("Generate English Speech") | |
| def gen_en_tts(txt): | |
| return trainer.text_to_speech(txt, language="en") | |
| gen_en_btn.click(fn=gen_en_tts, inputs=en_text_input, outputs=en_audio_output) | |
| # Tab 3: Luyện phát âm (Tiếng Việt) | |
| with gr.Tab("Luyện phát âm"): | |
| audio_input = gr.Audio(source="microphone", type="filepath", label="Giọng nói của bạn") | |
| transcript_output = gr.Textbox(label="Transcript") | |
| analysis_output = gr.Markdown(label="Phân tích") | |
| analyze_btn = gr.Button("Phân tích") | |
| def process_audio(audio_path): | |
| transcript = trainer.transcribe_audio(audio_path) | |
| analysis = trainer.analyze_text(transcript, dialect=trainer.config["default_dialect"]) | |
| return transcript, analysis | |
| analyze_btn.click(fn=process_audio, inputs=audio_input, outputs=[transcript_output, analysis_output]) | |
| # Tab 4: Thông tin & Hướng dẫn | |
| with gr.Tab("Thông tin"): | |
| gr.Markdown(""" | |
| ### Hướng dẫn sử dụng: | |
| - **TTS Tiếng Việt:** Nhập văn bản tiếng Việt và nhấn "Chuyển thành giọng nói". | |
| - **TTS Tiếng Anh (CSM):** Nhập English text và nhấn "Generate English Speech". | |
| - **Luyện phát âm:** Thu âm giọng nói, sau đó nhấn "Phân tích" để xem transcript và phân tích. | |
| ### Cấu hình LLM: | |
| - **OpenAI:** Đặt biến môi trường `LLM_PROVIDER=openai` và `OPENAI_API_KEY` với key của bạn. | |
| - **Gemini:** Đặt `LLM_PROVIDER=gemini` và `GEMINI_API_KEY`. | |
| - **Local LLM:** Đặt `LLM_PROVIDER=local` và `LOCAL_LLM_ENDPOINT` với URL của server LLM nếu bạn có. | |
| - **None:** Đặt `LLM_PROVIDER=none` để sử dụng phân tích rule-based. | |
| ### Lưu ý: | |
| - Để sử dụng TTS tiếng Anh (CSM), hãy bật biến `ENABLE_ENGLISH_TTS` (hoặc đặt `"enable_english_tts": true` trong config.json). | |
| """) | |
| return demo | |
| def main(): | |
| demo = create_demo() | |
| # Sử dụng hàng đợi Gradio để xử lý tác vụ dài (ví dụ TTS CSM) | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| main() | |