Spaces:
Running
Running
File size: 3,925 Bytes
799ac7c 0b9d8c7 799ac7c 0b9d8c7 799ac7c 0b9d8c7 799ac7c 9e67c3b 799ac7c 0b9d8c7 799ac7c 0b9d8c7 e84bc8e 799ac7c e305bc7 9e67c3b e305bc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from typing import Any
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer
from graphgen.models import (
AnchorBFSPartitioner,
BFSPartitioner,
DFSPartitioner,
ECEPartitioner,
LeidenPartitioner,
)
from graphgen.utils import logger
from .pre_tokenize import pre_tokenize
async def partition_kg(
kg_instance: BaseGraphStorage,
chunk_storage: BaseKVStorage,
tokenizer: Any = BaseTokenizer,
partition_config: dict = None,
) -> list[
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
method = partition_config["method"]
method_params = partition_config["method_params"]
if method == "bfs":
logger.info("Partitioning knowledge graph using BFS method.")
partitioner = BFSPartitioner()
elif method == "dfs":
logger.info("Partitioning knowledge graph using DFS method.")
partitioner = DFSPartitioner()
elif method == "ece":
logger.info("Partitioning knowledge graph using ECE method.")
# TODO: before ECE partitioning, we need to:
# 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
# 2. pre-tokenize nodes and edges to get the token length
edges = kg_instance.get_all_edges()
nodes = kg_instance.get_all_nodes()
await pre_tokenize(kg_instance, tokenizer, edges, nodes)
partitioner = ECEPartitioner()
elif method == "leiden":
logger.info("Partitioning knowledge graph using Leiden method.")
partitioner = LeidenPartitioner()
elif method == "anchor_bfs":
logger.info("Partitioning knowledge graph using Anchor BFS method.")
partitioner = AnchorBFSPartitioner(
anchor_type=method_params.get("anchor_type"),
anchor_ids=set(method_params.get("anchor_ids", []))
if method_params.get("anchor_ids")
else None,
)
else:
raise ValueError(f"Unsupported partition method: {method}")
communities = await partitioner.partition(g=kg_instance, **method_params)
logger.info("Partitioned the graph into %d communities.", len(communities))
batches = await partitioner.community2batch(communities, g=kg_instance)
batches = await attach_additional_data_to_node(batches, chunk_storage)
return batches
async def attach_additional_data_to_node(
batches: list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
],
chunk_storage: BaseKVStorage,
) -> list[
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
"""
Attach additional data from chunk_storage to nodes in the batches.
:param batches:
:param chunk_storage:
:return:
"""
for batch in batches:
for node_id, node_data in batch[0]:
await _attach_by_type(node_id, node_data, chunk_storage)
return batches
async def _attach_by_type(
node_id: str,
node_data: dict,
chunk_storage: BaseKVStorage,
) -> None:
"""
Attach additional data to the node based on its entity type.
"""
entity_type = (node_data.get("entity_type") or "").lower()
if not entity_type:
return
source_ids = [
sid.strip()
for sid in node_data.get("source_id", "").split("<SEP>")
if sid.strip()
]
# Handle images
if "image" in entity_type:
image_chunks = [
data
for sid in source_ids
if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid))
]
if image_chunks:
# The generator expects a dictionary with an 'img_path' key, not a list of captions.
# We'll use the first image chunk found for this node.
node_data["images"] = image_chunks[0]
logger.debug("Attached image data to node %s", node_id)
|