Spaces:
Running
Running
| import re | |
| import asyncio | |
| from typing import List | |
| from collections import defaultdict | |
| import gradio as gr | |
| from tqdm.asyncio import tqdm as tqdm_async | |
| from graphgen.models import Chunk, OpenAIModel, Tokenizer | |
| from graphgen.models.storage.base_storage import BaseGraphStorage | |
| from graphgen.templates import KG_EXTRACTION_PROMPT | |
| from graphgen.utils import (logger, pack_history_conversations, split_string_by_multi_markers, | |
| handle_single_entity_extraction, handle_single_relationship_extraction, | |
| detect_if_chinese) | |
| from graphgen.operators.merge_kg import merge_nodes, merge_edges | |
| # pylint: disable=too-many-statements | |
| async def extract_kg( | |
| llm_client: OpenAIModel, | |
| kg_instance: BaseGraphStorage, | |
| tokenizer_instance: Tokenizer, | |
| chunks: List[Chunk], | |
| progress_bar: gr.Progress = None, | |
| max_concurrent: int = 1000 | |
| ): | |
| """ | |
| :param llm_client: Synthesizer LLM model to extract entities and relationships | |
| :param kg_instance | |
| :param tokenizer_instance | |
| :param chunks | |
| :param progress_bar: Gradio progress bar to show the progress of the extraction | |
| :param max_concurrent | |
| :return: | |
| """ | |
| semaphore = asyncio.Semaphore(max_concurrent) | |
| async def _process_single_content(chunk: Chunk, max_loop: int = 3): | |
| async with semaphore: | |
| chunk_id = chunk.id | |
| content = chunk.content | |
| if detect_if_chinese(content): | |
| language = "Chinese" | |
| else: | |
| language = "English" | |
| KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language | |
| hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( | |
| **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content | |
| ) | |
| final_result = await llm_client.generate_answer(hint_prompt) | |
| logger.info('First result: %s', final_result) | |
| history = pack_history_conversations(hint_prompt, final_result) | |
| for loop_index in range(max_loop): | |
| if_loop_result = await llm_client.generate_answer( | |
| text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], | |
| history=history | |
| ) | |
| if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() | |
| if if_loop_result != "yes": | |
| break | |
| glean_result = await llm_client.generate_answer( | |
| text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], | |
| history=history | |
| ) | |
| logger.info('Loop %s glean: %s', loop_index, glean_result) | |
| history += pack_history_conversations(KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result) | |
| final_result += glean_result | |
| if loop_index == max_loop - 1: | |
| break | |
| records = split_string_by_multi_markers( | |
| final_result, | |
| [ | |
| KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], | |
| KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"]], | |
| ) | |
| nodes = defaultdict(list) | |
| edges = defaultdict(list) | |
| for record in records: | |
| record = re.search(r"\((.*)\)", record) | |
| if record is None: | |
| continue | |
| record = record.group(1) # ๆๅๆฌๅทๅ ็ๅ ๅฎน | |
| record_attributes = split_string_by_multi_markers( | |
| record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] | |
| ) | |
| entity = await handle_single_entity_extraction(record_attributes, chunk_id) | |
| if entity is not None: | |
| nodes[entity["entity_name"]].append(entity) | |
| continue | |
| relation = await handle_single_relationship_extraction(record_attributes, chunk_id) | |
| if relation is not None: | |
| edges[(relation["src_id"], relation["tgt_id"])].append(relation) | |
| return dict(nodes), dict(edges) | |
| results = [] | |
| chunk_number = len(chunks) | |
| async for result in tqdm_async( | |
| asyncio.as_completed([_process_single_content(c) for c in chunks]), | |
| total=len(chunks), | |
| desc="[3/4]Extracting entities and relationships from chunks", | |
| unit="chunk", | |
| ): | |
| try: | |
| if progress_bar is not None: | |
| progress_bar(len(results) / chunk_number, desc="[3/4]Extracting entities and relationships from chunks") | |
| results.append(await result) | |
| if progress_bar is not None and len(results) == chunk_number: | |
| progress_bar(1, desc="[3/4]Extracting entities and relationships from chunks") | |
| except Exception as e: # pylint: disable=broad-except | |
| logger.error("Error occurred while extracting entities and relationships from chunks: %s", e) | |
| nodes = defaultdict(list) | |
| edges = defaultdict(list) | |
| for n, e in results: | |
| for k, v in n.items(): | |
| nodes[k].extend(v) | |
| for k, v in e.items(): | |
| edges[tuple(sorted(k))].extend(v) | |
| await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance) | |
| await merge_edges(edges, kg_instance, llm_client, tokenizer_instance) | |
| return kg_instance | |