File size: 7,631 Bytes
31086ae
 
 
 
 
f1eedd1
31086ae
 
f1eedd1
31086ae
 
f1eedd1
 
31086ae
f1eedd1
31086ae
f1eedd1
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
f1eedd1
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1eedd1
31086ae
f1eedd1
 
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1eedd1
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1eedd1
31086ae
dee1edd
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1eedd1
31086ae
 
 
 
 
 
 
f1eedd1
31086ae
 
dee1edd
31086ae
 
dee1edd
31086ae
dee1edd
31086ae
 
 
dee1edd
31086ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import inspect
import logging
from collections import defaultdict, deque
from functools import wraps
from typing import Any, Callable, Dict, List, Set

import ray
import ray.data

from graphgen.bases import Config, Node
from graphgen.utils import logger


class Engine:
    def __init__(
        self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
    ):
        self.config = Config(**config)
        self.global_params = self.config.global_params
        self.functions = functions
        self.datasets: Dict[str, ray.data.Dataset] = {}

        if not ray.is_initialized():
            context = ray.init(
                ignore_reinit_error=True,
                logging_level=logging.ERROR,
                log_to_driver=True,
                **ray_init_kwargs,
            )
            logger.info("Ray Dashboard URL: %s", context.dashboard_url)

    @staticmethod
    def _topo_sort(nodes: List[Node]) -> List[Node]:
        id_to_node: Dict[str, Node] = {}
        for n in nodes:
            id_to_node[n.id] = n

        indeg: Dict[str, int] = {nid: 0 for nid in id_to_node}
        adj: Dict[str, List[str]] = defaultdict(list)

        for n in nodes:
            nid = n.id
            deps: List[str] = n.dependencies
            uniq_deps: Set[str] = set(deps)
            for d in uniq_deps:
                if d not in id_to_node:
                    raise ValueError(
                        f"The dependency node id {d} of node {nid} is not defined in the configuration."
                    )
                indeg[nid] += 1
                adj[d].append(nid)

        zero_deg: deque = deque(
            [id_to_node[nid] for nid, deg in indeg.items() if deg == 0]
        )
        sorted_nodes: List[Node] = []

        while zero_deg:
            cur = zero_deg.popleft()
            sorted_nodes.append(cur)
            cur_id = cur.id
            for nb_id in adj.get(cur_id, []):
                indeg[nb_id] -= 1
                if indeg[nb_id] == 0:
                    zero_deg.append(id_to_node[nb_id])

        if len(sorted_nodes) != len(nodes):
            remaining = [nid for nid, deg in indeg.items() if deg > 0]
            raise ValueError(
                f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}"
            )

        return sorted_nodes

    def _get_input_dataset(
        self, node: Node, initial_ds: ray.data.Dataset
    ) -> ray.data.Dataset:
        deps = node.dependencies

        if not deps:
            return initial_ds

        if len(deps) == 1:
            return self.datasets[deps[0]]

        main_ds = self.datasets[deps[0]]
        other_dss = [self.datasets[d] for d in deps[1:]]
        return main_ds.union(*other_dss)

    def _execute_node(self, node: Node, initial_ds: ray.data.Dataset):
        def _filter_kwargs(
            func_or_class: Callable,
            global_params: Dict[str, Any],
            func_params: Dict[str, Any],
        ) -> Dict[str, Any]:
            """
            1. global_params: only when specified in function signature, will be passed
            2. func_params: pass specified params first, then **kwargs if exists
            """
            try:
                sig = inspect.signature(func_or_class)
            except ValueError:
                return {}

            params = sig.parameters
            final_kwargs = {}

            has_var_keywords = any(
                p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
            )
            valid_keys = set(params.keys())
            for k, v in global_params.items():
                if k in valid_keys:
                    final_kwargs[k] = v

            for k, v in func_params.items():
                if k in valid_keys or has_var_keywords:
                    final_kwargs[k] = v
            return final_kwargs

        if node.op_name not in self.functions:
            raise ValueError(f"Operator {node.op_name} not found for node {node.id}")

        op_handler = self.functions[node.op_name]
        node_params = _filter_kwargs(op_handler, self.global_params, node.params or {})

        if node.type == "source":
            self.datasets[node.id] = op_handler(**node_params)
            return

        input_ds = self._get_input_dataset(node, initial_ds)

        if inspect.isclass(op_handler):
            execution_params = node.execution_params or {}
            replicas = execution_params.get("replicas", 1)
            batch_size = (
                int(execution_params.get("batch_size"))
                if "batch_size" in execution_params
                else "default"
            )
            compute_resources = execution_params.get("compute_resources", {})

            if node.type == "aggregate":
                self.datasets[node.id] = input_ds.repartition(1).map_batches(
                    op_handler,
                    compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
                    batch_size=None,  # aggregate processes the whole dataset at once
                    num_gpus=compute_resources.get("num_gpus", 0)
                    if compute_resources
                    else 0,
                    fn_constructor_kwargs=node_params,
                    batch_format="pandas",
                )
            else:
                # others like map, filter, flatmap, map_batch let actors process data inside batches
                self.datasets[node.id] = input_ds.map_batches(
                    op_handler,
                    compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
                    batch_size=batch_size,
                    num_gpus=compute_resources.get("num_gpus", 0)
                    if compute_resources
                    else 0,
                    fn_constructor_kwargs=node_params,
                    batch_format="pandas",
                )

        else:

            @wraps(op_handler)
            def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
                return op_handler(row_or_batch, **node_params)

            if node.type == "map":
                self.datasets[node.id] = input_ds.map(func_wrapper)
            elif node.type == "filter":
                self.datasets[node.id] = input_ds.filter(func_wrapper)
            elif node.type == "flatmap":
                self.datasets[node.id] = input_ds.flat_map(func_wrapper)
            elif node.type == "aggregate":
                self.datasets[node.id] = input_ds.repartition(1).map_batches(
                    func_wrapper, batch_format="default"
                )
            elif node.type == "map_batch":
                self.datasets[node.id] = input_ds.map_batches(func_wrapper)
            else:
                raise ValueError(
                    f"Unsupported node type {node.type} for node {node.id}"
                )

    @staticmethod
    def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
        all_ids = {n.id for n in nodes}
        deps_set = set()
        for n in nodes:
            deps_set.update(n.dependencies)
        return all_ids - deps_set

    def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
        sorted_nodes = self._topo_sort(self.config.nodes)

        for node in sorted_nodes:
            self._execute_node(node, initial_ds)

        leaf_nodes = self._find_leaf_nodes(sorted_nodes)

        @ray.remote
        def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
            return ds.take_all()

        return {node_id: self.datasets[node_id] for node_id in leaf_nodes}