File size: 3,472 Bytes
f1eedd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
orchestration engine for GraphGen
"""

import threading
import traceback
from functools import wraps
from typing import Any, Callable, List


class Context(dict):
    _lock = threading.Lock()

    def set(self, k, v):
        with self._lock:
            self[k] = v

    def get(self, k, default=None):
        with self._lock:
            return super().get(k, default)


class OpNode:
    def __init__(
        self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
    ):
        self.name, self.deps, self.func = name, deps, func


def op(name: str, deps=None):
    deps = deps or []

    def decorator(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            return func(*args, **kwargs)

        _wrapper.op_node = OpNode(name, deps, lambda self, ctx: func(self, **ctx))
        return _wrapper

    return decorator


class Engine:
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers

    def run(self, ops: List[OpNode], ctx: Context):
        name2op = {operation.name: operation for operation in ops}

        # topological sort
        graph = {n: set(name2op[n].deps) for n in name2op}
        topo = []
        q = [n for n, d in graph.items() if not d]
        while q:
            cur = q.pop(0)
            topo.append(cur)
            for child in [c for c, d in graph.items() if cur in d]:
                graph[child].remove(cur)
                if not graph[child]:
                    q.append(child)

        if len(topo) != len(ops):
            raise ValueError(
                "Cyclic dependencies detected among operations."
                "Please check your configuration."
            )

        # semaphore for max_workers
        sem = threading.Semaphore(self.max_workers)
        done = {n: threading.Event() for n in name2op}
        exc = {}

        def _exec(n: str):
            with sem:
                for d in name2op[n].deps:
                    done[d].wait()
                if any(d in exc for d in name2op[n].deps):
                    exc[n] = Exception("Skipped due to failed dependencies")
                    done[n].set()
                    return
                try:
                    name2op[n].func(name2op[n], ctx)
                except Exception:  # pylint: disable=broad-except
                    exc[n] = traceback.format_exc()
                done[n].set()

        ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
        for t in ts:
            t.start()
        for t in ts:
            t.join()
        if exc:
            raise RuntimeError(
                "Some operations failed:\n"
                + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
            )


def collect_ops(config: dict, graph_gen) -> List[OpNode]:
    """
    build operation nodes from yaml config
    :param config
    :param graph_gen
    """
    ops: List[OpNode] = []
    for stage in config["pipeline"]:
        name = stage["name"]
        method = getattr(graph_gen, name)
        op_node = method.op_node

        # if there are runtime dependencies, override them
        runtime_deps = stage.get("deps", op_node.deps)
        op_node.deps = runtime_deps

        if "params" in stage:
            op_node.func = lambda self, ctx, m=method, sc=stage: m(sc.get("params", {}))
        else:
            op_node.func = lambda self, ctx, m=method: m()
        ops.append(op_node)
    return ops