Spaces:
Running
Running
File size: 4,454 Bytes
fb9c306 1189434 fb9c306 d02622b acd7cf4 1189434 acd7cf4 fb9c306 d02622b fb9c306 1189434 fb9c306 acd7cf4 1189434 acd7cf4 1189434 acd7cf4 1189434 acd7cf4 1189434 9e67c3b 1189434 acd7cf4 1189434 9e67c3b 1189434 acd7cf4 1189434 acd7cf4 1189434 acd7cf4 1189434 acd7cf4 9e67c3b 1189434 acd7cf4 9e67c3b acd7cf4 1189434 fb9c306 1189434 acd7cf4 1189434 acd7cf4 1189434 acd7cf4 1189434 acd7cf4 1189434 9e67c3b 1189434 acd7cf4 1189434 9e67c3b 1189434 acd7cf4 1189434 acd7cf4 9e67c3b acd7cf4 1189434 acd7cf4 9e67c3b 1189434 acd7cf4 9e67c3b acd7cf4 1189434 fb9c306 1189434 acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import math
import gradio as gr
from graphgen.bases import BaseLLMWrapper
from graphgen.models import JsonKVStorage, NetworkXStorage
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
async def judge_statement( # pylint: disable=too-many-statements
trainee_llm_client: BaseLLMWrapper,
graph_storage: NetworkXStorage,
rephrase_storage: JsonKVStorage,
re_judge: bool = False,
progress_bar: gr.Progress = None,
) -> NetworkXStorage:
"""
Get all edges and nodes and judge them
:param trainee_llm_client: judge the statements to get comprehension loss
:param graph_storage: graph storage instance
:param rephrase_storage: rephrase storage instance
:param re_judge: re-judge the relations
:param progress_bar
:return:
"""
async def _judge_single_relation(
edge: tuple,
):
source_id = edge[0]
target_id = edge[1]
edge_data = edge[2]
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
logger.debug(
"Edge %s -> %s already judged, loss: %s, skip",
source_id,
target_id,
edge_data["loss"],
)
return source_id, target_id, edge_data
description = edge_data["description"]
try:
descriptions = rephrase_storage.get_by_id(description)
assert descriptions is not None
judgements = []
gts = [gt for _, gt in descriptions]
for description, gt in descriptions:
judgement = await trainee_llm_client.generate_topk_per_token(
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
)
judgements.append(judgement[0].top_candidates)
loss = yes_no_loss_entropy(judgements, gts)
logger.debug(
"Edge %s -> %s description: %s loss: %s",
source_id,
target_id,
description,
loss,
)
edge_data["loss"] = loss
except Exception as e: # pylint: disable=broad-except
logger.error(
"Error in judging relation %s -> %s: %s", source_id, target_id, e
)
logger.info("Use default loss 0.1")
edge_data["loss"] = -math.log(0.1)
graph_storage.update_edge(source_id, target_id, edge_data)
return source_id, target_id, edge_data
edges = graph_storage.get_all_edges()
await run_concurrent(
_judge_single_relation,
edges,
desc="Judging relations",
unit="relation",
progress_bar=progress_bar,
)
async def _judge_single_entity(
node: tuple,
):
node_id = node[0]
node_data = node[1]
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
logger.debug(
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
)
return node_id, node_data
description = node_data["description"]
try:
descriptions = rephrase_storage.get_by_id(description)
assert descriptions is not None
judgements = []
gts = [gt for _, gt in descriptions]
for description, gt in descriptions:
judgement = await trainee_llm_client.generate_topk_per_token(
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
)
judgements.append(judgement[0].top_candidates)
loss = yes_no_loss_entropy(judgements, gts)
logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
node_data["loss"] = loss
except Exception as e: # pylint: disable=broad-except
logger.error("Error in judging entity %s: %s", node_id, e)
logger.error("Use default loss 0.1")
node_data["loss"] = -math.log(0.1)
graph_storage.update_node(node_id, node_data)
return node_id, node_data
nodes = graph_storage.get_all_nodes()
await run_concurrent(
_judge_single_entity,
nodes,
desc="Judging entities",
unit="entity",
progress_bar=progress_bar,
)
return graph_storage
|