| import abc | |
| import warnings | |
| from pathlib import Path | |
| from typing import List, Union | |
| import torch | |
| from numpy.typing import NDArray | |
| from sentence_transformers import SentenceTransformer | |
| class SentenceTransformerModels(): | |
| def __init__(self, model_id, device: bool = False): | |
| self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = SentenceTransformer(model_id).eval() | |
| def encode(self, sentences: List[str], batch_size: int = 32) -> NDArray: | |
| with torch.no_grad(): | |
| embeddings = self.model.encode( | |
| sentences, batch_size=batch_size, device=self.device | |
| ) | |
| if isinstance(embeddings, torch.Tensor): | |
| return embeddings.cpu().numpy() | |
| return embeddings |