Quang Long commited on
Commit
a3fd3c7
·
1 Parent(s): e578b02

update progress, save cache audio

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +25 -5
  3. app_tts.py +24 -6
.gitignore CHANGED
@@ -160,3 +160,4 @@ checkpoints/
160
  gradio_cached_examples/
161
  gfpgan/
162
  start.sh
 
 
160
  gradio_cached_examples/
161
  gfpgan/
162
  start.sh
163
+ tts_cache/
app.py CHANGED
@@ -76,11 +76,26 @@ def generate_voice_and_video(
76
  length_of_audio,
77
  blink_every,
78
  ):
 
 
 
 
 
 
 
 
79
  # 1. Sinh audio từ TTS
80
  (final_sample_rate, final_wave), _ = infer_tts(ref_audio, ref_text, gen_text, speed)
81
- # Lưu ra file tạm
82
  tmp_audio = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
 
83
  sf.write(tmp_audio.name, final_wave, final_sample_rate)
 
 
 
 
 
 
 
84
  # 2. Gọi SadTalker với audio vừa sinh ra
85
  sad_talker = SadTalker(lazy_load=True)
86
  video_path = sad_talker.test(
@@ -101,14 +116,18 @@ def generate_voice_and_video(
101
  length_of_audio,
102
  blink_every,
103
  )
104
- return tmp_audio.name, video_path
 
 
 
 
 
105
 
106
 
107
  def sadtalker_demo():
108
  download_model()
109
  with gr.Blocks(
110
  analytics_enabled=False,
111
- css="src/assets/css/atalink_theme.css",
112
  ) as sadtalker_interface:
113
  gr.Markdown(
114
  f"""
@@ -207,8 +226,9 @@ def sadtalker_demo():
207
  with gr.Row(elem_classes="gr-row"):
208
  output_audio = gr.Audio(label="🎧 Audio đã tạo", type="filepath")
209
  gen_video = gr.Video(
210
- label="Video đã tạo", format="mp4", scale=1, height=180, width=180
211
  )
 
212
 
213
  def enable_generate(audio, text, image):
214
  return gr.update(interactive=bool(audio and text and image))
@@ -246,7 +266,7 @@ def sadtalker_demo():
246
  length_of_audio,
247
  blink_every,
248
  ],
249
- outputs=[output_audio, gen_video],
250
  )
251
  with gr.Tab("Lịch sử video"):
252
  with gr.Row(elem_classes="gr-row"):
 
76
  length_of_audio,
77
  blink_every,
78
  ):
79
+ import gradio as gr
80
+ # Bắt đầu: Hiển thị trạng thái đang tạo audio
81
+ yield (
82
+ gr.update(value=None, visible=True, interactive=False),
83
+ gr.update(value=None, visible=True, interactive=False),
84
+ gr.update(value="⏳ Đang tạo âm thanh...", visible=True)
85
+ )
86
+
87
  # 1. Sinh audio từ TTS
88
  (final_sample_rate, final_wave), _ = infer_tts(ref_audio, ref_text, gen_text, speed)
 
89
  tmp_audio = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
90
+ import soundfile as sf
91
  sf.write(tmp_audio.name, final_wave, final_sample_rate)
92
+ # Audio xong, chuyển sang tạo video
93
+ yield (
94
+ gr.update(value=tmp_audio.name, visible=True, interactive=True),
95
+ gr.update(value=None, visible=True, interactive=False),
96
+ gr.update(value="⏳ Đang tạo video...", visible=True)
97
+ )
98
+
99
  # 2. Gọi SadTalker với audio vừa sinh ra
100
  sad_talker = SadTalker(lazy_load=True)
101
  video_path = sad_talker.test(
 
116
  length_of_audio,
117
  blink_every,
118
  )
119
+ # Cả audio và video đã xong
120
+ yield (
121
+ gr.update(value=tmp_audio.name, visible=True, interactive=True),
122
+ gr.update(value=video_path, visible=True, interactive=True),
123
+ gr.update(value="✅ Hoàn thành!", visible=True)
124
+ )
125
 
126
 
127
  def sadtalker_demo():
128
  download_model()
129
  with gr.Blocks(
130
  analytics_enabled=False,
 
131
  ) as sadtalker_interface:
132
  gr.Markdown(
133
  f"""
 
226
  with gr.Row(elem_classes="gr-row"):
227
  output_audio = gr.Audio(label="🎧 Audio đã tạo", type="filepath")
228
  gen_video = gr.Video(
229
+ label="Video đã tạo", format="mp4", scale=1, width=180
230
  )
231
+ status_box = gr.Textbox(label="Trạng thái tiến trình", interactive=False, value="", visible=True)
232
 
233
  def enable_generate(audio, text, image):
234
  return gr.update(interactive=bool(audio and text and image))
 
266
  length_of_audio,
267
  blink_every,
268
  ],
269
+ outputs=[output_audio, gen_video, status_box],
270
  )
