github-actions[bot]
Auto-sync from demo at Tue Nov 25 11:19:13 UTC 2025
9e67c3b
raw
history blame
4.45 kB
import math
import gradio as gr
from graphgen.bases import BaseLLMWrapper
from graphgen.models import JsonKVStorage, NetworkXStorage
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
async def judge_statement( # pylint: disable=too-many-statements
trainee_llm_client: BaseLLMWrapper,
graph_storage: NetworkXStorage,
rephrase_storage: JsonKVStorage,
re_judge: bool = False,
progress_bar: gr.Progress = None,
) -> NetworkXStorage:
"""
Get all edges and nodes and judge them
:param trainee_llm_client: judge the statements to get comprehension loss
:param graph_storage: graph storage instance
:param rephrase_storage: rephrase storage instance
:param re_judge: re-judge the relations
:param progress_bar
:return:
"""
async def _judge_single_relation(
edge: tuple,
):
source_id = edge[0]
target_id = edge[1]
edge_data = edge[2]
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
logger.debug(
"Edge %s -> %s already judged, loss: %s, skip",
source_id,
target_id,
edge_data["loss"],
)
return source_id, target_id, edge_data
description = edge_data["description"]
try:
descriptions = rephrase_storage.get_by_id(description)
assert descriptions is not None
judgements = []
gts = [gt for _, gt in descriptions]
for description, gt in descriptions:
judgement = await trainee_llm_client.generate_topk_per_token(
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
)
judgements.append(judgement[0].top_candidates)
loss = yes_no_loss_entropy(judgements, gts)
logger.debug(
"Edge %s -> %s description: %s loss: %s",
source_id,
target_id,
description,
loss,
)
edge_data["loss"] = loss
except Exception as e: # pylint: disable=broad-except
logger.error(
"Error in judging relation %s -> %s: %s", source_id, target_id, e
)
logger.info("Use default loss 0.1")
edge_data["loss"] = -math.log(0.1)
graph_storage.update_edge(source_id, target_id, edge_data)
return source_id, target_id, edge_data
edges = graph_storage.get_all_edges()
await run_concurrent(
_judge_single_relation,
edges,
desc="Judging relations",
unit="relation",
progress_bar=progress_bar,
)
async def _judge_single_entity(
node: tuple,
):
node_id = node[0]
node_data = node[1]
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
logger.debug(
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
)
return node_id, node_data
description = node_data["description"]
try:
descriptions = rephrase_storage.get_by_id(description)
assert descriptions is not None
judgements = []
gts = [gt for _, gt in descriptions]
for description, gt in descriptions:
judgement = await trainee_llm_client.generate_topk_per_token(
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
)
judgements.append(judgement[0].top_candidates)
loss = yes_no_loss_entropy(judgements, gts)
logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
node_data["loss"] = loss
except Exception as e: # pylint: disable=broad-except
logger.error("Error in judging entity %s: %s", node_id, e)
logger.error("Use default loss 0.1")
node_data["loss"] = -math.log(0.1)
graph_storage.update_node(node_id, node_data)
return node_id, node_data
nodes = graph_storage.get_all_nodes()
await run_concurrent(
_judge_single_entity,
nodes,
desc="Judging entities",
unit="entity",
progress_bar=progress_bar,
)
return graph_storage