from typing import List, Union, Dict, Any, Optional import torch import numpy as np from PIL import Image from transformers import Pipeline from custom_st import Transformer class EmbeddingPipeline(Pipeline): """ Pipeline for generating embeddings using custom transformer model """ def __init__(self, model, **kwargs): super().__init__(model=model, **kwargs) # Default task if not specified self.default_task = "retrieval" def _sanitize_parameters(self, task=None, truncate_dim=None, **kwargs): preprocess_params = {} forward_params = {} postprocess_params = {} if task is not None: forward_params["task"] = task if truncate_dim is not None: forward_params["truncate_dim"] = truncate_dim return preprocess_params, forward_params, postprocess_params def preprocess(self, inputs, **preprocess_params): """ Preprocess the inputs before passing to model """ # Handle single input vs list of inputs if not isinstance(inputs, list): inputs = [inputs] # Tokenize/prepare the inputs features = self.model.tokenize(inputs) return features def _forward(self, features, task=None, truncate_dim=None): """ Forward pass through the model """ # Set default task if not provided if task is None: task = self.default_task # Forward pass outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim) return outputs def postprocess(self, model_outputs, **postprocess_params): """ Convert model outputs to final embeddings """ # Extract embeddings embeddings = model_outputs.get("sentence_embedding", None) if embeddings is None: raise ValueError("No embeddings were generated") # Convert to numpy embeddings = embeddings.cpu().numpy() return embeddings def load_pipeline(model_path: str, device: str = None): """ Load the embedding pipeline from a model path Args: model_path: Path to the model directory device: Device to use for inference (cpu, cuda, etc.) Returns: EmbeddingPipeline instance """ # Determine device if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Load model model = Transformer( model_name_or_path=model_path, model_args={"default_task": "retrieval", "trust_remote_code": True}, trust_remote_code=True ) model.to(device) model.eval() # Create pipeline pipeline = EmbeddingPipeline(model=model, device=device) return pipeline