File size: 3,014 Bytes
22fcf31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from typing import List, Union, Dict, Any, Optional
import torch
import numpy as np
from PIL import Image
from transformers import Pipeline
from custom_st import Transformer
class EmbeddingPipeline(Pipeline):
"""
Pipeline for generating embeddings using custom transformer model
"""
def __init__(self, model, **kwargs):
super().__init__(model=model, **kwargs)
# Default task if not specified
self.default_task = "retrieval"
def _sanitize_parameters(self, task=None, truncate_dim=None, **kwargs):
preprocess_params = {}
forward_params = {}
postprocess_params = {}
if task is not None:
forward_params["task"] = task
if truncate_dim is not None:
forward_params["truncate_dim"] = truncate_dim
return preprocess_params, forward_params, postprocess_params
def preprocess(self, inputs, **preprocess_params):
"""
Preprocess the inputs before passing to model
"""
# Handle single input vs list of inputs
if not isinstance(inputs, list):
inputs = [inputs]
# Tokenize/prepare the inputs
features = self.model.tokenize(inputs)
return features
def _forward(self, features, task=None, truncate_dim=None):
"""
Forward pass through the model
"""
# Set default task if not provided
if task is None:
task = self.default_task
# Forward pass
outputs = self.model.forward(features, task=task, truncate_dim=truncate_dim)
return outputs
def postprocess(self, model_outputs, **postprocess_params):
"""
Convert model outputs to final embeddings
"""
# Extract embeddings
embeddings = model_outputs.get("sentence_embedding", None)
if embeddings is None:
raise ValueError("No embeddings were generated")
# Convert to numpy
embeddings = embeddings.cpu().numpy()
return embeddings
def load_pipeline(model_path: str, device: str = None):
"""
Load the embedding pipeline from a model path
Args:
model_path: Path to the model directory
device: Device to use for inference (cpu, cuda, etc.)
Returns:
EmbeddingPipeline instance
"""
# Determine device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
model = Transformer(
model_name_or_path=model_path,
model_args={"default_task": "retrieval", "trust_remote_code": True},
trust_remote_code=True
)
model.to(device)
model.eval()
# Create pipeline
pipeline = EmbeddingPipeline(model=model, device=device)
return pipeline
|