tianfengping.tfp commited on
Commit
a723d72
·
1 Parent(s): 1bd43c9

load model after init

Browse files
Files changed (1) hide show
  1. app.py +54 -42
app.py CHANGED
@@ -36,27 +36,41 @@ os.system('export PYTHONPATH=third_party/Matcha-TTS')
36
 
37
  from huggingface_hub import hf_hub_download
38
 
39
- # Download assets and logos first (these are small files)
40
- try:
41
- assets_dir = snapshot_download(
42
- repo_id="tienfeng/prompt",
43
- repo_type="dataset",
44
- )
45
- logo_path = hf_hub_download(
46
- repo_id="tienfeng/prompt",
47
- filename="logo2.png",
48
- repo_type="dataset",
49
- )
50
- logo_path2 = hf_hub_download(
51
- repo_id="tienfeng/prompt",
52
- filename="logo.png",
53
- repo_type="dataset",
54
- )
55
- except Exception as e:
56
- print(f"Warning: Failed to download assets/logos: {e}")
57
- assets_dir = None
58
- logo_path = None
59
- logo_path2 = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # Delay model download to avoid blocking startup
62
  model_repo_id = "AIDC-AI/Marco-Voice"
@@ -157,16 +171,20 @@ os.makedirs("./tmp", exist_ok=True)
157
  def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
158
  # import pdb;pdb.set_trace()
159
  global tts_speakerminus_global, local_model_path
160
- # Ensure models are downloaded
161
  if local_model_path is None:
 
162
  load_models()
163
  if 'tts_speakerminus_global' not in globals() or tts_speakerminus_global is None:
164
  print("Loading CosyVoice (speakerminus) model...")
165
  tts_speakerminus_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path)
166
 
167
  if not ref_audio and not ref_text:
168
- if audio_prompt_path is None:
169
- raise ValueError("Audio prompt path is not available. Please provide reference audio and text.")
 
 
 
170
  ref_text = text_prompt.get(speaker, "")
171
  speaker_audio_name = audio_prompt.get(speaker)
172
  if speaker_audio_name:
@@ -241,15 +259,19 @@ def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_t
241
  def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
242
  # import pdb;pdb.set_trace()
243
  global tts_sft_global, local_model_path_enhenced
244
- # Ensure models are downloaded
245
  if local_model_path_enhenced is None:
 
246
  load_models()
247
  if 'tts_sft_global' not in globals() or tts_sft_global is None:
248
  print("Loading CosyVoice (SFT enhanced) model...")
249
  tts_sft_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path_enhenced)
250
  if not ref_audio and not ref_text:
251
- if audio_prompt_path is None:
252
- raise ValueError("Audio prompt path is not available. Please provide reference audio and text.")
 
 
 
253
  ref_text = text_prompt.get(speaker, "")
254
  speaker_audio_name = audio_prompt.get(speaker)
255
  if speaker_audio_name:
