from typing import Any from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer from graphgen.models import ( AnchorBFSPartitioner, BFSPartitioner, DFSPartitioner, ECEPartitioner, LeidenPartitioner, ) from graphgen.utils import logger from .pre_tokenize import pre_tokenize async def partition_kg( kg_instance: BaseGraphStorage, chunk_storage: BaseKVStorage, tokenizer: Any = BaseTokenizer, partition_config: dict = None, ) -> list[ tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] ]: method = partition_config["method"] method_params = partition_config["method_params"] if method == "bfs": logger.info("Partitioning knowledge graph using BFS method.") partitioner = BFSPartitioner() elif method == "dfs": logger.info("Partitioning knowledge graph using DFS method.") partitioner = DFSPartitioner() elif method == "ece": logger.info("Partitioning knowledge graph using ECE method.") # TODO: before ECE partitioning, we need to: # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random # 2. pre-tokenize nodes and edges to get the token length edges = await kg_instance.get_all_edges() nodes = await kg_instance.get_all_nodes() await pre_tokenize(kg_instance, tokenizer, edges, nodes) partitioner = ECEPartitioner() elif method == "leiden": logger.info("Partitioning knowledge graph using Leiden method.") partitioner = LeidenPartitioner() elif method == "anchor_bfs": logger.info("Partitioning knowledge graph using Anchor BFS method.") partitioner = AnchorBFSPartitioner( anchor_type=method_params.get("anchor_type"), anchor_ids=set(method_params.get("anchor_ids", [])) if method_params.get("anchor_ids") else None, ) else: raise ValueError(f"Unsupported partition method: {method}") communities = await partitioner.partition(g=kg_instance, **method_params) logger.info("Partitioned the graph into %d communities.", len(communities)) batches = await partitioner.community2batch(communities, g=kg_instance) for _, batch in enumerate(batches): nodes, edges = batch for node_id, node_data in nodes: entity_type = node_data.get("entity_type") if entity_type and "image" in entity_type.lower(): node_id = node_id.strip('"').lower() image_data = await chunk_storage.get_by_id(node_id) if image_data: node_data["images"] = image_data return batches