File size: 2,738 Bytes
574a799 f6987eb 574a799 f6987eb 574a799 3b25064 574a799 73cda73 574a799 73cda73 574a799 73cda73 574a799 73cda73 574a799 79a765f 574a799 f6987eb 574a799 2dd8b08 574a799 2dd8b08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
"""Custom inference handler for HuggingFace Inference Endpoints."""
from typing import Any, Dict, List, Union
try:
# For remote execution, imports are relative
from .asr_modeling import ASRModel
from .asr_pipeline import ASRPipeline
except ImportError:
# For local execution, imports are not relative
from asr_modeling import ASRModel # type: ignore[no-redef]
from asr_pipeline import ASRPipeline # type: ignore[no-redef]
class EndpointHandler:
"""HuggingFace Inference Endpoints handler for ASR model.
Handles model loading, warmup, and inference requests for deployment
on HuggingFace Inference Endpoints or similar services.
"""
def __init__(self, path: str = ""):
"""Initialize the endpoint handler.
Args:
path: Path to model directory or HuggingFace model ID
"""
import os
import nltk
nltk.download("punkt_tab", quiet=True)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Prepare model kwargs - let transformers handle device placement
model_kwargs = {
"device_map": "auto",
"torch_dtype": "auto",
"low_cpu_mem_usage": True,
}
if self._is_flash_attn_available():
model_kwargs["attn_implementation"] = "flash_attention_2"
# Load model (this loads the model, tokenizer, and feature extractor)
self.model = ASRModel.from_pretrained(path, **model_kwargs)
# Get device from model for pipeline
self.device = next(self.model.parameters()).device
# Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
self.pipe = ASRPipeline(
model=self.model,
feature_extractor=self.model.feature_extractor,
tokenizer=self.model.tokenizer,
device=self.device,
)
def _is_flash_attn_available(self):
"""Check if flash attention is available."""
import importlib.util
return importlib.util.find_spec("flash_attn") is not None
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Process an inference request.
Args:
data: Request data containing 'inputs' (audio path/bytes) and optional 'parameters'
Returns:
Transcription result with 'text' key
"""
inputs = data.get("inputs")
if inputs is None:
raise ValueError("Missing 'inputs' in request data")
# Pass through any parameters from request, let model config provide defaults
params = data.get("parameters", {})
return self.pipe(inputs, **params)
|