#@title fix import and path /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py # %%writefile /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py """ Pipeline-based MMS Model using the official MMS library. This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity. """ import logging import os import torch from typing import List, Dict, Any, Optional # from omnilingual_asr.models.inference.pipeline import Wav2Vec2InferencePipeline from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs from inference.audio_reading_tools import wav_to_bytes from env_vars import MODEL_NAME logger = logging.getLogger(__name__) class MMSModel: """Pipeline-based MMS model wrapper using the official inference pipeline.""" _instance = None _initialized = False def __new__(cls, *args, **kwargs): if cls._instance is None: logger.info("Creating new MMSModel singleton instance") cls._instance = super().__new__(cls) else: logger.info("Using existing MMSModel singleton instance") return cls._instance def __init__(self, model_card: str = None, device = None): """ Initialize the MMS model with the official pipeline. Args: model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.) If None, uses MODEL_NAME from environment variables device: Device to use (torch.device object, "cuda", "cpu", etc.) """ # Only initialize once if self._initialized: return # Get model name from environment variable with default fallback self.model_card = model_card or MODEL_NAME self.device = device # Load the pipeline immediately during initialization self._load_pipeline() # Mark as initialized self._initialized = True def _load_pipeline(self): """Load the MMS pipeline during initialization.""" logger.info(f"Loading MMS pipeline: {self.model_card}") logger.info(f"Target device: {self.device}") # Debug FAIRSEQ2_CACHE_DIR environment variable # fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR') fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR',"./models") logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}") try: # Convert device to string if it's a torch.device object device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device) # self.pipeline = Wav2Vec2InferencePipeline( # model_card=self.model_card, # device=device_str # ) self.pipeline = ASRInferencePipeline( model_card=self.model_card, device=device_str ) logger.info("✓ MMS pipeline loaded successfully") except Exception as e: logger.error(f"Failed to load MMS pipeline: {e}") raise def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]: """ Transcribe audio tensor using the MMS pipeline. Args: audio_tensor: Audio tensor (1D waveform) to transcribe batch_size: Batch size for processing language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script) If None, uses auto-detection Returns: List of transcription results """ # Pipeline is already loaded during initialization, no need to check # Convert tensor to bytes for the pipeline logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes") # Move to CPU first if on GPU tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor # Convert to bytes using wav_to_bytes with 16kHz sample rate audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav") logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}") try: # Use the official pipeline transcribe method with a list containing the single audio bytes if language_with_scripts is not None: transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts) else: transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size) logger.info(f"✓ Successfully transcribed audio tensor") return transcriptions except Exception as e: logger.error(f"Transcription failed: {e}") raise @classmethod def get_instance(cls, model_card: str = None, device = None): """ Get the singleton instance of MMSModel. Args: model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.) If None, uses MODEL_NAME from environment variables device: Device to use (torch.device object, "cuda", "cpu", etc.) Returns: MMSModel: The singleton instance """ if cls._instance is None: cls._instance = cls(model_card=model_card, device=device) return cls._instance