Spaces:
Running
Running
File size: 3,629 Bytes
f1eedd1 dee1edd f1eedd1 dee1edd f1eedd1 dee1edd f1eedd1 dee1edd f1eedd1 dee1edd f1eedd1 dee1edd f1eedd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""
orchestration engine for GraphGen
"""
import threading
import traceback
from typing import Any, Callable, List
class Context(dict):
_lock = threading.Lock()
def set(self, k, v):
with self._lock:
self[k] = v
def get(self, k, default=None):
with self._lock:
return super().get(k, default)
class OpNode:
def __init__(
self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
):
self.name, self.deps, self.func = name, deps, func
class Engine:
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
def run(self, ops: List[OpNode], ctx: Context):
self._validate(ops)
name2op = {operation.name: operation for operation in ops}
# topological sort
graph = {n: set(name2op[n].deps) for n in name2op}
topo = []
q = [n for n, d in graph.items() if not d]
while q:
cur = q.pop(0)
topo.append(cur)
for child in [c for c, d in graph.items() if cur in d]:
graph[child].remove(cur)
if not graph[child]:
q.append(child)
if len(topo) != len(ops):
raise ValueError(
"Cyclic dependencies detected among operations."
"Please check your configuration."
)
# semaphore for max_workers
sem = threading.Semaphore(self.max_workers)
done = {n: threading.Event() for n in name2op}
exc = {}
def _exec(n: str):
with sem:
for d in name2op[n].deps:
done[d].wait()
if any(d in exc for d in name2op[n].deps):
exc[n] = Exception("Skipped due to failed dependencies")
done[n].set()
return
try:
name2op[n].func(name2op[n], ctx)
except Exception:
exc[n] = traceback.format_exc()
done[n].set()
ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
for t in ts:
t.start()
for t in ts:
t.join()
if exc:
raise RuntimeError(
"Some operations failed:\n"
+ "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
)
@staticmethod
def _validate(ops: List[OpNode]):
name_set = set()
for op in ops:
if op.name in name_set:
raise ValueError(f"Duplicate operation name: {op.name}")
name_set.add(op.name)
for op in ops:
for dep in op.deps:
if dep not in name_set:
raise ValueError(
f"Operation {op.name} has unknown dependency: {dep}"
)
def collect_ops(config: dict, graph_gen) -> List[OpNode]:
"""
build operation nodes from yaml config
:param config
:param graph_gen
"""
ops: List[OpNode] = []
for stage in config["pipeline"]:
name = stage["name"]
method_name = stage.get("op_key")
method = getattr(graph_gen, method_name)
deps = stage.get("deps", [])
if "params" in stage:
def func(self, ctx, _method=method, _params=stage.get("params", {})):
return _method(_params)
else:
def func(self, ctx, _method=method):
return _method()
op_node = OpNode(name=name, deps=deps, func=func)
ops.append(op_node)
return ops
|