File size: 1,445 Bytes
bd90063
bb34072
bd90063
bb34072
bd90063
 
 
 
 
 
 
 
 
 
 
 
 
bb34072
bd90063
 
 
 
bb34072
db24b31
bd90063
 
db24b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import time
from .predict_bert import predict_relation

def print_pretty_prediction(result: dict, test_case_num: int, expected: str, claim1: str, claim2: str):
    """Pretty-print the result of a single test case."""
    prediction = result["predicted_label"]
    confidence = result["confidence"]
    probability = result["probability"]
    status = "✅ Correct" if prediction == expected else "❌ Incorrect"

    print(f"\n{'='*70}")
    print(f"Test Case {test_case_num}")
    print(f"{'='*70}")
    print(f"Claim 1: {claim1}")
    print(f"Claim 2: {claim2}")
    print(f"\nExpected:   {expected}")
    print(f"Predicted:  {prediction}")
    print(f"Probability: {probability:.4f}")
    print(f"Confidence: {confidence:.2%}")
    print(f"Status:     {status}")


def run_tests(model, tokenizer, device, test_cases):
    results_list = []

    for i, case in enumerate(test_cases, start=1):
        result = predict_relation(case["claim1"], case["claim2"], model, tokenizer, device)
        correct = result["predicted_label"] == case["expected"]

        results_list.append({
            "test_case": i,
            "claim1": case["claim1"],
            "claim2": case["claim2"],
            "expected": case["expected"],
            "predicted": result["predicted_label"],
            "probability": result["probability"],
            "confidence": result["confidence"],
            "correct": correct
        })

    return results_list