271
  with gr.Tab("Lịch sử video"):
272
  with gr.Row(elem_classes="gr-row"):
app_tts.py CHANGED
@@ -8,8 +8,6 @@ from cached_path import cached_path
8
  import tempfile
9
  from vinorm import TTSnorm
10
  from importlib.resources import files
11
- # import sys
12
- # sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
13
  from f5_tts.model import DiT
14
  from f5_tts.infer.utils_infer import (
15
  preprocess_ref_audio_text,
@@ -35,6 +33,7 @@ from f5_tts.infer.utils_infer import (
35
  from pathlib import Path
36
  from omegaconf import OmegaConf
37
  from datetime import datetime
 
38
  # Retrieve token from secrets
39
  hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
40
 
@@ -43,6 +42,13 @@ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
43
  if hf_token:
44
  login(token=hf_token)
45
 
 
 
 
 
 
 
 
46
  def post_process(text):
47
  text = " " + text + " "
48
  text = text.replace(" . . ", " . ")
@@ -168,12 +174,24 @@ def infer_tts(ref_audio_orig: str, ref_text_input: str, gen_text: str, speed: fl
168
  # Nếu người dùng nhập ref_text thì dùng, không thì để rỗng để tự động nhận diện
169
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text_input or "")
170
  gen_text_ = gen_text.strip()
171
- final_wave, final_sample_rate, spectrogram = infer_process(
172
- ref_audio, ref_text.lower(), gen_text_, ema_model, vocoder, speed=speed
173
- )
 
 
 
 
 
 
 
 
 
 
 
174
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
175
  spectrogram_path = tmp_spectrogram.name
176
- save_spectrogram(spectrogram, spectrogram_path)
 
177
  return (final_sample_rate, final_wave), spectrogram_path
178
  except Exception as e:
179
  raise gr.Error(f"Error generating voice: {e}")
 
8
  import tempfile
9
  from vinorm import TTSnorm
10
  from importlib.resources import files
 
 
11
  from f5_tts.model import DiT
12
  from f5_tts.infer.utils_infer import (
13
  preprocess_ref_audio_text,
 
33
  from pathlib import Path
34
  from omegaconf import OmegaConf
35
  from datetime import datetime
36
+ import hashlib
37
  # Retrieve token from secrets
38
  hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
39
 
 
42
  if hf_token:
43
  login(token=hf_token)
44
 
45
+ # Hàm lấy đường dẫn file cache dựa trên text, ref_audio, model
46
+ def get_audio_cache_path(text, ref_audio_path, model, cache_dir="tts_cache"):
47
+ os.makedirs(cache_dir, exist_ok=True)
48
+ hash_input = f"{text}|{ref_audio_path}|{model}"
49
+ hash_val = hashlib.sha256(hash_input.encode("utf-8")).hexdigest()
50
+ return os.path.join(cache_dir, f"{hash_val}.wav")
51
+
52
  def post_process(text):
53
  text = " " + text + " "
54
  text = text.replace(" . . ", " . ")
 
174
  # Nếu người dùng nhập ref_text thì dùng, không thì để rỗng để tự động nhận diện
175
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text_input or "")
176
  gen_text_ = gen_text.strip()
177
+ # --- BẮT ĐẦU: Thêm logic cache ---
178
+ cache_path = get_audio_cache_path(gen_text_, ref_audio_orig, model)
179
+ import soundfile as sf
180
+ if os.path.exists(cache_path):
181
+ print(f"Using cached audio: {cache_path}")
182
+ final_wave, final_sample_rate = sf.read(cache_path)
183
+ spectrogram = None
184
+ else:
185
+ final_wave, final_sample_rate, spectrogram = infer_process(
186
+ ref_audio, ref_text.lower(), gen_text_, ema_model, vocoder, speed=speed
187
+ )
188
+ print(f"[CACHE] Saved new audio to: {cache_path}")
189
+ sf.write(cache_path, final_wave, final_sample_rate)
190
+ # --- KẾT THÚC: Thêm logic cache ---
191
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
192
  spectrogram_path = tmp_spectrogram.name
193
+ if spectrogram is not None:
194
+ save_spectrogram(spectrogram, spectrogram_path)
195
  return (final_sample_rate, final_wave), spectrogram_path
196
  except Exception as e:
197
  raise gr.Error(f"Error generating voice: {e}")