Spaces:
Running
Running
| import math | |
| import gradio as gr | |
| from graphgen.bases import BaseLLMWrapper | |
| from graphgen.models import JsonKVStorage, NetworkXStorage | |
| from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT | |
| from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy | |
| async def judge_statement( # pylint: disable=too-many-statements | |
| trainee_llm_client: BaseLLMWrapper, | |
| graph_storage: NetworkXStorage, | |
| rephrase_storage: JsonKVStorage, | |
| re_judge: bool = False, | |
| progress_bar: gr.Progress = None, | |
| ) -> NetworkXStorage: | |
| """ | |
| Get all edges and nodes and judge them | |
| :param trainee_llm_client: judge the statements to get comprehension loss | |
| :param graph_storage: graph storage instance | |
| :param rephrase_storage: rephrase storage instance | |
| :param re_judge: re-judge the relations | |
| :param progress_bar | |
| :return: | |
| """ | |
| async def _judge_single_relation( | |
| edge: tuple, | |
| ): | |
| source_id = edge[0] | |
| target_id = edge[1] | |
| edge_data = edge[2] | |
| if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: | |
| logger.debug( | |
| "Edge %s -> %s already judged, loss: %s, skip", | |
| source_id, | |
| target_id, | |
| edge_data["loss"], | |
| ) | |
| return source_id, target_id, edge_data | |
| description = edge_data["description"] | |
| try: | |
| descriptions = rephrase_storage.get_by_id(description) | |
| assert descriptions is not None | |
| judgements = [] | |
| gts = [gt for _, gt in descriptions] | |
| for description, gt in descriptions: | |
| judgement = await trainee_llm_client.generate_topk_per_token( | |
| STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) | |
| ) | |
| judgements.append(judgement[0].top_candidates) | |
| loss = yes_no_loss_entropy(judgements, gts) | |
| logger.debug( | |
| "Edge %s -> %s description: %s loss: %s", | |
| source_id, | |
| target_id, | |
| description, | |
| loss, | |
| ) | |
| edge_data["loss"] = loss | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.error( | |
| "Error in judging relation %s -> %s: %s", source_id, target_id, e | |
| ) | |
| logger.info("Use default loss 0.1") | |
| edge_data["loss"] = -math.log(0.1) | |
| graph_storage.update_edge(source_id, target_id, edge_data) | |
| return source_id, target_id, edge_data | |
| edges = graph_storage.get_all_edges() | |
| await run_concurrent( | |
| _judge_single_relation, | |
| edges, | |
| desc="Judging relations", | |
| unit="relation", | |
| progress_bar=progress_bar, | |
| ) | |
| async def _judge_single_entity( | |
| node: tuple, | |
| ): | |
| node_id = node[0] | |
| node_data = node[1] | |
| if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: | |
| logger.debug( | |
| "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] | |
| ) | |
| return node_id, node_data | |
| description = node_data["description"] | |
| try: | |
| descriptions = rephrase_storage.get_by_id(description) | |
| assert descriptions is not None | |
| judgements = [] | |
| gts = [gt for _, gt in descriptions] | |
| for description, gt in descriptions: | |
| judgement = await trainee_llm_client.generate_topk_per_token( | |
| STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) | |
| ) | |
| judgements.append(judgement[0].top_candidates) | |
| loss = yes_no_loss_entropy(judgements, gts) | |
| logger.debug("Node %s description: %s loss: %s", node_id, description, loss) | |
| node_data["loss"] = loss | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.error("Error in judging entity %s: %s", node_id, e) | |
| logger.error("Use default loss 0.1") | |
| node_data["loss"] = -math.log(0.1) | |
| graph_storage.update_node(node_id, node_data) | |
| return node_id, node_data | |
| nodes = graph_storage.get_all_nodes() | |
| await run_concurrent( | |
| _judge_single_entity, | |
| nodes, | |
| desc="Judging entities", | |
| unit="entity", | |
| progress_bar=progress_bar, | |
| ) | |
| return graph_storage | |