File size: 2,738 Bytes
574a799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6987eb
 
 
 
 
 
574a799
f6987eb
 
 
 
 
574a799
 
3b25064
 
 
 
574a799
 
73cda73
574a799
73cda73
 
574a799
 
73cda73
 
574a799
 
 
 
73cda73
 
 
574a799
 
 
 
 
 
 
 
 
 
 
79a765f
574a799
 
 
f6987eb
 
 
 
 
 
 
 
574a799
 
 
 
2dd8b08
574a799
2dd8b08
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Custom inference handler for HuggingFace Inference Endpoints."""

from typing import Any, Dict, List, Union

try:
    # For remote execution, imports are relative
    from .asr_modeling import ASRModel
    from .asr_pipeline import ASRPipeline
except ImportError:
    # For local execution, imports are not relative
    from asr_modeling import ASRModel  # type: ignore[no-redef]
    from asr_pipeline import ASRPipeline  # type: ignore[no-redef]


class EndpointHandler:
    """HuggingFace Inference Endpoints handler for ASR model.

    Handles model loading, warmup, and inference requests for deployment
    on HuggingFace Inference Endpoints or similar services.
    """

    def __init__(self, path: str = ""):
        """Initialize the endpoint handler.

        Args:
            path: Path to model directory or HuggingFace model ID
        """
        import os

        import nltk

        nltk.download("punkt_tab", quiet=True)

        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

        # Prepare model kwargs - let transformers handle device placement
        model_kwargs = {
            "device_map": "auto",
            "torch_dtype": "auto",
            "low_cpu_mem_usage": True,
        }
        if self._is_flash_attn_available():
            model_kwargs["attn_implementation"] = "flash_attention_2"

        # Load model (this loads the model, tokenizer, and feature extractor)
        self.model = ASRModel.from_pretrained(path, **model_kwargs)

        # Get device from model for pipeline
        self.device = next(self.model.parameters()).device

        # Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
        self.pipe = ASRPipeline(
            model=self.model,
            feature_extractor=self.model.feature_extractor,
            tokenizer=self.model.tokenizer,
            device=self.device,
        )

    def _is_flash_attn_available(self):
        """Check if flash attention is available."""
        import importlib.util

        return importlib.util.find_spec("flash_attn") is not None

    def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
        """Process an inference request.

        Args:
            data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters'

        Returns:
            Transcription result with 'text' key
        """
        inputs = data.get("inputs")
        if inputs is None:
            raise ValueError("Missing 'inputs' in request data")

        # Pass through any parameters from request, let model config provide defaults
        params = data.get("parameters", {})

        return self.pipe(inputs, **params)