Spaces:
Sleeping
Sleeping
File size: 5,704 Bytes
92e075b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#@title fix import and path /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py
# %%writefile /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py
"""
Pipeline-based MMS Model using the official MMS library.
This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity.
"""
import logging
import os
import torch
from typing import List, Dict, Any, Optional
# from omnilingual_asr.models.inference.pipeline import Wav2Vec2InferencePipeline
from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs
from inference.audio_reading_tools import wav_to_bytes
from env_vars import MODEL_NAME
logger = logging.getLogger(__name__)
class MMSModel:
"""Pipeline-based MMS model wrapper using the official inference pipeline."""
_instance = None
_initialized = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
logger.info("Creating new MMSModel singleton instance")
cls._instance = super().__new__(cls)
else:
logger.info("Using existing MMSModel singleton instance")
return cls._instance
def __init__(self, model_card: str = None, device = None):
"""
Initialize the MMS model with the official pipeline.
Args:
model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
If None, uses MODEL_NAME from environment variables
device: Device to use (torch.device object, "cuda", "cpu", etc.)
"""
# Only initialize once
if self._initialized:
return
# Get model name from environment variable with default fallback
self.model_card = model_card or MODEL_NAME
self.device = device
# Load the pipeline immediately during initialization
self._load_pipeline()
# Mark as initialized
self._initialized = True
def _load_pipeline(self):
"""Load the MMS pipeline during initialization."""
logger.info(f"Loading MMS pipeline: {self.model_card}")
logger.info(f"Target device: {self.device}")
# Debug FAIRSEQ2_CACHE_DIR environment variable
# fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR')
fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR',"./models")
logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}")
try:
# Convert device to string if it's a torch.device object
device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device)
# self.pipeline = Wav2Vec2InferencePipeline(
# model_card=self.model_card,
# device=device_str
# )
self.pipeline = ASRInferencePipeline(
model_card=self.model_card,
device=device_str
)
logger.info("✓ MMS pipeline loaded successfully")
except Exception as e:
logger.error(f"Failed to load MMS pipeline: {e}")
raise
def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]:
"""
Transcribe audio tensor using the MMS pipeline.
Args:
audio_tensor: Audio tensor (1D waveform) to transcribe
batch_size: Batch size for processing
language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script)
If None, uses auto-detection
Returns:
List of transcription results
"""
# Pipeline is already loaded during initialization, no need to check
# Convert tensor to bytes for the pipeline
logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes")
# Move to CPU first if on GPU
tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor
# Convert to bytes using wav_to_bytes with 16kHz sample rate
audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav")
logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}")
try:
# Use the official pipeline transcribe method with a list containing the single audio bytes
if language_with_scripts is not None:
transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts)
else:
transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size)
logger.info(f"✓ Successfully transcribed audio tensor")
return transcriptions
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
@classmethod
def get_instance(cls, model_card: str = None, device = None):
"""
Get the singleton instance of MMSModel.
Args:
model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)
If None, uses MODEL_NAME from environment variables
device: Device to use (torch.device object, "cuda", "cpu", etc.)
Returns:
MMSModel: The singleton instance
"""
if cls._instance is None:
cls._instance = cls(model_card=model_card, device=device)
return cls._instance
|