File size: 4,164 Bytes
acd7cf4
d2af8b0
3a3b216
 
 
acd7cf4
 
 
 
 
 
3a3b216
acd7cf4
 
 
 
 
 
3a3b216
acd7cf4
 
 
 
 
3a3b216
acd7cf4
 
 
 
 
3a3b216
acd7cf4
 
 
 
 
 
3a3b216
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
3a3b216
d2af8b0
 
 
 
 
 
 
 
 
 
 
 
2ff8798
 
d2af8b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ff8798
d2af8b0
 
 
 
 
 
2ff8798
 
 
 
d2af8b0
2ff8798
 
 
 
 
d2af8b0
 
3a3b216
 
 
acd7cf4
 
d2af8b0
 
 
 
 
 
acd7cf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import math
from typing import Dict, List

from graphgen.bases.datatypes import Token


def preprocess_tokens(tokens: List[Token]) -> List[Token]:
    """Preprocess tokens for calculating confidence."""
    tokens = [x for x in tokens if x.prob > 0]
    return tokens


def joint_probability(tokens: List[Token]) -> float:
    """Calculate joint probability of a list of tokens."""
    tokens = preprocess_tokens(tokens)
    logprob_sum = sum(x.logprob for x in tokens)
    return math.exp(logprob_sum / len(tokens))


def min_prob(tokens: List[Token]) -> float:
    """Calculate the minimum probability of a list of tokens."""
    tokens = preprocess_tokens(tokens)
    return min(x.prob for x in tokens)


def average_prob(tokens: List[Token]) -> float:
    """Calculate the average probability of a list of tokens."""
    tokens = preprocess_tokens(tokens)
    return sum(x.prob for x in tokens) / len(tokens)


def average_confidence(tokens: List[Token]) -> float:
    """Calculate the average confidence of a list of tokens."""
    tokens = preprocess_tokens(tokens)
    confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
    return sum(confidence) / len(tokens)


def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
    """Calculate the loss for yes/no question."""
    losses = []
    for i, tokens in enumerate(tokens_list):
        token = tokens[0]
        assert token.text.lower() in ["yes", "no"]
        if token.text == ground_truth[i]:
            losses.append(1 - token.prob)
        else:
            losses.append(token.prob)
    return sum(losses) / len(losses)


def _normalize_yes_no(tokens: List[Token]) -> Dict[str, float]:
    """
    Mapping yes/no synonyms to their probabilities and normalizing.
    For example, given tokens with probabilities:
    - "yes" (0.6)
    - "yeah" (0.2)
    - "no" (0.1)
    - "nope" (0.1)
    The function will return:
    {"yes": 0.8, "no": 0.2}
    Among them, "yes" and "yeah" are synonyms for "yes",
    while "no" and "nope" are synonyms for "no".
    If no "yes" or "no" synonyms are present, it will be judged as uncertain.
    An uncertain result will also be considered as opposite to the ground truth.
    """
    yes_syno = {
        # English yes synonyms
        "yes",
        "yeah",
        "yea",
        "yep",
        "yup",
        "yay",
        "ya",
        "yah",
        "sure",
        "certainly",
        "absolutely",
        "definitely",
        "exactly",
        "indeed",
        "right",
        "correct",
        "true",
        "t",
        "1",
        # Chinese yes synonyms
        "是",
        "对",
        "好的",
        "行",
        "可以",
        "没错",
        "当然",
        "确实",
        "正确",
        "真",
        "对的",
    }
    no_syno = {
        # English no synonyms
        "no",
        "nope",
        "nop",
        "nah",
        "naw",
        "na",
        "negative",
        "never",
        "not",
        "false",
        "f",
        "0",
        # Chinese no synonyms
        "不",
        "不是",
        "没有",
        "错",
        "不对",
        "不行",
        "不能",
        "否",
        "假的",
    }

    yes_prob = 0.0
    no_prob = 0.0
    uncertain_prob = 0.0
    for tok in tokens:
        t = tok.text.lower().strip()
        if t in yes_syno:
            yes_prob += tok.prob
        elif t in no_syno:
            no_prob += tok.prob
        else:
            uncertain_prob += tok.prob

    total = yes_prob + no_prob + uncertain_prob

    return {
        "yes": yes_prob / total,
        "no": no_prob / total,
        "uncertain": uncertain_prob / total,
    }


def yes_no_loss_entropy(
    tokens_list: List[List[Token]], ground_truth: List[str]
) -> float:
    """Calculate the loss for yes/no question using entropy."""
    losses = []
    for toks, gt in zip(tokens_list, ground_truth):
        dist = _normalize_yes_no(toks)
        gt = gt.lower()
        assert gt in {"yes", "no"}
        prob_correct = dist[gt]
        losses.append(-math.log(prob_correct))
    return sum(losses) / len(losses)