| import os | |
| from pathlib import Path | |
| from typing import Any, BinaryIO, Mapping, Optional, Union | |
| import torch | |
| from config import default_config | |
| from featex import load_audio, Preprocessor | |
| from model import Classifier | |
| class Pipeline: | |
| def __init__(self, checkpoint: Optional[str | Path] = None, config: Optional[Mapping[str, Any]] = None, device: Optional[torch.device] = None): | |
| if checkpoint is None: | |
| file_dir = Path(__file__).parent.resolve() | |
| checkpoint = file_dir / "dam3.1.ckpt" | |
| if config is None: | |
| config = default_config | |
| if device is None: | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| else: | |
| device = torch.device("cpu") | |
| self.device = device | |
| self.model = Classifier(**config) | |
| self.preprocessor = Preprocessor(**self.model.preprocessor_config) | |
| state_dict = torch.load(checkpoint, map_location=device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def run_on_features(self, features: torch.Tensor, quantize: bool = True): | |
| scores = self.model(features, torch.tensor([features.shape[0]], device=self.device))[0] | |
| if quantize: | |
| return {k: int(v.item()) for k, v in self.model.quantize_scores(scores).items()} | |
| else: | |
| return scores | |
| def run_on_audio(self, audio: torch.Tensor, quantize: bool = True): | |
| features = self.preprocessor.preprocess_with_audio_normalization(audio) | |
| return self.run_on_features(features.to(self.device), quantize=quantize) | |
| def run_on_file(self, source: Union[BinaryIO, str, os.PathLike], quantize=True): | |
| audio = load_audio(source) | |
| return self.run_on_audio(audio, quantize=quantize) | |