File size: 2,719 Bytes
799ac7c
 
0b9d8c7
799ac7c
0b9d8c7
799ac7c
 
 
 
 
 
 
 
 
 
 
 
0b9d8c7
799ac7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b9d8c7
 
 
 
 
 
 
 
799ac7c
 
 
 
 
 
0b9d8c7
 
 
 
 
 
 
 
 
 
799ac7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Any

from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer
from graphgen.models import (
    AnchorBFSPartitioner,
    BFSPartitioner,
    DFSPartitioner,
    ECEPartitioner,
    LeidenPartitioner,
)
from graphgen.utils import logger

from .pre_tokenize import pre_tokenize


async def partition_kg(
    kg_instance: BaseGraphStorage,
    chunk_storage: BaseKVStorage,
    tokenizer: Any = BaseTokenizer,
    partition_config: dict = None,
) -> list[
    tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
    method = partition_config["method"]
    method_params = partition_config["method_params"]
    if method == "bfs":
        logger.info("Partitioning knowledge graph using BFS method.")
        partitioner = BFSPartitioner()
    elif method == "dfs":
        logger.info("Partitioning knowledge graph using DFS method.")
        partitioner = DFSPartitioner()
    elif method == "ece":
        logger.info("Partitioning knowledge graph using ECE method.")
        # TODO: before ECE partitioning, we need to:
        # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
        # 2. pre-tokenize nodes and edges to get the token length
        edges = await kg_instance.get_all_edges()
        nodes = await kg_instance.get_all_nodes()
        await pre_tokenize(kg_instance, tokenizer, edges, nodes)
        partitioner = ECEPartitioner()
    elif method == "leiden":
        logger.info("Partitioning knowledge graph using Leiden method.")
        partitioner = LeidenPartitioner()
    elif method == "anchor_bfs":
        logger.info("Partitioning knowledge graph using Anchor BFS method.")
        partitioner = AnchorBFSPartitioner(
            anchor_type=method_params.get("anchor_type"),
            anchor_ids=set(method_params.get("anchor_ids", []))
            if method_params.get("anchor_ids")
            else None,
        )
    else:
        raise ValueError(f"Unsupported partition method: {method}")

    communities = await partitioner.partition(g=kg_instance, **method_params)
    logger.info("Partitioned the graph into %d communities.", len(communities))
    batches = await partitioner.community2batch(communities, g=kg_instance)

    for _, batch in enumerate(batches):
        nodes, edges = batch
        for node_id, node_data in nodes:
            entity_type = node_data.get("entity_type")
            if entity_type and "image" in entity_type.lower():
                node_id = node_id.strip('"').lower()
                image_data = await chunk_storage.get_by_id(node_id)
                if image_data:
                    node_data["images"] = image_data
    return batches