p2002814
now using the best fine tuned bert model, uploaded to hugging face hub
bb34072
raw
history blame contribute delete
820 Bytes
import torch
from .embeddings import generate_embeddings
def predict_relation_old(arg1, arg2, model, embedding_model, processor, best_threshold, label_encoder, model_type="pytorch"):
embeddings = generate_embeddings(arg1, arg2, embedding_model, processor)
if model_type == "pytorch":
model.eval()
with torch.no_grad():
tensor = torch.FloatTensor(embeddings).unsqueeze(0)
prob = torch.sigmoid(model(tensor)).item()
prediction = 1 if prob > best_threshold else 0
else:
prob = model.predict_proba(embeddings.reshape(1, -1))[0][1]
prediction = 1 if prob > best_threshold else 0
return {
"predicted_label": label_encoder.inverse_transform([prediction])[0],
"probability": prob,
"confidence": abs(prob - 0.5) * 2
}