File size: 3,629 Bytes
f1eedd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dee1edd
f1eedd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dee1edd
f1eedd1
 
 
 
 
 
 
 
 
 
 
 
 
 
dee1edd
 
 
 
 
 
 
 
 
 
 
 
 
 
f1eedd1
 
 
 
 
 
 
 
 
 
dee1edd
 
 
f1eedd1
 
dee1edd
 
 
 
f1eedd1
dee1edd
 
 
 
 
f1eedd1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
orchestration engine for GraphGen
"""

import threading
import traceback
from typing import Any, Callable, List


class Context(dict):
    _lock = threading.Lock()

    def set(self, k, v):
        with self._lock:
            self[k] = v

    def get(self, k, default=None):
        with self._lock:
            return super().get(k, default)


class OpNode:
    def __init__(
        self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
    ):
        self.name, self.deps, self.func = name, deps, func


class Engine:
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers

    def run(self, ops: List[OpNode], ctx: Context):
        self._validate(ops)
        name2op = {operation.name: operation for operation in ops}

        # topological sort
        graph = {n: set(name2op[n].deps) for n in name2op}
        topo = []
        q = [n for n, d in graph.items() if not d]
        while q:
            cur = q.pop(0)
            topo.append(cur)
            for child in [c for c, d in graph.items() if cur in d]:
                graph[child].remove(cur)
                if not graph[child]:
                    q.append(child)

        if len(topo) != len(ops):
            raise ValueError(
                "Cyclic dependencies detected among operations."
                "Please check your configuration."
            )

        # semaphore for max_workers
        sem = threading.Semaphore(self.max_workers)
        done = {n: threading.Event() for n in name2op}
        exc = {}

        def _exec(n: str):
            with sem:
                for d in name2op[n].deps:
                    done[d].wait()
                if any(d in exc for d in name2op[n].deps):
                    exc[n] = Exception("Skipped due to failed dependencies")
                    done[n].set()
                    return
                try:
                    name2op[n].func(name2op[n], ctx)
                except Exception:
                    exc[n] = traceback.format_exc()
                done[n].set()

        ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
        for t in ts:
            t.start()
        for t in ts:
            t.join()
        if exc:
            raise RuntimeError(
                "Some operations failed:\n"
                + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
            )

    @staticmethod
    def _validate(ops: List[OpNode]):
        name_set = set()
        for op in ops:
            if op.name in name_set:
                raise ValueError(f"Duplicate operation name: {op.name}")
            name_set.add(op.name)
        for op in ops:
            for dep in op.deps:
                if dep not in name_set:
                    raise ValueError(
                        f"Operation {op.name} has unknown dependency: {dep}"
                    )


def collect_ops(config: dict, graph_gen) -> List[OpNode]:
    """
    build operation nodes from yaml config
    :param config
    :param graph_gen
    """
    ops: List[OpNode] = []
    for stage in config["pipeline"]:
        name = stage["name"]
        method_name = stage.get("op_key")
        method = getattr(graph_gen, method_name)
        deps = stage.get("deps", [])

        if "params" in stage:

            def func(self, ctx, _method=method, _params=stage.get("params", {})):
                return _method(_params)

        else:

            def func(self, ctx, _method=method):
                return _method()

        op_node = OpNode(name=name, deps=deps, func=func)
        ops.append(op_node)
    return ops