File size: 4,035 Bytes
0b9d8c7
 
31086ae
0b9d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31086ae
0b9d8c7
 
 
 
31086ae
9e67c3b
 
0b9d8c7
 
 
31086ae
0b9d8c7
31086ae
0b9d8c7
 
 
 
 
 
 
 
 
 
31086ae
0b9d8c7
 
 
31086ae
0b9d8c7
31086ae
0b9d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
31086ae
0b9d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import random
from collections import deque
from typing import Any, Iterable, List, Literal, Set, Tuple

from graphgen.bases import BaseGraphStorage
from graphgen.bases.datatypes import Community

from .bfs_partitioner import BFSPartitioner

NODE_UNIT: str = "n"
EDGE_UNIT: str = "e"


class AnchorBFSPartitioner(BFSPartitioner):
    """
    Anchor BFS partitioner that partitions the graph into communities of a fixed size.
    1. Randomly choose a node of a specified type as the anchor.
    2. Expand the community using BFS until the max unit size is reached.(A unit is a node or an edge.)
    3. Non-anchor units can only be "pulled" into a community and never become seeds themselves.
    For example, for VQA tasks, we may want to use image nodes as anchors and expand to nearby text nodes and edges.
    """

    def __init__(
        self,
        *,
        anchor_type: Literal["image"] = "image",
        anchor_ids: Set[str] | None = None,
    ) -> None:
        super().__init__()
        self.anchor_type = anchor_type
        self.anchor_ids = anchor_ids

    def partition(
        self,
        g: BaseGraphStorage,
        max_units_per_community: int = 1,
        **kwargs: Any,
    ) -> Iterable[Community]:
        nodes = g.get_all_nodes()  # List[tuple[id, meta]]
        edges = g.get_all_edges()  # List[tuple[u, v, meta]]

        adj, _ = self._build_adjacency_list(nodes, edges)

        anchors: Set[str] = self._pick_anchor_ids(nodes)
        if not anchors:
            return  # if no anchors, return nothing

        used_n: set[str] = set()
        used_e: set[frozenset[str]] = set()

        seeds = list(anchors)
        random.shuffle(seeds)

        for seed_node in seeds:
            if seed_node in used_n:
                continue
            comm_n, comm_e = self._grow_community(
                seed_node, adj, max_units_per_community, used_n, used_e
            )
            if comm_n or comm_e:
                yield Community(id=seed_node, nodes=comm_n, edges=comm_e)

    def _pick_anchor_ids(
        self,
        nodes: List[tuple[str, dict]],
    ) -> Set[str]:
        if self.anchor_ids is not None:
            return self.anchor_ids

        anchor_ids: Set[str] = set()
        for node_id, meta in nodes:
            node_type = str(meta.get("entity_type", "")).lower()
            if self.anchor_type.lower() in node_type:
                anchor_ids.add(node_id)
        return anchor_ids

    @staticmethod
    def _grow_community(
        seed: str,
        adj: dict[str, List[str]],
        max_units: int,
        used_n: set[str],
        used_e: set[frozenset[str]],
    ) -> Tuple[List[str], List[Tuple[str, str]]]:
        """
        Grow a community from the seed node using BFS.
        :param seed: seed node id
        :param adj: adjacency list
        :param max_units: maximum number of units (nodes + edges) in the community
        :param used_n: set of used node ids
        :param used_e: set of used edge keys
        :return: (list of node ids, list of edge tuples)
        """
        comm_n: List[str] = []
        comm_e: List[Tuple[str, str]] = []
        queue: deque[tuple[str, Any]] = deque([(NODE_UNIT, seed)])
        cnt = 0

        while queue and cnt < max_units:
            k, it = queue.popleft()

            if k == NODE_UNIT:
                if it in used_n:
                    continue
                used_n.add(it)
                comm_n.append(it)
                cnt += 1
                for nei in adj[it]:
                    e_key = frozenset((it, nei))
                    if e_key not in used_e:
                        queue.append((EDGE_UNIT, e_key))
            else:  # EDGE_UNIT
                if it in used_e:
                    continue
                used_e.add(it)
                u, v = it
                comm_e.append((u, v))
                cnt += 1
                for n in it:
                    if n not in used_n:
                        queue.append((NODE_UNIT, n))

        return comm_n, comm_e