File size: 3,925 Bytes
799ac7c
 
0b9d8c7
799ac7c
0b9d8c7
799ac7c
 
 
 
 
 
 
 
 
 
 
 
0b9d8c7
799ac7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
 
799ac7c
 
 
 
 
0b9d8c7
 
 
 
 
 
 
 
799ac7c
 
 
 
 
 
0b9d8c7
e84bc8e
799ac7c
e305bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
e305bc7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Any

from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer
from graphgen.models import (
    AnchorBFSPartitioner,
    BFSPartitioner,
    DFSPartitioner,
    ECEPartitioner,
    LeidenPartitioner,
)
from graphgen.utils import logger

from .pre_tokenize import pre_tokenize


async def partition_kg(
    kg_instance: BaseGraphStorage,
    chunk_storage: BaseKVStorage,
    tokenizer: Any = BaseTokenizer,
    partition_config: dict = None,
) -> list[
    tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
    method = partition_config["method"]
    method_params = partition_config["method_params"]
    if method == "bfs":
        logger.info("Partitioning knowledge graph using BFS method.")
        partitioner = BFSPartitioner()
    elif method == "dfs":
        logger.info("Partitioning knowledge graph using DFS method.")
        partitioner = DFSPartitioner()
    elif method == "ece":
        logger.info("Partitioning knowledge graph using ECE method.")
        # TODO: before ECE partitioning, we need to:
        # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
        # 2. pre-tokenize nodes and edges to get the token length
        edges = kg_instance.get_all_edges()
        nodes = kg_instance.get_all_nodes()
        await pre_tokenize(kg_instance, tokenizer, edges, nodes)
        partitioner = ECEPartitioner()
    elif method == "leiden":
        logger.info("Partitioning knowledge graph using Leiden method.")
        partitioner = LeidenPartitioner()
    elif method == "anchor_bfs":
        logger.info("Partitioning knowledge graph using Anchor BFS method.")
        partitioner = AnchorBFSPartitioner(
            anchor_type=method_params.get("anchor_type"),
            anchor_ids=set(method_params.get("anchor_ids", []))
            if method_params.get("anchor_ids")
            else None,
        )
    else:
        raise ValueError(f"Unsupported partition method: {method}")

    communities = await partitioner.partition(g=kg_instance, **method_params)
    logger.info("Partitioned the graph into %d communities.", len(communities))
    batches = await partitioner.community2batch(communities, g=kg_instance)

    batches = await attach_additional_data_to_node(batches, chunk_storage)
    return batches


async def attach_additional_data_to_node(
    batches: list[
        tuple[
            list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
        ]
    ],
    chunk_storage: BaseKVStorage,
) -> list[
    tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
]:
    """
    Attach additional data from chunk_storage to nodes in the batches.
    :param batches:
    :param chunk_storage:
    :return:
    """
    for batch in batches:
        for node_id, node_data in batch[0]:
            await _attach_by_type(node_id, node_data, chunk_storage)
    return batches


async def _attach_by_type(
    node_id: str,
    node_data: dict,
    chunk_storage: BaseKVStorage,
) -> None:
    """
    Attach additional data to the node based on its entity type.
    """
    entity_type = (node_data.get("entity_type") or "").lower()
    if not entity_type:
        return

    source_ids = [
        sid.strip()
        for sid in node_data.get("source_id", "").split("<SEP>")
        if sid.strip()
    ]

    # Handle images
    if "image" in entity_type:
        image_chunks = [
            data
            for sid in source_ids
            if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid))
        ]
        if image_chunks:
            # The generator expects a dictionary with an 'img_path' key, not a list of captions.
            # We'll use the first image chunk found for this node.
            node_data["images"] = image_chunks[0]
            logger.debug("Attached image data to node %s", node_id)