File size: 5,704 Bytes
92e075b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#@title fix import and path /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py
# %%writefile /content/omniasr-transcriptions/server/inference/mms_model_pipeline.py
"""

Pipeline-based MMS Model using the official MMS library.

This implementation uses Wav2Vec2LlamaInferencePipeline to avoid Seq2SeqBatch complexity.

"""

import logging
import os
import torch
from typing import List, Dict, Any, Optional
# from omnilingual_asr.models.inference.pipeline import Wav2Vec2InferencePipeline
from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline

from omnilingual_asr.models.wav2vec2_llama.lang_ids import supported_langs

from inference.audio_reading_tools import wav_to_bytes
from env_vars import MODEL_NAME

logger = logging.getLogger(__name__)


class MMSModel:
    """Pipeline-based MMS model wrapper using the official inference pipeline."""
    _instance = None
    _initialized = False

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            logger.info("Creating new MMSModel singleton instance")
            cls._instance = super().__new__(cls)
        else:
            logger.info("Using existing MMSModel singleton instance")
        return cls._instance

    def __init__(self, model_card: str = None, device = None):
        """

        Initialize the MMS model with the official pipeline.



        Args:

            model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)

                       If None, uses MODEL_NAME from environment variables

            device: Device to use (torch.device object, "cuda", "cpu", etc.)

        """
        # Only initialize once
        if self._initialized:
            return

        # Get model name from environment variable with default fallback
        self.model_card = model_card or MODEL_NAME
        self.device = device

        # Load the pipeline immediately during initialization
        self._load_pipeline()

        # Mark as initialized
        self._initialized = True

    def _load_pipeline(self):
        """Load the MMS pipeline during initialization."""
        logger.info(f"Loading MMS pipeline: {self.model_card}")
        logger.info(f"Target device: {self.device}")

        # Debug FAIRSEQ2_CACHE_DIR environment variable
        # fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR')
        fairseq2_cache_dir = os.environ.get('FAIRSEQ2_CACHE_DIR',"./models")
        logger.info(f"DEBUG: FAIRSEQ2_CACHE_DIR = {fairseq2_cache_dir}")

        try:
            # Convert device to string if it's a torch.device object
            device_str = str(self.device) if hasattr(self.device, 'type') else str(self.device)
            # self.pipeline = Wav2Vec2InferencePipeline(
            #     model_card=self.model_card,
            #     device=device_str
            # )
            self.pipeline = ASRInferencePipeline(
                model_card=self.model_card,
                device=device_str
            )
            logger.info("✓ MMS pipeline loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load MMS pipeline: {e}")
            raise

    def transcribe_audio(self, audio_tensor: torch.Tensor, batch_size: int = 1, language_with_scripts: List[str] = None) -> List[Dict[str, Any]]:
        """

        Transcribe audio tensor using the MMS pipeline.



        Args:

            audio_tensor: Audio tensor (1D waveform) to transcribe

            batch_size: Batch size for processing

            language_with_scripts: List of language_with_scripts codes for transcription (3-letter ISO codes with script)

                 If None, uses auto-detection



        Returns:

            List of transcription results

        """
        # Pipeline is already loaded during initialization, no need to check

        # Convert tensor to bytes for the pipeline
        logger.info(f"Converting tensor (shape: {audio_tensor.shape}) to bytes")
        # Move to CPU first if on GPU
        tensor_cpu = audio_tensor.cpu() if audio_tensor.is_cuda else audio_tensor
        # Convert to bytes using wav_to_bytes with 16kHz sample rate
        audio_bytes = wav_to_bytes(tensor_cpu, sample_rate=16000, format="wav")

        logger.info(f"Transcribing audio tensor with batch_size={batch_size}, language_with_scripts={language_with_scripts}")

        try:
            # Use the official pipeline transcribe method with a list containing the single audio bytes
            if language_with_scripts is not None:
                transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size, lang=language_with_scripts)
            else:
                transcriptions = self.pipeline.transcribe([audio_bytes], batch_size=batch_size)

            logger.info(f"✓ Successfully transcribed audio tensor")
            return transcriptions

        except Exception as e:
            logger.error(f"Transcription failed: {e}")
            raise

    @classmethod
    def get_instance(cls, model_card: str = None, device = None):
        """

        Get the singleton instance of MMSModel.



        Args:

            model_card: Model card to use (omniASR_LLM_1B, omniASR_LLM_300M, etc.)

                       If None, uses MODEL_NAME from environment variables

            device: Device to use (torch.device object, "cuda", "cpu", etc.)



        Returns:

            MMSModel: The singleton instance

        """
        if cls._instance is None:
            cls._instance = cls(model_card=model_card, device=device)
        return cls._instance