File size: 5,837 Bytes
799ac7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0edfe
 
799ac7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
 
799ac7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import asyncio
import random
from typing import Any, Dict, List, Optional, Set, Tuple

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.bases import BaseGraphStorage
from graphgen.bases.datatypes import Community
from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner

NODE_UNIT: str = "n"
EDGE_UNIT: str = "e"


class ECEPartitioner(BFSPartitioner):
    """
    ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
    We calculate ECE for units in KG (represented as 'comprehension loss')
    and group units with similar ECE values into the same community.
    1. Select a sampling strategy.
    2. Choose a unit based on the sampling strategy.
    2. Expand the community using BFS.
    3. When expending, prefer to add units with the sampling strategy.
    4. Stop when the max unit size is reached or the max input length is reached.
    (A unit is a node or an edge.)
    """

    @staticmethod
    def _sort_units(units: list, edge_sampling: str) -> list:
        """
        Sort units with edge sampling strategy

        :param units: total units
        :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
        :return: sorted units
        """
        if edge_sampling == "random":
            random.shuffle(units)
        elif edge_sampling == "min_loss":
            units = sorted(
                units,
                key=lambda x: x[-1]["loss"],
            )
        elif edge_sampling == "max_loss":
            units = sorted(
                units,
                key=lambda x: x[-1]["loss"],
                reverse=True,
            )
        else:
            raise ValueError(f"Invalid edge sampling: {edge_sampling}")
        return units

    async def partition(
        self,
        g: BaseGraphStorage,
        max_units_per_community: int = 10,
        min_units_per_community: int = 1,
        max_tokens_per_community: int = 10240,
        unit_sampling: str = "random",
        **kwargs: Any,
    ) -> List[Community]:
        nodes: List[Tuple[str, dict]] = g.get_all_nodes()
        edges: List[Tuple[str, str, dict]] = g.get_all_edges()

        adj, _ = self._build_adjacency_list(nodes, edges)
        node_dict = dict(nodes)
        edge_dict = {frozenset((u, v)): d for u, v, d in edges}

        all_units: List[Tuple[str, Any, dict]] = [
            (NODE_UNIT, nid, d) for nid, d in nodes
        ] + [(EDGE_UNIT, frozenset((u, v)), d) for u, v, d in edges]

        used_n: Set[str] = set()
        used_e: Set[frozenset[str]] = set()
        communities: List = []

        all_units = self._sort_units(all_units, unit_sampling)

        async def _grow_community(
            seed_unit: Tuple[str, Any, dict]
        ) -> Optional[Community]:
            nonlocal used_n, used_e

            community_nodes: Dict[str, dict] = {}
            community_edges: Dict[frozenset[str], dict] = {}
            queue: asyncio.Queue = asyncio.Queue()
            token_sum = 0

            async def _add_unit(u):
                nonlocal token_sum
                t, i, d = u
                if t == NODE_UNIT:  # node
                    if i in used_n or i in community_nodes:
                        return False
                    community_nodes[i] = d
                    used_n.add(i)
                else:  # edge
                    if i in used_e or i in community_edges:
                        return False
                    community_edges[i] = d
                    used_e.add(i)
                token_sum += d.get("length", 0)
                return True

            await _add_unit(seed_unit)
            await queue.put(seed_unit)

            # BFS
            while not queue.empty():
                if (
                    len(community_nodes) + len(community_edges)
                    >= max_units_per_community
                    or token_sum >= max_tokens_per_community
                ):
                    break

                cur_type, cur_id, _ = await queue.get()

                neighbors: List[Tuple[str, Any, dict]] = []
                if cur_type == NODE_UNIT:
                    for nb_id in adj.get(cur_id, []):
                        e_key = frozenset((cur_id, nb_id))
                        if e_key not in used_e and e_key not in community_edges:
                            neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key]))
                else:
                    for n_id in cur_id:
                        if n_id not in used_n and n_id not in community_nodes:
                            neighbors.append((NODE_UNIT, n_id, node_dict[n_id]))

                neighbors = self._sort_units(neighbors, unit_sampling)
                for nb in neighbors:
                    if (
                        len(community_nodes) + len(community_edges)
                        >= max_units_per_community
                        or token_sum >= max_tokens_per_community
                    ):
                        break
                    if await _add_unit(nb):
                        await queue.put(nb)

            if len(community_nodes) + len(community_edges) < min_units_per_community:
                return None

            return Community(
                id=len(communities),
                nodes=list(community_nodes.keys()),
                edges=[(u, v) for (u, v), _ in community_edges.items()],
            )

        async for unit in tqdm_async(all_units, desc="ECE partition"):
            utype, uid, _ = unit
            if (utype == NODE_UNIT and uid in used_n) or (
                utype == EDGE_UNIT and uid in used_e
            ):
                continue
            comm = await _grow_community(unit)
            if comm is not None:
                communities.append(comm)

        return communities