Spaces:
Running
Running
File size: 7,631 Bytes
31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae f1eedd1 31086ae dee1edd 31086ae f1eedd1 31086ae f1eedd1 31086ae dee1edd 31086ae dee1edd 31086ae dee1edd 31086ae dee1edd 31086ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import inspect
import logging
from collections import defaultdict, deque
from functools import wraps
from typing import Any, Callable, Dict, List, Set
import ray
import ray.data
from graphgen.bases import Config, Node
from graphgen.utils import logger
class Engine:
def __init__(
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
):
self.config = Config(**config)
self.global_params = self.config.global_params
self.functions = functions
self.datasets: Dict[str, ray.data.Dataset] = {}
if not ray.is_initialized():
context = ray.init(
ignore_reinit_error=True,
logging_level=logging.ERROR,
log_to_driver=True,
**ray_init_kwargs,
)
logger.info("Ray Dashboard URL: %s", context.dashboard_url)
@staticmethod
def _topo_sort(nodes: List[Node]) -> List[Node]:
id_to_node: Dict[str, Node] = {}
for n in nodes:
id_to_node[n.id] = n
indeg: Dict[str, int] = {nid: 0 for nid in id_to_node}
adj: Dict[str, List[str]] = defaultdict(list)
for n in nodes:
nid = n.id
deps: List[str] = n.dependencies
uniq_deps: Set[str] = set(deps)
for d in uniq_deps:
if d not in id_to_node:
raise ValueError(
f"The dependency node id {d} of node {nid} is not defined in the configuration."
)
indeg[nid] += 1
adj[d].append(nid)
zero_deg: deque = deque(
[id_to_node[nid] for nid, deg in indeg.items() if deg == 0]
)
sorted_nodes: List[Node] = []
while zero_deg:
cur = zero_deg.popleft()
sorted_nodes.append(cur)
cur_id = cur.id
for nb_id in adj.get(cur_id, []):
indeg[nb_id] -= 1
if indeg[nb_id] == 0:
zero_deg.append(id_to_node[nb_id])
if len(sorted_nodes) != len(nodes):
remaining = [nid for nid, deg in indeg.items() if deg > 0]
raise ValueError(
f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}"
)
return sorted_nodes
def _get_input_dataset(
self, node: Node, initial_ds: ray.data.Dataset
) -> ray.data.Dataset:
deps = node.dependencies
if not deps:
return initial_ds
if len(deps) == 1:
return self.datasets[deps[0]]
main_ds = self.datasets[deps[0]]
other_dss = [self.datasets[d] for d in deps[1:]]
return main_ds.union(*other_dss)
def _execute_node(self, node: Node, initial_ds: ray.data.Dataset):
def _filter_kwargs(
func_or_class: Callable,
global_params: Dict[str, Any],
func_params: Dict[str, Any],
) -> Dict[str, Any]:
"""
1. global_params: only when specified in function signature, will be passed
2. func_params: pass specified params first, then **kwargs if exists
"""
try:
sig = inspect.signature(func_or_class)
except ValueError:
return {}
params = sig.parameters
final_kwargs = {}
has_var_keywords = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
)
valid_keys = set(params.keys())
for k, v in global_params.items():
if k in valid_keys:
final_kwargs[k] = v
for k, v in func_params.items():
if k in valid_keys or has_var_keywords:
final_kwargs[k] = v
return final_kwargs
if node.op_name not in self.functions:
raise ValueError(f"Operator {node.op_name} not found for node {node.id}")
op_handler = self.functions[node.op_name]
node_params = _filter_kwargs(op_handler, self.global_params, node.params or {})
if node.type == "source":
self.datasets[node.id] = op_handler(**node_params)
return
input_ds = self._get_input_dataset(node, initial_ds)
if inspect.isclass(op_handler):
execution_params = node.execution_params or {}
replicas = execution_params.get("replicas", 1)
batch_size = (
int(execution_params.get("batch_size"))
if "batch_size" in execution_params
else "default"
)
compute_resources = execution_params.get("compute_resources", {})
if node.type == "aggregate":
self.datasets[node.id] = input_ds.repartition(1).map_batches(
op_handler,
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
batch_size=None, # aggregate processes the whole dataset at once
num_gpus=compute_resources.get("num_gpus", 0)
if compute_resources
else 0,
fn_constructor_kwargs=node_params,
batch_format="pandas",
)
else:
# others like map, filter, flatmap, map_batch let actors process data inside batches
self.datasets[node.id] = input_ds.map_batches(
op_handler,
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
batch_size=batch_size,
num_gpus=compute_resources.get("num_gpus", 0)
if compute_resources
else 0,
fn_constructor_kwargs=node_params,
batch_format="pandas",
)
else:
@wraps(op_handler)
def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
return op_handler(row_or_batch, **node_params)
if node.type == "map":
self.datasets[node.id] = input_ds.map(func_wrapper)
elif node.type == "filter":
self.datasets[node.id] = input_ds.filter(func_wrapper)
elif node.type == "flatmap":
self.datasets[node.id] = input_ds.flat_map(func_wrapper)
elif node.type == "aggregate":
self.datasets[node.id] = input_ds.repartition(1).map_batches(
func_wrapper, batch_format="default"
)
elif node.type == "map_batch":
self.datasets[node.id] = input_ds.map_batches(func_wrapper)
else:
raise ValueError(
f"Unsupported node type {node.type} for node {node.id}"
)
@staticmethod
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
all_ids = {n.id for n in nodes}
deps_set = set()
for n in nodes:
deps_set.update(n.dependencies)
return all_ids - deps_set
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
sorted_nodes = self._topo_sort(self.config.nodes)
for node in sorted_nodes:
self._execute_node(node, initial_ds)
leaf_nodes = self._find_leaf_nodes(sorted_nodes)
@ray.remote
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
return ds.take_all()
return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
|