binaryMao commited on
Commit
1c8d96c
·
verified ·
1 Parent(s): bfedbb0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -94
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # -*- coding: utf-8 -*-
2
- """RobotsMali_ASR_Demo.ipynb - Script Final pour Démo Fluide et Stable
3
- Version corrigée du SyntaxError.
4
  """
5
  import gradio as gr
6
  import time
@@ -11,6 +11,7 @@ import numpy as np
11
 
12
  # --- IMPORTS NEMO ---
13
  import nemo.collections.asr as nemo_asr
 
14
  # --------------------
15
 
16
  # ----------------------------------------------------------------------
@@ -20,77 +21,79 @@ ROBOTSMALI_MODELS = [
20
  "RobotsMali/soloba-ctc-0.6b-v0",
21
  "RobotsMali/soloni-114m-tdt-ctc-v1",
22
  "RobotsMali/soloni-114m-tdt-ctc-V0",
23
- "RobotsMali/stt-bm-quartznet5x5-V0",
24
  "RobotsMali/stt-bm-quartznet5x5-v1",
25
  "RobotsMali/soloba-ctc-0.6b-v1"
26
  ]
27
 
28
- CHUNK_DURATION_SEC = 25 # Durée par segment (secondes) pour économiser la RAM
29
- SR_TARGET = 16000 # Taux d'échantillonnage cible pour NeMo ASR (16kHz)
30
 
31
- # Cache pour stocker les modèles NeMo chargés.
 
 
 
32
  asr_pipelines = {}
 
33
 
34
  # ----------------------------------------------------------------------
35
- # 1. FONCTIONS DE GESTION DES MODÈLES (CHARGEMENT + WARM-UP)
36
  # ----------------------------------------------------------------------
37
  def load_pipeline(model_name):
38
- """
39
- Charge le modèle NeMo, le met en cache et effectue un warm-up.
40
- """
41
  if model_name not in asr_pipelines:
42
  print(f"-> Tentative de chargement du modèle NeMo: {model_name}...")
43
  temp_warmup_file = "dummy_warmup.wav"
44
 
45
  try:
46
- # 🚀 CHARGEMENT NEMO
47
  model_instance = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
48
  model_instance.eval()
49
-
50
  asr_pipelines[model_name] = model_instance
51
  print(f"-> Modèle NeMo {model_name} chargé avec succès.")
52
 
53
- # ----------------------------------------------------
54
- # WARM-UP (Inférence à blanc)
55
- # ----------------------------------------------------
56
  print(f" [Warmup] Exécution d'une inférence à blanc...")
57
-
58
  dummy_audio = np.random.randn(SR_TARGET).astype(np.float32)
59
  sf.write(temp_warmup_file, dummy_audio, SR_TARGET)
60
-
61
  model_instance.transcribe([temp_warmup_file], batch_size=1)
62
-
63
  print(f" [Warmup] Terminé.")
64
 
65
  except Exception as e:
66
- if model_name in asr_pipelines:
67
- del asr_pipelines[model_name]
68
  print(f"!!! Erreur de chargement NeMo pour {model_name}: {e}")
69
  raise RuntimeError(f"Impossible de charger le modèle {model_name}. Détail: {e}")
70
 
71
  finally:
72
- if os.path.exists(temp_warmup_file):
73
- os.remove(temp_warmup_file)
74
 
75
  return asr_pipelines.get(model_name)
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # ----------------------------------------------------------------------
78
- # 2. FONCTION PRINCIPALE D'INFÉRENCE AVEC STREAMING ET DÉCOUPAGE
79
  # ----------------------------------------------------------------------
80
  def transcribe_audio(model_name: str, audio_path: str):
81
  """
82
- Effectue la transcription ASR avec découpage (chunking) et streaming d'état.
83
  """
84
  if audio_path is None:
85
  yield "⚠️ Veuillez d'abord télécharger ou enregistrer un fichier audio."
86
  return
87
- if not ROBOTSMALI_MODELS:
88
- yield "Liste de modèles ASR indisponible."
89
- return
90
 
91
  start_time = time.time()
92
  model_short_name = model_name.split('/')[-1]
93
- temp_chunk_paths = []
94
 
95
  try:
96
  # ----------------------------------------------------------------
@@ -99,109 +102,109 @@ def transcribe_audio(model_name: str, audio_path: str):
99
  yield f"**[1/4] CHARGEMENT AUDIO...** Préparation du fichier original (Mono @ 16kHz). ⚙️"
100
 
101
  full_audio_data, sr = librosa.load(audio_path, sr=SR_TARGET, mono=True)
