File size: 4,187 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, f1_score

from .metrics import class_wise_f1


def run_knn(
    eval_type,
    train_embeddings,
    train_labels,
    test_embeddings,
    test_labels,
    num_classes,
    is_multilabel,
    device,
    skip_idx=False,
    return_class_wise=False,
):
    if not is_multilabel:
        if eval_type == "KNN-5":
            predictions = _run_knn_for_k(
                train_embeddings=train_embeddings,
                train_labels=train_labels,
                test_embeddings=test_embeddings,
                num_classes=num_classes,
                k=5,
                device=device,
                skip_idx=skip_idx,
            )
        elif eval_type == "KNN-20":
            predictions = _run_knn_for_k(
                train_embeddings=train_embeddings,
                train_labels=train_labels,
                test_embeddings=test_embeddings,
                num_classes=num_classes,
                k=20,
                device=device,
                skip_idx=skip_idx,
            )

        if return_class_wise:
            return class_wise_f1(y_true=test_labels, y_pred=predictions, num_classes=num_classes)
        return accuracy_score(y_true=test_labels, y_pred=predictions)
    else:
        # multilabel dataset, e.g., BigEarthNet
        # we will run KNN or K-Means once per class to compute predictions
        # labels are shape (num_samples, num_classes)
        assert num_classes == train_labels.shape[-1]
        assert num_classes == test_labels.shape[-1]
        predictions = []
        for class_idx in range(num_classes):
            train_single_labels = train_labels[:, class_idx]  # (num_samples)

            if eval_type == "KNN-5":
                single_predictions = _run_knn_for_k(
                    train_embeddings=train_embeddings,
                    train_labels=train_single_labels,
                    test_embeddings=test_embeddings,
                    num_classes=2,  # binary prediction for each class
                    k=5,
                    device=device,
                    skip_idx=skip_idx,
                )  # (num_samples)
            if eval_type == "KNN-20":
                single_predictions = _run_knn_for_k(
                    train_embeddings=train_embeddings,
                    train_labels=train_single_labels,
                    test_embeddings=test_embeddings,
                    num_classes=2,  # binary prediction for each class
                    k=20,
                    device=device,
                    skip_idx=skip_idx,
                )  # (num_samples)
            predictions.append(single_predictions)

        predictions = torch.stack(predictions, dim=1)  # (num_samples, num_classes)

        if return_class_wise:
            return [f1_score(test_labels[:, i], predictions[:, i]) for i in range(num_classes)]
        else:
            return f1_score(y_true=test_labels, y_pred=predictions, average="micro")


def _run_knn_for_k(
    train_embeddings, train_labels, test_embeddings, num_classes, k, device, skip_idx
):
    train_embeddings = train_embeddings.to(device)
    test_embeddings = test_embeddings.to(device)
    train_labels = train_labels.to(device)
    cos = nn.CosineSimilarity(dim=-1)
    all_preds = []
    for idx in range(test_embeddings.shape[0]):
        test_embedding = test_embeddings[idx].unsqueeze(dim=0).repeat(train_embeddings.shape[0], 1)
        sims = cos(test_embedding, train_embeddings)
        top_k = torch.topk(sims, k=k)
        if skip_idx:
            top_k_values = top_k.values[1:]
            top_k_indices = top_k.indices[1:]
        else:
            top_k_values = top_k.values
            top_k_indices = top_k.indices

        fetched_labels = train_labels[top_k_indices]
        fetched_onehots = nn.functional.one_hot(fetched_labels, num_classes=num_classes)
        distances = top_k_values.clone().div_(0.07).exp_()
        weighted_sum_onehots = (distances.unsqueeze(dim=1) * fetched_onehots).sum(dim=0)
        prediction = torch.argmax(weighted_sum_onehots)
        all_preds.append(prediction)

    return torch.LongTensor(all_preds).cpu()