File size: 820 Bytes
bd90063
 
 
bb34072
bd90063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from .embeddings import generate_embeddings

def predict_relation_old(arg1, arg2, model, embedding_model, processor, best_threshold, label_encoder, model_type="pytorch"):
    embeddings = generate_embeddings(arg1, arg2, embedding_model, processor)

    if model_type == "pytorch":
        model.eval()
        with torch.no_grad():
            tensor = torch.FloatTensor(embeddings).unsqueeze(0)
            prob = torch.sigmoid(model(tensor)).item()
        prediction = 1 if prob > best_threshold else 0
    else:
        prob = model.predict_proba(embeddings.reshape(1, -1))[0][1]
        prediction = 1 if prob > best_threshold else 0

    return {
        "predicted_label": label_encoder.inverse_transform([prediction])[0],
        "probability": prob,
        "confidence": abs(prob - 0.5) * 2
    }