102
-
103
  total_duration = len(full_audio_data) / SR_TARGET
104
- samples_per_chunk = int(CHUNK_DURATION_SEC * SR_TARGET)
 
 
 
105
 
106
  # ----------------------------------------------------------------
107
- # ÉTAPE 2 : CHARGEMENT/VÉRIFICATION DU MODÈLE ET DÉCOUPAGE
108
  # ----------------------------------------------------------------
109
- yield f"**[2/4] PRÉ-CALCUL...** Chargement du modèle et découpage ({total_duration:.1f}s en segments de {CHUNK_DURATION_SEC}s). 🧠"
110
 
111
  asr_model = load_pipeline(model_name)
112
 
113
- # Logique de DÉCOUPAGE
114
- audio_segments = []
115
- for i in range(0, len(full_audio_data), samples_per_chunk):
116
- audio_segments.append(full_audio_data[i:i + samples_per_chunk])
117
-
118
- num_chunks = len(audio_segments)
119
- full_transcription_text = ""
120
-
121
  # ----------------------------------------------------------------
122
- # ÉTAPE 3 : TRANSCRIPTION PAR SEGMENT
123
  # ----------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- for idx, segment_data in enumerate(audio_segments):
126
-
127
- yield f"**[3/4] TRANSCRIPTION EN COURS...** Analyse du segment {idx + 1}/{num_chunks}. ⏳"
128
-
129
- # --- CORRECTION DE LA FORME AUDIO (squeeze) ---
130
- segment_data = segment_data.squeeze()
131
-
132
- # Écriture du chunk temporaire
133
- chunk_path = f"{os.path.splitext(os.path.basename(audio_path))[0]}_chunk_{idx}.wav"
134
- sf.write(chunk_path, segment_data, SR_TARGET)
135
- temp_chunk_paths.append(chunk_path)
136
 
137
- # 🚀 INFÉRENCE NEMO
138
- transcriptions = asr_model.transcribe([chunk_path], batch_size=1)
 
 
 
 
 
139
 
140
- # --- GESTION DE L'OBJET HYPOTHESIS ---
141
- segment_text = ""
142
- if transcriptions and transcriptions[0]:
143
- hyp_object = transcriptions[0]
144
-
145
- if hasattr(hyp_object, 'text'):
146
- segment_text = hyp_object.text.strip()
147
- elif isinstance(hyp_object, str):
148
- segment_text = hyp_object.strip()
149
- elif isinstance(hyp_object, list) and hasattr(hyp_object[0], 'text'):
150
- segment_text = hyp_object[0].text.strip()
151
-
152
- if not segment_text:
153
- segment_text = "[Transcription vide]"
154
-
155
- full_transcription_text += segment_text + "\n\n"
156
 
 
157
  # ----------------------------------------------------
158
- # ÉTAPE 4 : RÉSULTAT FINAL
159
  # ----------------------------------------------------
160
  end_time = time.time()
161
  duration = end_time - start_time
162
-
163
- transcription_text_final = full_transcription_text.strip()
164
-
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  # 1. EN-TÊTE D'INFORMATION
166
  output = f"**Modèle Utilisé :** `{model_short_name}` (NeMo)\n"
167
  output += f"**Durée de l'Audio :** {total_duration:.1f} secondes\n"
168
  output += f"**Temps de Traitement Total :** {duration:.2f} secondes\n"
169
- output += f"**DÉCOUPAGE :** {CHUNK_DURATION_SEC} secondes ({num_chunks} segments)\n"
170
  output += f"***\n"
171
 
172
  # 2. PRÉSENTATION LYRICS PROPRE
173
  output += "**RÉSULTAT DE LA TRANSCRIPTION (Lyrics) :**\n"
174
- # Préparation du texte pour le Markdown (Remplacement avant le yield)
175
- formatted_lyrics = transcription_text_final.replace('\n\n', '\n>>> ')
176
- output += f">>> {formatted_lyrics}"
 
 
 
177
 
178
  # 3. NOTE FINALE
179
- output += "\n\n*Note : Audio converti en **Mono @ 16kHz** pour la transcription.*"
180
-
181
  yield output
182
 
183
  except RuntimeError as e:
184
  yield f"❌ Erreur critique lors du chargement : {str(e)}"
185
 
186
  except Exception as e:
187
- # --- CORRECTION DE SYNTAXE APPLIQUÉE ICI ---
188
- # Affiche le texte partiel en cas d'erreur
189
- if 'full_transcription_text' in locals() and full_transcription_text:
190
- partial_text = full_transcription_text.strip().replace('\n\n', '\n>>> ')
191
- yield f"❌ Erreur lors de la transcription, le traitement s'est arrêté. Texte partiel:\n>>> {partial_text}"
192
-
193
- yield f"❌ Erreur générale : {e}"
194
 