@@ -638,7 +660,9 @@ input[type="text"]:focus, textarea:focus {
638
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
639
  with gr.Column(elem_classes="header"):
640
  with gr.Row(elem_id="header-row", variant="compact"):
641
- gr.Image(value=logo_path,
 
 
642
  elem_id="logo-container",
643
  show_label=False,
644
  show_download_button=False,
@@ -823,20 +847,8 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
823
  outputs=tts_v2_output
824
  )
825
 
826
- def preload_models():
827
- """Pre-download models to cache (non-blocking for launch)"""
828
- import threading
829
- def _download():
830
- try:
831
- print("Pre-downloading models to cache...")
832
- load_models()
833
- print("Model pre-download completed.")
834
- except Exception as e:
835
- print(f"Warning: Model pre-download failed: {e}. Models will be loaded on first use.")
836
- threading.Thread(target=_download, daemon=True).start()
837
-
838
- # Start preloading models in background (non-blocking)
839
- preload_models()
840
 
841
  if __name__ == "__main__":
842
  # Use environment variable for port (Hugging Face Spaces uses 7860 by default)
 
36
 
37
  from huggingface_hub import hf_hub_download
38
 
39
+ # Download assets and logos in background to avoid blocking startup
40
+ assets_dir = None
41
+ logo_path = None
42
+ logo_path2 = None
43
+
44
+ def load_assets():
45
+ """Load assets lazily"""
46
+ global assets_dir, logo_path, logo_path2
47
+ if assets_dir is None:
48
+ try:
49
+ print("Downloading assets and logos...")
50
+ assets_dir = snapshot_download(
51
+ repo_id="tienfeng/prompt",
52
+ repo_type="dataset",
53
+ )
54
+ logo_path = hf_hub_download(
55
+ repo_id="tienfeng/prompt",
56
+ filename="logo2.png",
57
+ repo_type="dataset",
58
+ )
59
+ logo_path2 = hf_hub_download(
60
+ repo_id="tienfeng/prompt",
61
+ filename="logo.png",
62
+ repo_type="dataset",
63
+ )
64
+ print("Assets downloaded successfully")
65
+ except Exception as e:
66
+ print(f"Warning: Failed to download assets/logos: {e}")
67
+ assets_dir = None
68
+ logo_path = None
69
+ logo_path2 = None
70
+
71
+ # Start downloading assets in background (non-blocking)
72
+ import threading
73
+ threading.Thread(target=load_assets, daemon=True).start()
74
 
75
  # Delay model download to avoid blocking startup
76
  model_repo_id = "AIDC-AI/Marco-Voice"
 
171
  def generate_speech_speakerminus(tts_text, speed, speaker, key, ref_audio, ref_text):
172
  # import pdb;pdb.set_trace()
173
  global tts_speakerminus_global, local_model_path
174
+ # Ensure models are downloaded (this may take time on first use)
175
  if local_model_path is None:
176
+ print("Downloading models (this may take a few minutes on first use)...")
177
  load_models()
178
  if 'tts_speakerminus_global' not in globals() or tts_speakerminus_global is None:
179
  print("Loading CosyVoice (speakerminus) model...")
180
  tts_speakerminus_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path)
181
 
182
  if not ref_audio and not ref_text:
183
+ # Ensure assets are loaded
184
+ if assets_dir is None:
185
+ load_assets()
186
+ if audio_prompt_path is None or assets_dir is None:
187
+ raise ValueError("Audio prompt path is not available. Please wait a moment and try again, or provide reference audio and text.")
188
  ref_text = text_prompt.get(speaker, "")
189
  speaker_audio_name = audio_prompt.get(speaker)
190
  if speaker_audio_name:
 
259
  def generate_speech_sft(tts_text, speed, speaker, key, ref_audio, ref_text):
260
  # import pdb;pdb.set_trace()
261
  global tts_sft_global, local_model_path_enhenced
262
+ # Ensure models are downloaded (this may take time on first use)
263
  if local_model_path_enhenced is None:
264
+ print("Downloading models (this may take a few minutes on first use)...")
265
  load_models()
266
  if 'tts_sft_global' not in globals() or tts_sft_global is None:
267
  print("Loading CosyVoice (SFT enhanced) model...")
268
  tts_sft_global = CosyVoiceTTS_speakerminus(model_dir=local_model_path_enhenced)
269
  if not ref_audio and not ref_text:
270
+ # Ensure assets are loaded
271
+ if assets_dir is None:
272
+ load_assets()
273
+ if audio_prompt_path is None or assets_dir is None:
274
+ raise ValueError("Audio prompt path is not available. Please wait a moment and try again, or provide reference audio and text.")
275
  ref_text = text_prompt.get(speaker, "")
276
  speaker_audio_name = audio_prompt.get(speaker)
277
  if speaker_audio_name:
 
660
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
661
  with gr.Column(elem_classes="header"):
662
  with gr.Row(elem_id="header-row", variant="compact"):
663
+ # Load logo if available, otherwise use placeholder
664
+ logo_value = logo_path if logo_path is not None else None
665
+ gr.Image(value=logo_value,
666
  elem_id="logo-container",
667
  show_label=False,
668
  show_download_button=False,
 
847
  outputs=tts_v2_output
848
  )
849
 
850
+ # Don't preload models - let them download on first use to avoid startup timeout
851
+ # Models will be downloaded and loaded lazily when first requested by user
 
 
 
 
 
 
 
 
 
 
 
 
852
 
853
  if __name__ == "__main__":
854
  # Use environment variable for port (Hugging Face Spaces uses 7860 by default)