GraphGen / graphgen /models /kg_builder /light_rag_kg_builder.py
github-actions[bot]
Auto-sync from demo at Tue Nov 25 11:19:13 UTC 2025
9e67c3b
import re
from collections import Counter, defaultdict
from typing import Dict, List, Tuple
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMWrapper, Chunk
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
from graphgen.utils import (
detect_main_language,
handle_single_entity_extraction,
handle_single_relationship_extraction,
logger,
pack_history_conversations,
split_string_by_multi_markers,
)
class LightRAGKGBuilder(BaseKGBuilder):
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
super().__init__(llm_client)
self.max_loop = max_loop
async def extract(
self, chunk: Chunk
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
"""
Extract entities and relationships from a single chunk using the LLM client.
:param chunk
:return: (nodes_data, edges_data)
"""
chunk_id = chunk.id
content = chunk.content
# step 1: language_detection
language = detect_main_language(content)
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
)
# step 2: initial glean
final_result = await self.llm_client.generate_answer(hint_prompt)
logger.debug("First extraction result: %s", final_result)
# step3: iterative refinement
history = pack_history_conversations(hint_prompt, final_result)
for loop_idx in range(self.max_loop):
if_loop_result = await self.llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
glean_result = await self.llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
)
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
history += pack_history_conversations(
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
)
final_result += glean_result
# step 4: parse the final result
records = split_string_by_multi_markers(
final_result,
[
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
],
)
nodes = defaultdict(list)
edges = defaultdict(list)
for record in records:
match = re.search(r"\((.*)\)", record)
if not match:
continue
inner = match.group(1)
attributes = split_string_by_multi_markers(
inner, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)
entity = await handle_single_entity_extraction(attributes, chunk_id)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue
relation = await handle_single_relationship_extraction(attributes, chunk_id)
if relation is not None:
key = (relation["src_id"], relation["tgt_id"])
edges[key].append(relation)
return dict(nodes), dict(edges)
async def merge_nodes(
self,
node_data: tuple[str, List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
entity_name, node_data = node_data
entity_types = []
source_ids = []
descriptions = []
node = kg_instance.get_node(entity_name)
if node is not None:
entity_types.append(node["entity_type"])
source_ids.extend(
split_string_by_multi_markers(node["source_id"], ["<SEP>"])
)
descriptions.append(node["description"])
# take the most frequent entity_type
entity_type = sorted(
Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
key=lambda x: x[1],
reverse=True,
)[0][0]
description = "<SEP>".join(
sorted(set([dp["description"] for dp in node_data] + descriptions))
)
description = await self._handle_kg_summary(entity_name, description)
source_id = "<SEP>".join(
set([dp["source_id"] for dp in node_data] + source_ids)
)
node_data = {
"entity_type": entity_type,
"description": description,
"source_id": source_id,
}
kg_instance.upsert_node(entity_name, node_data=node_data)
async def merge_edges(
self,
edges_data: tuple[Tuple[str, str], List[dict]],
kg_instance: BaseGraphStorage,
) -> None:
(src_id, tgt_id), edge_data = edges_data
source_ids = []
descriptions = []
edge = kg_instance.get_edge(src_id, tgt_id)
if edge is not None:
source_ids.extend(
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
)
descriptions.append(edge["description"])
description = "<SEP>".join(
sorted(set([dp["description"] for dp in edge_data] + descriptions))
)
source_id = "<SEP>".join(
set([dp["source_id"] for dp in edge_data] + source_ids)
)
for insert_id in [src_id, tgt_id]:
if not kg_instance.has_node(insert_id):
kg_instance.upsert_node(
insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
},
)
description = await self._handle_kg_summary(
f"({src_id}, {tgt_id})", description
)
kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={"source_id": source_id, "description": description},
)
async def _handle_kg_summary(
self,
entity_or_relation_name: str,
description: str,
max_summary_tokens: int = 200,
) -> str:
"""
Handle knowledge graph summary
:param entity_or_relation_name
:param description
:param max_summary_tokens
:return summary
"""
tokenizer_instance = self.llm_client.tokenizer
language = detect_main_language(description)
tokens = tokenizer_instance.encode(description)
if len(tokens) < max_summary_tokens:
return description
use_description = tokenizer_instance.decode(tokens[:max_summary_tokens])
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
entity_name=entity_or_relation_name,
description_list=use_description.split("<SEP>"),
**KG_SUMMARIZATION_PROMPT["FORMAT"],
)
new_description = await self.llm_client.generate_answer(prompt)
logger.info(
"Entity or relation %s summary: %s",
entity_or_relation_name,
new_description,
)
return new_description