195
  finally:
196
  # Nettoyage
197
- for chunk_path in temp_chunk_paths:
198
- if os.path.exists(chunk_path):
199
- os.remove(chunk_path)
200
- print(f"-> {len(temp_chunk_paths)} fichiers temporaires de segments supprimés.")
201
 
202
 
203
  # ----------------------------------------------------------------------
204
- # 4. PRÉ-CHARGEMENT ET INTERFACE GRADIO
205
  # ----------------------------------------------------------------------
206
 
207
  INITIAL_DESCRIPTION = "Sélectionnez un modèle ASR de RobotsMali, puis enregistrez ou téléchargez un fichier audio pour obtenir la transcription."
@@ -213,6 +216,7 @@ if ROBOTSMALI_MODELS:
213
  default_model_short_name = default_model.split('/')[-1]
214
  INITIAL_DESCRIPTION = (
215
  f"✅ Le modèle par défaut `{default_model_short_name}` (NeMo) a été **préchargé et réchauffé** avec succès. "
 
216
  f"Téléchargez ou enregistrez votre audio pour transcrire."
217
  )
218
  except RuntimeError as e:
@@ -243,9 +247,9 @@ interface = gr.Interface(
243
  fn=transcribe_audio,
244
  inputs=[model_dropdown, audio_input],
245
  outputs=text_output,
246
- title="🤖 RobotsMali ASR Multi-Modèles (Démo NeMo Fluide)",
247
  description=INITIAL_DESCRIPTION,
248
  allow_flagging="never")
249
 
250
  print("Lancement de l'interface Gradio...")
251
- interface.launch(share=True)
 
1
  # -*- coding: utf-8 -*-
2
+ """RobotsMali_ASR_Demo.ipynb - Script FINAL
3
+ Traitement complet de l'audio sans découpage, avec barre de progression Gradio et post-correction.
4
  """
5
  import gradio as gr
6
  import time
 
11
 
12
  # --- IMPORTS NEMO ---
13
  import nemo.collections.asr as nemo_asr
14
+ import nemo.collections.nlp as nemo_nlp
15
  # --------------------
16
 
17
  # ----------------------------------------------------------------------
 
21
  "RobotsMali/soloba-ctc-0.6b-v0",
22
  "RobotsMali/soloni-114m-tdt-ctc-v1",
23
  "RobotsMali/soloni-114m-tdt-ctc-V0",
24
+ "RobotsMali/stt-bm-quartznet5x5-V0", # Modèles souvent en erreur (selon les logs), mais inclus.
25
  "RobotsMali/stt-bm-quartznet5x5-v1",
26
  "RobotsMali/soloba-ctc-0.6b-v1"
27
  ]
28
 
29
+ SR_TARGET = 16000 # Taux d'échantillonnage cible pour NeMo ASR (16kHz)
 
30
 
31
+ # Modèle de post-traitement pour restaurer la ponctuation et la casse
32
+ PUNCT_MODEL_NAME = "nemo/nlp/punctuation_and_capitalization"
33
+
34
+ # Caches
35
  asr_pipelines = {}
36
+ punct_pipeline = None
37
 
38
  # ----------------------------------------------------------------------
39
+ # 1. FONCTIONS DE GESTION DES MODÈLES (CHARGEMENT & CACHE)
40
  # ----------------------------------------------------------------------
41
  def load_pipeline(model_name):
42
+ """Charge un modèle ASR NeMo, le met en cache et effectue un warm-up."""
 
 
43
  if model_name not in asr_pipelines:
44
  print(f"-> Tentative de chargement du modèle NeMo: {model_name}...")
45
  temp_warmup_file = "dummy_warmup.wav"
46
 
47
  try:
 
48
  model_instance = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
49
  model_instance.eval()
 
50
  asr_pipelines[model_name] = model_instance
51
  print(f"-> Modèle NeMo {model_name} chargé avec succès.")
52
 
53
+ # WARM-UP
 
 
54
  print(f" [Warmup] Exécution d'une inférence à blanc...")
 
55
  dummy_audio = np.random.randn(SR_TARGET).astype(np.float32)
56
  sf.write(temp_warmup_file, dummy_audio, SR_TARGET)
 
57
  model_instance.transcribe([temp_warmup_file], batch_size=1)
 
58
  print(f" [Warmup] Terminé.")
59
 
