Spaces:
Sleeping
Sleeping
| #@title fix import and path /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py | |
| # %%writefile /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py | |
| """ | |
| Pipeline-based MMS Model using the official MMS library. | |
| This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity. | |
| """ | |
| import logging | |
| import os | |
| import torch | |
| from typing import List, Dict, Any, Optional | |
| # from omnilingual_asr.models.inference.pipeline import Wav2Vec2InferencePipeline | |
| from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline | |
| from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs | |
| from inference.audio_reading_tools import wav_to_bytes | |
| from env_vars import MODEL_NAME | |
| logger = logging.getLogger(__name__) | |
| class MMSModel: | |
| """Pipeline-based MMS model wrapper using the official inference pipeline.""" | |
| _instance = None | |
| _initialized = False | |
| def __new__(cls, *args, **kwargs): | |
| if cls._instance is None: | |
| logger.info("Creating new MMSModel singleton instance") | |
| cls._instance = super().__new__(cls) | |
| else: | |
| logger.info("Using existing MMSModel singleton instance") | |
| return cls._instance | |
| def __init__(self, model_card: str = None, device = None): | |
| """ | |
| Initialize the MMS model with the official pipeline. | |
| Args: | |
| model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.) | |
| If None, uses MODEL_NAME from environment variables | |
| device: Device to use (torch.device object, "cuda", "cpu", etc.) | |
| """ | |
| # Only initialize once | |
| if self._initialized: | |
| return | |
| # Get model name from environment variable with default fallback | |
| self.model_card = model_card or MODEL_NAME | |
| self.device = device | |
| # Load the pipeline immediately during initialization | |
| self._load_pipeline() | |
| # Mark as initialized | |
| self._initialized = True | |
| def _load_pipeline(self): | |
| """Load the MMS pipeline during initialization.""" | |
| logger.info(f"Loading MMS pipeline: {self.model_card}") | |
| logger.info(f"Target device: {self.device}") | |
| # Debug FAIRSEQ2_CACHE_DIR environment variable | |
| # fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR') | |
| fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR',"./models") | |
| logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}") | |
| try: | |
| # Convert device to string if it's a torch.device object | |
| device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device) | |
| # self.pipeline = Wav2Vec2InferencePipeline( | |
| # model_card=self.model_card, | |
| # device=device_str | |
| # ) | |
| self.pipeline = ASRInferencePipeline( | |
| model_card=self.model_card, | |
| device=device_str | |
| ) | |
| logger.info("β MMS pipeline loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load MMS pipeline: {e}") | |
| raise | |
| def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Transcribe audio tensor using the MMS pipeline. | |
| Args: | |
| audio_tensor: Audio tensor (1D waveform) to transcribe | |
| batch_size: Batch size for processing | |
| language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script) | |
| If None, uses auto-detection | |
| Returns: | |
| List of transcription results | |
| """ | |
| # Pipeline is already loaded during initialization, no need to check | |
| # Convert tensor to bytes for the pipeline | |
| logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes") | |
| # Move to CPU first if on GPU | |
| tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor | |
| # Convert to bytes using wav_to_bytes with 16kHz sample rate | |
| audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav") | |
| logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}") | |
| try: | |
| # Use the official pipeline transcribe method with a list containing the single audio bytes | |
| if language_with_scripts is not None: | |
| transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts) | |
| else: | |
| transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size) | |
| logger.info(f"β Successfully transcribed audio tensor") | |
| return transcriptions | |
| except Exception as e: | |
| logger.error(f"Transcription failed: {e}") | |
| raise | |
| def get_instance(cls, model_card: str = None, device = None): | |
| """ | |
| Get the singleton instance of MMSModel. | |
| Args: | |
| model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.) | |
| If None, uses MODEL_NAME from environment variables | |
| device: Device to use (torch.device object, "cuda", "cpu", etc.) | |
| Returns: | |
| MMSModel: The singleton instance | |
| """ | |
| if cls._instance is None: | |
| cls._instance = cls(model_card=model_card, device=device) | |
| return cls._instance | |