from typing import Any from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer from graphgen.models import ( AnchorBFSPartitioner, BFSPartitioner, DFSPartitioner, ECEPartitioner, LeidenPartitioner, ) from graphgen.utils import logger from .pre_tokenize import pre_tokenize async def partition_kg( kg_instance: BaseGraphStorage, chunk_storage: BaseKVStorage, tokenizer: Any = BaseTokenizer, partition_config: dict = None, ) -> list[ tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] ]: method = partition_config["method"] method_params = partition_config["method_params"] if method == "bfs": logger.info("Partitioning knowledge graph using BFS method.") partitioner = BFSPartitioner() elif method == "dfs": logger.info("Partitioning knowledge graph using DFS method.") partitioner = DFSPartitioner() elif method == "ece": logger.info("Partitioning knowledge graph using ECE method.") # TODO: before ECE partitioning, we need to: # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random # 2. pre-tokenize nodes and edges to get the token length edges = kg_instance.get_all_edges() nodes = kg_instance.get_all_nodes() await pre_tokenize(kg_instance, tokenizer, edges, nodes) partitioner = ECEPartitioner() elif method == "leiden": logger.info("Partitioning knowledge graph using Leiden method.") partitioner = LeidenPartitioner() elif method == "anchor_bfs": logger.info("Partitioning knowledge graph using Anchor BFS method.") partitioner = AnchorBFSPartitioner( anchor_type=method_params.get("anchor_type"), anchor_ids=set(method_params.get("anchor_ids", [])) if method_params.get("anchor_ids") else None, ) else: raise ValueError(f"Unsupported partition method: {method}") communities = await partitioner.partition(g=kg_instance, **method_params) logger.info("Partitioned the graph into %d communities.", len(communities)) batches = await partitioner.community2batch(communities, g=kg_instance) batches = await attach_additional_data_to_node(batches, chunk_storage) return batches async def attach_additional_data_to_node( batches: list[ tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ] ], chunk_storage: BaseKVStorage, ) -> list[ tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] ]: """ Attach additional data from chunk_storage to nodes in the batches. :param batches: :param chunk_storage: :return: """ for batch in batches: for node_id, node_data in batch[0]: await _attach_by_type(node_id, node_data, chunk_storage) return batches async def _attach_by_type( node_id: str, node_data: dict, chunk_storage: BaseKVStorage, ) -> None: """ Attach additional data to the node based on its entity type. """ entity_type = (node_data.get("entity_type") or "").lower() if not entity_type: return source_ids = [ sid.strip() for sid in node_data.get("source_id", "").split("") if sid.strip() ] # Handle images if "image" in entity_type: image_chunks = [ data for sid in source_ids if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid)) ] if image_chunks: # The generator expects a dictionary with an 'img_path' key, not a list of captions. # We'll use the first image chunk found for this node. node_data["images"] = image_chunks[0] logger.debug("Attached image data to node %s", node_id)