60
  except Exception as e:
61
+ if model_name in asr_pipelines: del asr_pipelines[model_name]
 
62
  print(f"!!! Erreur de chargement NeMo pour {model_name}: {e}")
63
  raise RuntimeError(f"Impossible de charger le modèle {model_name}. Détail: {e}")
64
 
65
  finally:
66
+ if os.path.exists(temp_warmup_file): os.remove(temp_warmup_file)
 
67
 
68
  return asr_pipelines.get(model_name)
69
 
70
+ def load_punct_model():
71
+ """Charge le modèle de ponctuation/casse et le met en cache."""
72
+ global punct_pipeline
73
+ if punct_pipeline is None:
74
+ print(f"-> Tentative de chargement du modèle de ponctuation: {PUNCT_MODEL_NAME}...")
75
+ try:
76
+ punct_pipeline = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(model_name=PUNCT_MODEL_NAME)
77
+ punct_pipeline.eval()
78
+ print("-> Modèle de ponctuation chargé avec succès.")
79
+ except Exception as e:
80
+ print(f"!!! AVERTISSEMENT: Échec du chargement du modèle de ponctuation {PUNCT_MODEL_NAME}. La sortie restera brute. Détail: {e}")
81
+ return punct_pipeline
82
+
83
  # ----------------------------------------------------------------------
84
+ # 2. FONCTION PRINCIPALE D'INFÉRENCE (TRAITEMENT COMPLET AVEC PROGRESSION)
85
  # ----------------------------------------------------------------------
86
  def transcribe_audio(model_name: str, audio_path: str):
87
  """
88
+ Effectue la transcription ASR de l'audio complet avec une barre de progression simulée.
89
  """
90
  if audio_path is None:
91
  yield "⚠️ Veuillez d'abord télécharger ou enregistrer un fichier audio."
92
  return
 
 
 
93
 
94
  start_time = time.time()
95
  model_short_name = model_name.split('/')[-1]
96
+ temp_full_path = f"temp_nemo_input_{os.path.basename(audio_path)}.wav"
97
 
98
  try:
99
  # ----------------------------------------------------------------
 
102
  yield f"**[1/4] CHARGEMENT AUDIO...** Préparation du fichier original (Mono @ 16kHz). ⚙️"
103
 
104
  full_audio_data, sr = librosa.load(audio_path, sr=SR_TARGET, mono=True)
 
105
  total_duration = len(full_audio_data) / SR_TARGET
106
+
107
+ # Correction de la forme audio (squeeze) pour éviter l'erreur de "Output shape mismatch"
108
+ segment_data = full_audio_data.squeeze()
109
+ sf.write(temp_full_path, segment_data, SR_TARGET)
110
 
111
  # ----------------------------------------------------------------
112
+ # ÉTAPE 2 : CHARGEMENT/VÉRIFICATION DU MODÈLE
113
  # ----------------------------------------------------------------
114
+ yield f"**[2/4] PRÉ-CALCUL...** Chargement du modèle. Durée de l'audio : {total_duration:.1f}s. 🧠"
115
 
116
  asr_model = load_pipeline(model_name)
117
 
 
 
 
 
 
 
 
 
118
  # ----------------------------------------------------------------
119
+ # ÉTAPE 3 : TRANSCRIPTION COMPLÈTE (AVEC BARRE DE PROGRESSION)
120
  # ----------------------------------------------------------------
121
+ yield f"**[3/4] TRANSCRIPTION EN COURS...** Démarrage de l'inférence. ⏳"
122
+
123
+ # --- BARRE DE PROGRESSION SIMULÉE ---
124
+ # Affiche une progression visuelle pendant l'attente de l'inférence GPU
125
+ for progress_percent in range(0, 91, 10):
126
+ time.sleep(0.3)
127
+ # Utilise gr.Progress pour une barre stylée en haut de l'interface
128
+ yield gr.Progress(progress_percent, total=100, desc=f"Progression ASR ({progress_percent}%)")
129
+
130
+ yield f"**[3/4] FINALISATION...** Inférence en cours sur le GPU. 🚀"
131
+ # ---------------------------------------------
132
 
133
+ # 🚀 INFÉRENCE NEMO
134
+ transcriptions = asr_model.transcribe([temp_full_path], batch_size=1)
135
+
136
+ # --- GESTION DE L'OBJET HYPOTHESIS ---
137
+ transcription_text_final = ""
138
+ if transcriptions and transcriptions[0]:
139
+ hyp_object = transcriptions[0]
 
 
 
 
140
 
