|
|
from sentence_transformers import models
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer
|
|
|
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import onnxruntime
|
|
|
|
|
|
|
|
|
model_dir = "embeddinggemma-300m"
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m-qat-q4_0-unquantized")
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
onnx_model = ORTModelForFeatureExtraction.from_pretrained(
|
|
|
model_dir,
|
|
|
file_name="model.onnx"
|
|
|
).to(device)
|
|
|
|
|
|
class ONNXTransformer:
|
|
|
def __init__(self, onnx_model, tokenizer, max_seq_length=2048):
|
|
|
self.onnx_model = onnx_model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_seq_length = max_seq_length
|
|
|
def encode(self, sentences):
|
|
|
inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=self.max_seq_length)
|
|
|
input_ids = inputs['input_ids']
|
|
|
sequence_length = input_ids.shape[1]
|
|
|
position_ids = torch.arange(sequence_length)[None, :].expand(input_ids.shape[0], sequence_length)
|
|
|
inputs['position_ids'] = position_ids.to(input_ids.device)
|
|
|
with torch.no_grad():
|
|
|
outputs = self.onnx_model(**inputs)
|
|
|
return outputs.last_hidden_state
|
|
|
|
|
|
modules = []
|
|
|
onnx_transformer = ONNXTransformer(onnx_model, tokenizer, max_seq_length=2048)
|
|
|
modules.append(onnx_transformer)
|
|
|
for idx, name in [(1, "Pooling"), (2, "Dense"), (3, "Dense"), (4, "Normalize")]:
|
|
|
module_path = os.path.join(model_dir, f"{idx}_{name}")
|
|
|
if name == "Pooling":
|
|
|
modules.append(models.Pooling(module_path))
|
|
|
elif name == "Dense":
|
|
|
|
|
|
dense_onnx_path = os.path.join(model_dir, "onnx", f"dense{idx-1}.onnx")
|
|
|
modules.append(onnxruntime.InferenceSession(dense_onnx_path, providers=["CPUExecutionProvider"]))
|
|
|
elif name == "Normalize":
|
|
|
modules.append(models.Normalize())
|
|
|
|
|
|
class ONNXSentenceTransformer:
|
|
|
def __init__(self, modules):
|
|
|
self.modules = modules
|
|
|
def encode(self, sentences):
|
|
|
features = self.modules[0].encode(sentences)
|
|
|
for module in self.modules[1:]:
|
|
|
if isinstance(module, models.Pooling):
|
|
|
features = module({'token_embeddings': features, 'attention_mask': torch.ones(features.shape[:2], device=features.device)})['sentence_embedding']
|
|
|
elif isinstance(module, onnxruntime.InferenceSession):
|
|
|
|
|
|
if isinstance(features, torch.Tensor):
|
|
|
features = features.cpu().detach().numpy()
|
|
|
outputs = []
|
|
|
for vec in features:
|
|
|
ort_inputs = {module.get_inputs()[0].name: vec.reshape(1, -1)}
|
|
|
out = module.run(None, ort_inputs)[0]
|
|
|
outputs.append(out.squeeze(0))
|
|
|
features = np.stack(outputs, axis=0)
|
|
|
elif isinstance(module, models.Normalize):
|
|
|
|
|
|
if not isinstance(features, torch.Tensor):
|
|
|
features = torch.from_numpy(features)
|
|
|
features = module({'sentence_embedding': features})['sentence_embedding']
|
|
|
if isinstance(features, torch.Tensor):
|
|
|
return features.cpu().detach().numpy()
|
|
|
return features
|
|
|
|
|
|
onnx_st = ONNXSentenceTransformer(modules)
|
|
|
|
|
|
def cosine_similarity(a, b):
|
|
|
a = a.flatten()
|
|
|
b = b.flatten()
|
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
words = ["apple", "banana", "car"]
|
|
|
embeddings = onnx_st.encode(words)
|
|
|
print(embeddings)
|
|
|
for idx, embedding in enumerate(embeddings):
|
|
|
print(f"Embedding {idx+1}: {embedding.shape}")
|
|
|
|
|
|
print("\nCosine similarities:")
|
|
|
print(f"apple vs banana: {cosine_similarity(embeddings[0], embeddings[1]):.4f}")
|
|
|
print(f"apple vs car: {cosine_similarity(embeddings[0], embeddings[2]):.4f}")
|
|
|
print(f"banana vs car: {cosine_similarity(embeddings[1], embeddings[2]):.4f}")
|
|
|
|