|
|
from typing import List, Union, Dict, Any, Optional
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
from transformers import Pipeline
|
|
|
from custom_st import Transformer
|
|
|
|
|
|
class EmbeddingPipeline(Pipeline):
|
|
|
"""
|
|
|
Pipeline for generating embeddings using custom transformer model
|
|
|
"""
|
|
|
|
|
|
def __init__(self, model, **kwargs):
|
|
|
super().__init__(model=model, **kwargs)
|
|
|
|
|
|
|
|
|
self.default_task = "retrieval"
|
|
|
|
|
|
def _sanitize_parameters(self, task=None, truncate_dim=None, **kwargs):
|
|
|
preprocess_params = {}
|
|
|
forward_params = {}
|
|
|
postprocess_params = {}
|
|
|
|
|
|
if task is not None:
|
|
|
forward_params["task"] = task
|
|
|
|
|
|
if truncate_dim is not None:
|
|
|
forward_params["truncate_dim"] = truncate_dim
|
|
|
|
|
|
return preprocess_params, forward_params, postprocess_params
|
|
|
|
|
|
def preprocess(self, inputs, **preprocess_params):
|
|
|
"""
|
|
|
Preprocess the inputs before passing to model
|
|
|
"""
|
|
|
|
|
|
if not isinstance(inputs, list):
|
|
|
inputs = [inputs]
|
|
|
|
|
|
|
|
|
features = self.model.tokenize(inputs)
|
|
|
return features
|
|
|
|
|
|
def _forward(self, features, task=None, truncate_dim=None):
|
|
|
"""
|
|
|
Forward pass through the model
|
|
|
"""
|
|
|
|
|
|
if task is None:
|
|
|
task = self.default_task
|
|
|
|
|
|
|
|
|
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
|
|
|
return outputs
|
|
|
|
|
|
def postprocess(self, model_outputs, **postprocess_params):
|
|
|
"""
|
|
|
Convert model outputs to final embeddings
|
|
|
"""
|
|
|
|
|
|
embeddings = model_outputs.get("sentence_embedding", None)
|
|
|
|
|
|
if embeddings is None:
|
|
|
raise ValueError("No embeddings were generated")
|
|
|
|
|
|
|
|
|
embeddings = embeddings.cpu().numpy()
|
|
|
|
|
|
return embeddings
|
|
|
|
|
|
def load_pipeline(model_path: str, device: str = None):
|
|
|
"""
|
|
|
Load the embedding pipeline from a model path
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the model directory
|
|
|
device: Device to use for inference (cpu, cuda, etc.)
|
|
|
|
|
|
Returns:
|
|
|
EmbeddingPipeline instance
|
|
|
"""
|
|
|
|
|
|
if device is None:
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
model = Transformer(
|
|
|
model_name_or_path=model_path,
|
|
|
model_args={"default_task": "retrieval", "trust_remote_code": True},
|
|
|
trust_remote_code=True
|
|
|
)
|
|
|
model.to(device)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
pipeline = EmbeddingPipeline(model=model, device=device)
|
|
|
|
|
|
return pipeline
|
|
|
|