import re from collections import Counter, defaultdict from typing import Dict, List, Tuple from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT from graphgen.utils import ( detect_main_language, handle_single_entity_extraction, handle_single_relationship_extraction, logger, pack_history_conversations, split_string_by_multi_markers, ) class LightRAGKGBuilder(BaseKGBuilder): def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3): super().__init__(llm_client) self.max_loop = max_loop async def extract( self, chunk: Chunk ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]: """ Extract entities and relationships from a single chunk using the LLM client. :param chunk :return: (nodes_data, edges_data) """ chunk_id = chunk.id content = chunk.content # step 1: language_detection language = detect_main_language(content) hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format( **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content ) # step 2: initial glean final_result = await self.llm_client.generate_answer(hint_prompt) logger.debug("First extraction result: %s", final_result) # step3: iterative refinement history = pack_history_conversations(hint_prompt, final_result) for loop_idx in range(self.max_loop): if_loop_result = await self.llm_client.generate_answer( text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history ) if_loop_result = if_loop_result.strip().strip('"').strip("'").lower() if if_loop_result != "yes": break glean_result = await self.llm_client.generate_answer( text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history ) logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result) history += pack_history_conversations( KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result ) final_result += glean_result # step 4: parse the final result records = split_string_by_multi_markers( final_result, [ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"], KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"], ], ) nodes = defaultdict(list) edges = defaultdict(list) for record in records: match = re.search(r"\((.*)\)", record) if not match: continue inner = match.group(1) attributes = split_string_by_multi_markers( inner, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]] ) entity = await handle_single_entity_extraction(attributes, chunk_id) if entity is not None: nodes[entity["entity_name"]].append(entity) continue relation = await handle_single_relationship_extraction(attributes, chunk_id) if relation is not None: key = (relation["src_id"], relation["tgt_id"]) edges[key].append(relation) return dict(nodes), dict(edges) async def merge_nodes( self, node_data: tuple[str, List[dict]], kg_instance: BaseGraphStorage, ) -> None: entity_name, node_data = node_data entity_types = [] source_ids = [] descriptions = [] node = kg_instance.get_node(entity_name) if node is not None: entity_types.append(node["entity_type"]) source_ids.extend( split_string_by_multi_markers(node["source_id"], [""]) ) descriptions.append(node["description"]) # take the most frequent entity_type entity_type = sorted( Counter([dp["entity_type"] for dp in node_data] + entity_types).items(), key=lambda x: x[1], reverse=True, )[0][0] description = "".join( sorted(set([dp["description"] for dp in node_data] + descriptions)) ) description = await self._handle_kg_summary(entity_name, description) source_id = "".join( set([dp["source_id"] for dp in node_data] + source_ids) ) node_data = { "entity_type": entity_type, "description": description, "source_id": source_id, } kg_instance.upsert_node(entity_name, node_data=node_data) async def merge_edges( self, edges_data: tuple[Tuple[str, str], List[dict]], kg_instance: BaseGraphStorage, ) -> None: (src_id, tgt_id), edge_data = edges_data source_ids = [] descriptions = [] edge = kg_instance.get_edge(src_id, tgt_id) if edge is not None: source_ids.extend( split_string_by_multi_markers(edge["source_id"], [""]) ) descriptions.append(edge["description"]) description = "".join( sorted(set([dp["description"] for dp in edge_data] + descriptions)) ) source_id = "".join( set([dp["source_id"] for dp in edge_data] + source_ids) ) for insert_id in [src_id, tgt_id]: if not kg_instance.has_node(insert_id): kg_instance.upsert_node( insert_id, node_data={ "source_id": source_id, "description": description, "entity_type": "UNKNOWN", }, ) description = await self._handle_kg_summary( f"({src_id}, {tgt_id})", description ) kg_instance.upsert_edge( src_id, tgt_id, edge_data={"source_id": source_id, "description": description}, ) async def _handle_kg_summary( self, entity_or_relation_name: str, description: str, max_summary_tokens: int = 200, ) -> str: """ Handle knowledge graph summary :param entity_or_relation_name :param description :param max_summary_tokens :return summary """ tokenizer_instance = self.llm_client.tokenizer language = detect_main_language(description) tokens = tokenizer_instance.encode(description) if len(tokens) < max_summary_tokens: return description use_description = tokenizer_instance.decode(tokens[:max_summary_tokens]) prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format( entity_name=entity_or_relation_name, description_list=use_description.split(""), **KG_SUMMARIZATION_PROMPT["FORMAT"], ) new_description = await self.llm_client.generate_answer(prompt) logger.info( "Entity or relation %s summary: %s", entity_or_relation_name, new_description, ) return new_description