import math from typing import Dict, List from graphgen.bases.datatypes import Token def preprocess_tokens(tokens: List[Token]) -> List[Token]: """Preprocess tokens for calculating confidence.""" tokens = [x for x in tokens if x.prob > 0] return tokens def joint_probability(tokens: List[Token]) -> float: """Calculate joint probability of a list of tokens.""" tokens = preprocess_tokens(tokens) logprob_sum = sum(x.logprob for x in tokens) return math.exp(logprob_sum / len(tokens)) def min_prob(tokens: List[Token]) -> float: """Calculate the minimum probability of a list of tokens.""" tokens = preprocess_tokens(tokens) return min(x.prob for x in tokens) def average_prob(tokens: List[Token]) -> float: """Calculate the average probability of a list of tokens.""" tokens = preprocess_tokens(tokens) return sum(x.prob for x in tokens) / len(tokens) def average_confidence(tokens: List[Token]) -> float: """Calculate the average confidence of a list of tokens.""" tokens = preprocess_tokens(tokens) confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens] return sum(confidence) / len(tokens) def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float: """Calculate the loss for yes/no question.""" losses = [] for i, tokens in enumerate(tokens_list): token = tokens[0] assert token.text.lower() in ["yes", "no"] if token.text == ground_truth[i]: losses.append(1 - token.prob) else: losses.append(token.prob) return sum(losses) / len(losses) def _normalize_yes_no(tokens: List[Token]) -> Dict[str, float]: """ Mapping yes/no synonyms to their probabilities and normalizing. For example, given tokens with probabilities: - "yes" (0.6) - "yeah" (0.2) - "no" (0.1) - "nope" (0.1) The function will return: {"yes": 0.8, "no": 0.2} Among them, "yes" and "yeah" are synonyms for "yes", while "no" and "nope" are synonyms for "no". If no "yes" or "no" synonyms are present, it will be judged as uncertain. An uncertain result will also be considered as opposite to the ground truth. """ yes_syno = { # English yes synonyms "yes", "yeah", "yea", "yep", "yup", "yay", "ya", "yah", "sure", "certainly", "absolutely", "definitely", "exactly", "indeed", "right", "correct", "true", "t", "1", # Chinese yes synonyms "是", "对", "好的", "行", "可以", "没错", "当然", "确实", "正确", "真", "对的", } no_syno = { # English no synonyms "no", "nope", "nop", "nah", "naw", "na", "negative", "never", "not", "false", "f", "0", # Chinese no synonyms "不", "不是", "没有", "错", "不对", "不行", "不能", "否", "假的", } yes_prob = 0.0 no_prob = 0.0 uncertain_prob = 0.0 for tok in tokens: t = tok.text.lower().strip() if t in yes_syno: yes_prob += tok.prob elif t in no_syno: no_prob += tok.prob else: uncertain_prob += tok.prob total = yes_prob + no_prob + uncertain_prob return { "yes": yes_prob / total, "no": no_prob / total, "uncertain": uncertain_prob / total, } def yes_no_loss_entropy( tokens_list: List[List[Token]], ground_truth: List[str] ) -> float: """Calculate the loss for yes/no question using entropy.""" losses = [] for toks, gt in zip(tokens_list, ground_truth): dist = _normalize_yes_no(toks) gt = gt.lower() assert gt in {"yes", "no"} prob_correct = dist[gt] losses.append(-math.log(prob_correct)) return sum(losses) / len(losses)