Spaces:
Running
Running
| """ | |
| orchestration engine for GraphGen | |
| """ | |
| import threading | |
| import traceback | |
| from typing import Any, Callable, List | |
| class Context(dict): | |
| _lock = threading.Lock() | |
| def set(self, k, v): | |
| with self._lock: | |
| self[k] = v | |
| def get(self, k, default=None): | |
| with self._lock: | |
| return super().get(k, default) | |
| class OpNode: | |
| def __init__( | |
| self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] | |
| ): | |
| self.name, self.deps, self.func = name, deps, func | |
| class Engine: | |
| def __init__(self, max_workers: int = 4): | |
| self.max_workers = max_workers | |
| def run(self, ops: List[OpNode], ctx: Context): | |
| self._validate(ops) | |
| name2op = {operation.name: operation for operation in ops} | |
| # topological sort | |
| graph = {n: set(name2op[n].deps) for n in name2op} | |
| topo = [] | |
| q = [n for n, d in graph.items() if not d] | |
| while q: | |
| cur = q.pop(0) | |
| topo.append(cur) | |
| for child in [c for c, d in graph.items() if cur in d]: | |
| graph[child].remove(cur) | |
| if not graph[child]: | |
| q.append(child) | |
| if len(topo) != len(ops): | |
| raise ValueError( | |
| "Cyclic dependencies detected among operations." | |
| "Please check your configuration." | |
| ) | |
| # semaphore for max_workers | |
| sem = threading.Semaphore(self.max_workers) | |
| done = {n: threading.Event() for n in name2op} | |
| exc = {} | |
| def _exec(n: str): | |
| with sem: | |
| for d in name2op[n].deps: | |
| done[d].wait() | |
| if any(d in exc for d in name2op[n].deps): | |
| exc[n] = Exception("Skipped due to failed dependencies") | |
| done[n].set() | |
| return | |
| try: | |
| name2op[n].func(name2op[n], ctx) | |
| except Exception: | |
| exc[n] = traceback.format_exc() | |
| done[n].set() | |
| ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] | |
| for t in ts: | |
| t.start() | |
| for t in ts: | |
| t.join() | |
| if exc: | |
| raise RuntimeError( | |
| "Some operations failed:\n" | |
| + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) | |
| ) | |
| def _validate(ops: List[OpNode]): | |
| name_set = set() | |
| for op in ops: | |
| if op.name in name_set: | |
| raise ValueError(f"Duplicate operation name: {op.name}") | |
| name_set.add(op.name) | |
| for op in ops: | |
| for dep in op.deps: | |
| if dep not in name_set: | |
| raise ValueError( | |
| f"Operation {op.name} has unknown dependency: {dep}" | |
| ) | |
| def collect_ops(config: dict, graph_gen) -> List[OpNode]: | |
| """ | |
| build operation nodes from yaml config | |
| :param config | |
| :param graph_gen | |
| """ | |
| ops: List[OpNode] = [] | |
| for stage in config["pipeline"]: | |
| name = stage["name"] | |
| method_name = stage.get("op_key") | |
| method = getattr(graph_gen, method_name) | |
| deps = stage.get("deps", []) | |
| if "params" in stage: | |
| def func(self, ctx, _method=method, _params=stage.get("params", {})): | |
| return _method(_params) | |
| else: | |
| def func(self, ctx, _method=method): | |
| return _method() | |
| op_node = OpNode(name=name, deps=deps, func=func) | |
| ops.append(op_node) | |
| return ops | |