Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """RobotsMali_ASR_Demo.ipynb - Script FINAL MINIMALISTE (UPLOAD SEULEMENT) | |
| Interface utilisateur simplifiée pour les tests utilisateurs, ne gardant que l'option de téléchargement de fichier audio. | |
| """ | |
| import gradio as gr | |
| import time | |
| import os | |
| import librosa | |
| import soundfile as sf | |
| import numpy as np | |
| # --- IMPORTS NEMO --- | |
| try: | |
| import nemo.collections.asr as nemo_asr | |
| import nemo.collections.nlp as nemo_nlp | |
| except ImportError: | |
| # Simuler les imports si NeMo n'est pas disponible (pour le test initial) | |
| class DummyASRModel: | |
| def from_pretrained(self, model_name): raise RuntimeError("NeMo ASR not installed.") | |
| class DummyNLPModel: | |
| def from_pretrained(self, model_name): raise RuntimeError("NeMo NLP not installed.") | |
| nemo_asr = type('nemo_asr', (object,), {'models': type('models', (object,), {'ASRModel': DummyASRModel})}) | |
| nemo_nlp = type('nemo_nlp', (object,), {'models': type('models', (object,), {'PunctuationCapitalizationModel': DummyNLPModel})}) | |
| # ---------------------------------------------------------------------- | |
| # CONSTANTES DE CONFIGURATION | |
| # ---------------------------------------------------------------------- | |
| # Dictionnaire des modèles : {Nom Lisible: Nom Complet (pour NeMo)} | |
| ROBOTSMALI_MODELS_MAP = { | |
| "Soloba CTC 0.6B v0": "RobotsMali/soloba-ctc-0.6b-v0", | |
| "Soloni 114M TDT v1": "RobotsMali/soloni-114m-tdt-ctc-v1", | |
| "Soloni 114M TDT v0": "RobotsMali/soloni-114m-tdt-ctc-V0", | |
| "QuartzNet 15x5 V0": "RobotsMali/stt-bm-quartznet15x5-V0", | |
| "QuartzNet 15x5 v1": "RobotsMali/stt-bm-quartznet15x5-v1", | |
| "Soloba CTC 0.6B v1": "RobotsMali/soloba-ctc-0.6b-v1" | |
| } | |
| ROBOTSMALI_SHORT_NAMES = list(ROBOTSMALI_MODELS_MAP.keys()) | |
| SR_TARGET = 16000 | |
| PUNCT_MODEL_NAME = "nemo/nlp/punctuation_and_capitalization" | |
| # Caches | |
| asr_pipelines = {} | |
| punct_pipeline = None | |
| # ---------------------------------------------------------------------- | |
| # 1. FONCTIONS DE GESTION DES MODÈLES (Pas de changement logique) | |
| # ---------------------------------------------------------------------- | |
| def load_pipeline(short_name): | |
| """Charge un modèle ASR NeMo en utilisant son nom court.""" | |
| model_name = ROBOTSMALI_MODELS_MAP.get(short_name) | |
| if not model_name: | |
| raise ValueError(f"Nom de modèle inconnu: {short_name}") | |
| if short_name not in asr_pipelines: | |
| temp_warmup_file = "dummy_warmup.wav" | |
| try: | |
| model_instance = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | |
| model_instance.eval() | |
| asr_pipelines[short_name] = model_instance | |
| # WARM-UP (silencieux pour l'UX) | |
| dummy_audio = np.random.randn(SR_TARGET).astype(np.float32) | |
| sf.write(temp_warmup_file, dummy_audio, SR_TARGET) | |
| model_instance.transcribe([temp_warmup_file], batch_size=1) | |
| except Exception as e: | |
| if short_name in asr_pipelines: del asr_pipelines[short_name] | |
| raise RuntimeError(f"Impossible de charger le modèle {short_name}. Détail: {e}") | |
| finally: | |
| if os.path.exists(temp_warmup_file): os.remove(temp_warmup_file) | |
| return asr_pipelines.get(short_name) | |
| def load_punct_model(): | |
| """Charge le modèle de ponctuation/casse et le met en cache.""" | |
| global punct_pipeline | |
| if punct_pipeline is None: | |
| try: | |
| punct_pipeline = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(model_name=PUNCT_MODEL_NAME) | |
| punct_pipeline.eval() | |
| except Exception as e: | |
| punct_pipeline = False | |
| return punct_pipeline | |
| # ---------------------------------------------------------------------- | |
| # 2. FONCTION PRINCIPALE D'INFÉRENCE | |
| # ---------------------------------------------------------------------- | |
| def transcribe_audio(model_short_name: str, audio_path: str): | |
| """Effectue la transcription ASR de l'audio complet.""" | |
| progress = gr.Progress() | |
| if audio_path is None: | |
| return "⚠️ **Erreur :** Veuillez **télécharger** un fichier audio pour commencer." | |
| start_time = time.time() | |
| temp_full_path = f"temp_nemo_input_{os.path.basename(audio_path)}.wav" | |
| raw_transcription = "[Transcription vide ou échec ASR]" | |
| try: | |
| # Affichage de la première étape | |
| yield f"**Statut :** 🔄 Préparation et chargement du modèle `{model_short_name}`..." | |
| full_audio_data, sr = librosa.load(audio_path, sr=SR_TARGET, mono=True) | |
| total_duration = len(full_audio_data) / SR_TARGET | |
| segment_data = full_audio_data.squeeze() | |
| sf.write(temp_full_path, segment_data, SR_TARGET) | |
| asr_model = load_pipeline(model_short_name) | |
| # Affichage de la deuxième étape avec progression simulée | |
| yield f"**Statut :** ⏳ Transcription en cours (Durée audio: {total_duration:.1f}s)..." | |
| for progress_percent in range(0, 91, 10): | |
| time.sleep(0.3) | |
| progress(progress_percent / 100, desc=f"Progression ASR ({progress_percent}%)") | |
| # Inférence | |
| transcriptions = asr_model.transcribe([temp_full_path], batch_size=1) | |
| if transcriptions and transcriptions[0]: | |
| hyp_object = transcriptions[0] | |
| if hasattr(hyp_object, 'text'): | |
| raw_transcription = hyp_object.text.strip() | |
| elif isinstance(hyp_object, str): | |
| raw_transcription = hyp_object.strip() | |
| elif isinstance(hyp_object, list) and hasattr(hyp_object[0], 'text'): | |
| raw_transcription = hyp_object[0].text.strip() | |
| # Post-traitement (Ponctuation et Casse) | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| processed_text = raw_transcription | |
| punct_status = "" | |
| punct_model = load_punct_model() | |
| if punct_model and raw_transcription != "[Transcription vide ou échec ASR]": | |
| yield f"**Statut :** ✨ Finalisation et correction de la ponctuation..." | |
| progress(1.0, desc="Progression ASR (100%)") | |
| try: | |
| corrected_list = punct_model.add_punctuation_capitalization([raw_transcription]) | |
| if corrected_list: | |
| processed_text = corrected_list[0].strip() | |
| punct_status = " (Correction Ponctuation OK)" | |
| except Exception as pc_error: | |
| punct_status = " (Correction Ponctuation ÉCHOUÉE)" | |
| # ------------------------------------------------ | |
| # BLOC DE RÉSULTAT FINAL MINIMALISTE | |
| # ------------------------------------------------ | |
| # 1. En-tête (une seule ligne) | |
| header = f"### ✅ Transcription Terminée " | |
| header += f"*(Modèle: **{model_short_name}** | Temps: {duration:.2f}s)*{punct_status}\n\n" | |
| header += "--- \n" | |
| # 2. Lyrics très clairs | |
| clean_text = processed_text.replace('\n', ' ').strip() | |
| # Remplacer les séparateurs de phrases par un double saut de ligne | |
| formatted_lyrics = clean_text.replace('. ', '.\n\n').replace('? ', '?\n\n').replace('! ', '!\n\n') | |
| lyrics_output = f""" | |
| **Transcription Finale :** | |
| ```text | |
| {formatted_lyrics} | |
| ``` | |
| """ | |
| # 3. Message final de nettoyage | |
| footer = f"\n\n*Traitement basé sur le lien `{ROBOTSMALI_MODELS_MAP.get(model_short_name)}`.*" | |
| final_markdown = header + lyrics_output + footer | |
| yield final_markdown | |
| except RuntimeError as e: | |
| yield f"❌ **Erreur Critique :** Impossible de procéder. Détails : {str(e)}" | |
| except Exception as e: | |
| yield f"❌ **Erreur Générale :** Une erreur inattendue est survenue : {e}" | |
| finally: | |
| if os.path.exists(temp_full_path): | |
| os.remove(temp_full_path) | |
| # ---------------------------------------------------------------------- | |
| # 3. INTERFACE GRADIO (Minimaliste) | |
| # ---------------------------------------------------------------------- | |
| # Statut initial de l'application | |
| APP_STATUS = "Chargement en cours..." | |
| if ROBOTSMALI_SHORT_NAMES: | |
| default_short_name = ROBOTSMALI_SHORT_NAMES[0] | |
| try: | |
| load_pipeline(default_short_name) | |
| APP_STATUS = f"✅ **Prêt :** Modèle de base `{default_short_name}` chargé." | |
| except Exception as e: | |
| APP_STATUS = f"❌ **Échec au Démarrage :** Vérifiez la configuration NeMo/CUDA." | |
| # Composants | |
| model_dropdown = gr.Dropdown( | |
| label="Étape 1 : Choisir le Modèle ASR", | |
| choices=ROBOTSMALI_SHORT_NAMES, | |
| value=ROBOTSMALI_SHORT_NAMES[0] if ROBOTSMALI_SHORT_NAMES else None, | |
| interactive=True) | |
| # MODIFICATION ICI : Suppression de 'microphone' | |
| audio_input = gr.Audio( | |
| label="Étape 2 : Télécharger l'Audio (MP3, WAV, etc.)", | |
| type="filepath", | |
| sources=["upload"], # SEULEMENT UPLOAD | |
| format="mp3") | |
| text_output = gr.Markdown( | |
| label="Résultat", | |
| value="Commencez par choisir un modèle et télécharger votre audio. Le résultat s'affichera ici. 💡") | |
| # Mise en page Blocks (Deux colonnes simples et bien définies) | |
| with gr.Blocks(theme=gr.themes.Soft(), title="RobotsMali ASR") as demo: | |
| gr.Markdown( | |
| f""" | |
| # 🤖 RobotsMali ASR Demo | |
| ### **Transcription Vocale pour les langues maliennes. Minimaliste et rapide.** | |
| {APP_STATUS} | |
| --- | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Colonne de GAUCHE: INPUTS | |
| with gr.Column(scale=1, min_width=300): | |
| model_dropdown.render() | |
| audio_input.render() | |
| submit_btn = gr.Button("▶️ ÉTAPE 3 : LANCER LA TRANSCRIPTION", variant="primary") | |
| gr.Markdown( | |
| """ | |
| *Rappel : L'audio doit être court (moins de 5 minutes) pour éviter une erreur de mémoire.* | |
| """ | |
| ) | |
| # Colonne de DROITE: OUTPUT | |
| with gr.Column(scale=2, min_width=500): | |
| text_output.render() | |
| # Définition des actions | |
| submit_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[model_dropdown, audio_input], | |
| outputs=text_output | |
| ) | |
| print("Lancement de l'interface Gradio Blocks...") | |
| demo.launch(share=True) |