Omnilingual-ASR-Colab / server /inference /mms_model_pipeline.py
NeuralFalcon's picture
Upload 21 files
92e075b verified
#@title fix import and path /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py
# %%writefile /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py
"""
Pipeline-based MMS Model using the official MMS library.
This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity.
"""
import logging
import os
import torch
from typing import List, Dict, Any, Optional
# from omnilingual_asr.models.inference.pipeline import Wav2Vec2InferencePipeline
from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs
from inference.audio_reading_tools import wav_to_bytes
from env_vars import MODEL_NAME
logger = logging.getLogger(__name__)
class MMSModel:
"""Pipeline-based MMS model wrapper using the official inference pipeline."""
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
logger.info("Creating new MMSModel singleton instance")
cls._instance = super().__new__(cls)
else:
logger.info("Using existing MMSModel singleton instance")
return cls._instance
def __init__(self, model_card: str = None, device = None):
"""
Initialize the MMS model with the official pipeline.
Args:
model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
If None, uses MODEL_NAME from environment variables
device: Device to use (torch.device object, "cuda", "cpu", etc.)
"""
# Only initialize once
if self._initialized:
return
# Get model name from environment variable with default fallback
self.model_card = model_card or MODEL_NAME
self.device = device
# Load the pipeline immediately during initialization
self._load_pipeline()
# Mark as initialized
self._initialized = True
def _load_pipeline(self):
"""Load the MMS pipeline during initialization."""
logger.info(f"Loading MMS pipeline: {self.model_card}")
logger.info(f"Target device: {self.device}")
# Debug FAIRSEQ2_CACHE_DIR environment variable
# fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR')
fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR',"./models")
logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}")
try:
# Convert device to string if it's a torch.device object
device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device)
# self.pipeline = Wav2Vec2InferencePipeline(
# model_card=self.model_card,
# device=device_str
# )
self.pipeline = ASRInferencePipeline(
model_card=self.model_card,
device=device_str
)
logger.info("βœ“ MMS pipeline loaded successfully")
except Exception as e:
logger.error(f"Failed to load MMS pipeline: {e}")
raise
def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]:
"""
Transcribe audio tensor using the MMS pipeline.
Args:
audio_tensor: Audio tensor (1D waveform) to transcribe
batch_size: Batch size for processing
language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script)
If None, uses auto-detection
Returns:
List of transcription results
"""
# Pipeline is already loaded during initialization, no need to check
# Convert tensor to bytes for the pipeline
logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes")
# Move to CPU first if on GPU
tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor
# Convert to bytes using wav_to_bytes with 16kHz sample rate
audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav")
logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}")
try:
# Use the official pipeline transcribe method with a list containing the single audio bytes
if language_with_scripts is not None:
transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts)
else:
transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size)
logger.info(f"βœ“ Successfully transcribed audio tensor")
return transcriptions
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
@classmethod
def get_instance(cls, model_card: str = None, device = None):
"""
Get the singleton instance of MMSModel.
Args:
model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
If None, uses MODEL_NAME from environment variables
device: Device to use (torch.device object, "cuda", "cpu", etc.)
Returns:
MMSModel: The singleton instance
"""
if cls._instance is None:
cls._instance = cls(model_card=model_card, device=device)
return cls._instance