IRMSEmbeddingsV4 / pipeline.py
Krishna Indukuri
Upload 31 files
22fcf31 verified
from typing import List, Union, Dict, Any, Optional
import torch
import numpy as np
from PIL import Image
from transformers import Pipeline
from custom_st import Transformer
class EmbeddingPipeline(Pipeline):
"""
Pipeline for generating embeddings using custom transformer model
"""
def __init__(self, model, **kwargs):
super().__init__(model=model, **kwargs)
# Default task if not specified
self.default_task = "retrieval"
def _sanitize_parameters(self, task=None, truncate_dim=None, **kwargs):
preprocess_params = {}
forward_params = {}
postprocess_params = {}
if task is not None:
forward_params["task"] = task
if truncate_dim is not None:
forward_params["truncate_dim"] = truncate_dim
return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, **preprocess_params):
"""
Preprocess the inputs before passing to model
"""
# Handle single input vs list of inputs
if not isinstance(inputs, list):
inputs = [inputs]
# Tokenize/prepare the inputs
features = self.model.tokenize(inputs)
return features
def _forward(self, features, task=None, truncate_dim=None):
"""
Forward pass through the model
"""
# Set default task if not provided
if task is None:
task = self.default_task
# Forward pass
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
return outputs
def postprocess(self, model_outputs, **postprocess_params):
"""
Convert model outputs to final embeddings
"""
# Extract embeddings
embeddings = model_outputs.get("sentence_embedding", None)
if embeddings is None:
raise ValueError("No embeddings were generated")
# Convert to numpy
embeddings = embeddings.cpu().numpy()
return embeddings
def load_pipeline(model_path: str, device: str = None):
"""
Load the embedding pipeline from a model path
Args:
model_path: Path to the model directory
device: Device to use for inference (cpu, cuda, etc.)
Returns:
EmbeddingPipeline instance
"""
# Determine device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
model = Transformer(
model_name_or_path=model_path,
model_args={"default_task": "retrieval", "trust_remote_code": True},
trust_remote_code=True
)
model.to(device)
model.eval()
# Create pipeline
pipeline = EmbeddingPipeline(model=model, device=device)
return pipeline