github-actions[bot]
Auto-sync from demo at Tue Nov 25 11:19:13 UTC 2025
9e67c3b
raw
history blame
3.93 kB
from typing import Any
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer
from graphgen.models import (
AnchorBFSPartitioner,
BFSPartitioner,
DFSPartitioner,
ECEPartitioner,
LeidenPartitioner,
)
from graphgen.utils import logger
from .pre_tokenize import pre_tokenize
async def partition_kg(
kg_instance: BaseGraphStorage,
chunk_storage: BaseKVStorage,
tokenizer: Any = BaseTokenizer,
partition_config: dict = None,
) -> list[
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
method = partition_config["method"]
method_params = partition_config["method_params"]
if method == "bfs":
logger.info("Partitioning knowledge graph using BFS method.")
partitioner = BFSPartitioner()
elif method == "dfs":
logger.info("Partitioning knowledge graph using DFS method.")
partitioner = DFSPartitioner()
elif method == "ece":
logger.info("Partitioning knowledge graph using ECE method.")
# TODO: before ECE partitioning, we need to:
# 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
# 2. pre-tokenize nodes and edges to get the token length
edges = kg_instance.get_all_edges()
nodes = kg_instance.get_all_nodes()
await pre_tokenize(kg_instance, tokenizer, edges, nodes)
partitioner = ECEPartitioner()
elif method == "leiden":
logger.info("Partitioning knowledge graph using Leiden method.")
partitioner = LeidenPartitioner()
elif method == "anchor_bfs":
logger.info("Partitioning knowledge graph using Anchor BFS method.")
partitioner = AnchorBFSPartitioner(
anchor_type=method_params.get("anchor_type"),
anchor_ids=set(method_params.get("anchor_ids", []))
if method_params.get("anchor_ids")
else None,
)
else:
raise ValueError(f"Unsupported partition method: {method}")
communities = await partitioner.partition(g=kg_instance, **method_params)
logger.info("Partitioned the graph into %d communities.", len(communities))
batches = await partitioner.community2batch(communities, g=kg_instance)
batches = await attach_additional_data_to_node(batches, chunk_storage)
return batches
async def attach_additional_data_to_node(
batches: list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
],
chunk_storage: BaseKVStorage,
) -> list[
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
"""
Attach additional data from chunk_storage to nodes in the batches.
:param batches:
:param chunk_storage:
:return:
"""
for batch in batches:
for node_id, node_data in batch[0]:
await _attach_by_type(node_id, node_data, chunk_storage)
return batches
async def _attach_by_type(
node_id: str,
node_data: dict,
chunk_storage: BaseKVStorage,
) -> None:
"""
Attach additional data to the node based on its entity type.
"""
entity_type = (node_data.get("entity_type") or "").lower()
if not entity_type:
return
source_ids = [
sid.strip()
for sid in node_data.get("source_id", "").split("<SEP>")
if sid.strip()
]
# Handle images
if "image" in entity_type:
image_chunks = [
data
for sid in source_ids
if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid))
]
if image_chunks:
# The generator expects a dictionary with an 'img_path' key, not a list of captions.
# We'll use the first image chunk found for this node.
node_data["images"] = image_chunks[0]
logger.debug("Attached image data to node %s", node_id)