File size: 4,454 Bytes
fb9c306
 
1189434
fb9c306
d02622b
 
acd7cf4
1189434
acd7cf4
 
fb9c306
d02622b
fb9c306
 
 
1189434
fb9c306
acd7cf4
 
 
 
 
 
 
1189434
acd7cf4
 
 
 
 
 
1189434
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
1189434
acd7cf4
1189434
9e67c3b
1189434
acd7cf4
1189434
 
 
 
9e67c3b
1189434
 
acd7cf4
1189434
acd7cf4
1189434
 
 
 
 
 
 
acd7cf4
1189434
 
 
 
 
 
 
acd7cf4
9e67c3b
1189434
acd7cf4
9e67c3b
acd7cf4
1189434
 
 
fb9c306
1189434
 
 
acd7cf4
 
 
 
1189434
 
acd7cf4
1189434
 
 
 
 
acd7cf4
1189434
acd7cf4
1189434
9e67c3b
1189434
acd7cf4
1189434
 
 
 
9e67c3b
1189434
 
acd7cf4
1189434
acd7cf4
9e67c3b
acd7cf4
1189434
 
 
 
 
acd7cf4
9e67c3b
1189434
acd7cf4
9e67c3b
acd7cf4
1189434
 
 
fb9c306
1189434
 
 
acd7cf4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import math

import gradio as gr

from graphgen.bases import BaseLLMWrapper
from graphgen.models import JsonKVStorage, NetworkXStorage
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy


async def judge_statement(  # pylint: disable=too-many-statements
    trainee_llm_client: BaseLLMWrapper,
    graph_storage: NetworkXStorage,
    rephrase_storage: JsonKVStorage,
    re_judge: bool = False,
    progress_bar: gr.Progress = None,
) -> NetworkXStorage:
    """
    Get all edges and nodes and judge them

    :param trainee_llm_client: judge the statements to get comprehension loss
    :param graph_storage: graph storage instance
    :param rephrase_storage: rephrase storage instance
    :param re_judge: re-judge the relations
    :param progress_bar
    :return:
    """

    async def _judge_single_relation(
        edge: tuple,
    ):
        source_id = edge[0]
        target_id = edge[1]
        edge_data = edge[2]

        if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
            logger.debug(
                "Edge %s -> %s already judged, loss: %s, skip",
                source_id,
                target_id,
                edge_data["loss"],
            )
            return source_id, target_id, edge_data

        description = edge_data["description"]

        try:
            descriptions = rephrase_storage.get_by_id(description)
            assert descriptions is not None

            judgements = []
            gts = [gt for _, gt in descriptions]
            for description, gt in descriptions:
                judgement = await trainee_llm_client.generate_topk_per_token(
                    STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
                )
                judgements.append(judgement[0].top_candidates)

            loss = yes_no_loss_entropy(judgements, gts)

            logger.debug(
                "Edge %s -> %s description: %s loss: %s",
                source_id,
                target_id,
                description,
                loss,
            )

            edge_data["loss"] = loss
        except Exception as e:  # pylint: disable=broad-except
            logger.error(
                "Error in judging relation %s -> %s: %s", source_id, target_id, e
            )
            logger.info("Use default loss 0.1")
            edge_data["loss"] = -math.log(0.1)

        graph_storage.update_edge(source_id, target_id, edge_data)
        return source_id, target_id, edge_data

    edges = graph_storage.get_all_edges()

    await run_concurrent(
        _judge_single_relation,
        edges,
        desc="Judging relations",
        unit="relation",
        progress_bar=progress_bar,
    )

    async def _judge_single_entity(
        node: tuple,
    ):
        node_id = node[0]
        node_data = node[1]

        if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
            logger.debug(
                "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
            )
            return node_id, node_data

        description = node_data["description"]

        try:
            descriptions = rephrase_storage.get_by_id(description)
            assert descriptions is not None

            judgements = []
            gts = [gt for _, gt in descriptions]
            for description, gt in descriptions:
                judgement = await trainee_llm_client.generate_topk_per_token(
                    STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
                )
                judgements.append(judgement[0].top_candidates)

            loss = yes_no_loss_entropy(judgements, gts)

            logger.debug("Node %s description: %s loss: %s", node_id, description, loss)

            node_data["loss"] = loss
        except Exception as e:  # pylint: disable=broad-except
            logger.error("Error in judging entity %s: %s", node_id, e)
            logger.error("Use default loss 0.1")
            node_data["loss"] = -math.log(0.1)

        graph_storage.update_node(node_id, node_data)
        return node_id, node_data

    nodes = graph_storage.get_all_nodes()

    await run_concurrent(
        _judge_single_entity,
        nodes,
        desc="Judging entities",
        unit="entity",
        progress_bar=progress_bar,
    )

    return graph_storage