141
+ # Gère les différents formats de sortie de NeMo
142
+ if hasattr(hyp_object, 'text'):
143
+ transcription_text_final = hyp_object.text.strip()
144
+ elif isinstance(hyp_object, str):
145
+ transcription_text_final = hyp_object.strip()
146
+ elif isinstance(hyp_object, list) and hasattr(hyp_object[0], 'text'):
147
+ transcription_text_final = hyp_object[0].text.strip()
148
 
149
+ if not transcription_text_final:
150
+ transcription_text_final = "[Transcription vide ou échec ASR]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+
153
  # ----------------------------------------------------
154
+ # ÉTAPE 4 : POST-TRAITEMENT ET AFFICHAGE FINAL
155
  # ----------------------------------------------------
156
  end_time = time.time()
157
  duration = end_time - start_time
158
+ processed_text = transcription_text_final
159
+
160
+ # --- POST-TRAITEMENT (PONCTUATION & CASSE) ---
161
+ punct_model = load_punct_model()
162
+ if punct_model and transcription_text_final != "[Transcription vide ou échec ASR]":
163
+ yield f"**[4/4] POST-TRAITEMENT...** Correction de la ponctuation et de la casse pour la lisibilité. ✨"
164
+ yield gr.Progress(100, total=100, desc="Progression ASR (100%)") # Termine la barre
165
+
166
+ try:
167
+ corrected_list = punct_model.add_punctuation_capitalization([transcription_text_final])
168
+ if corrected_list:
169
+ processed_text = corrected_list[0].strip()
170
+ except Exception as pc_error:
171
+ print(f"!!! Échec du post-traitement de ponctuation : {pc_error}")
172
+ yield "⚠️ Échec de la correction de ponctuation. Affichage du texte brut."
173
+
174
  # 1. EN-TÊTE D'INFORMATION
175
  output = f"**Modèle Utilisé :** `{model_short_name}` (NeMo)\n"
176
  output += f"**Durée de l'Audio :** {total_duration:.1f} secondes\n"
177
  output += f"**Temps de Traitement Total :** {duration:.2f} secondes\n"
 
178
  output += f"***\n"
179
 
180
  # 2. PRÉSENTATION LYRICS PROPRE
181
  output += "**RÉSULTAT DE LA TRANSCRIPTION (Lyrics) :**\n"
182
+ # Formatage du texte pour l'affichage Markdown
183
+ formatted_lyrics = processed_text.replace('\n', ' ').strip().replace('. ', '.\n\n>>> ').replace('? ', '?\n\n>>> ')
184
+ if not formatted_lyrics.startswith('>>> '):
185
+ formatted_lyrics = '>>> ' + formatted_lyrics
186
+
187
+ output += formatted_lyrics
188
 
189
  # 3. NOTE FINALE
190
+ output += "\n\n*Traitement complet de l'audio sans découpage (chunking).* "
191
+
192
  yield output
193
 
194
  except RuntimeError as e:
195
  yield f"❌ Erreur critique lors du chargement : {str(e)}"
196
 
197
  except Exception as e:
198
+ yield f"❌ Erreur générale lors de la transcription complète : {e}"
 
 
 
 
 
 
199
 
200
  finally:
201
  # Nettoyage
202
+ if os.path.exists(temp_full_path):
203
+ os.remove(temp_full_path)
 
 
204
 
205
 
206
  # ----------------------------------------------------------------------
207
+ # 3. PRÉ-CHARGEMENT ET INTERFACE GRADIO
208
  # ----------------------------------------------------------------------
209
 
210
  INITIAL_DESCRIPTION = "Sélectionnez un modèle ASR de RobotsMali, puis enregistrez ou téléchargez un fichier audio pour obtenir la transcription."
 
216
  default_model_short_name = default_model.split('/')[-1]
217
  INITIAL_DESCRIPTION = (
218
  f"✅ Le modèle par défaut `{default_model_short_name}` (NeMo) a été **préchargé et réchauffé** avec succès. "
219
+ f"**Attention :** Le traitement se fait sur l'audio complet. Les longs fichiers peuvent planter la RAM. "
220
  f"Téléchargez ou enregistrez votre audio pour transcrire."
221
  )
222
  except RuntimeError as e:
 
247
  fn=transcribe_audio,
248
  inputs=[model_dropdown, audio_input],
249
  outputs=text_output,
250
+ title="🤖 RobotsMali ASR Multi-Modèles (Traitement Complet)",
251
  description=INITIAL_DESCRIPTION,
252
  allow_flagging="never")
253
 
254
  print("Lancement de l'interface Gradio...")
255
+ interface.launch(share=True)