File size: 4,507 Bytes
9507532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Utility functions for multiprocessing
"""

import os
from multiprocessing.dummy import Pool as ThreadPool

import torch
from torch.multiprocessing import Pool as TorchPool, set_start_method
from tqdm import tqdm


def cpu_count():
    """
    Returns the number of available CPUs for the python process
    """
    return len(os.sched_getaffinity(0))


def parallel_threads(
    function,
    args,
    workers=0,
    star_args=False,
    kw_args=False,
    front_num=1,
    Pool=ThreadPool,
    ordered_res=True,
    **tqdm_kw,
):
    """tqdm but with parallel execution.

    Will essentially return
      res = [ function(arg) # default
              function(*arg) # if star_args is True
              function(**arg) # if kw_args is True
              for arg in args]

    Note:
        the <front_num> first elements of args will not be parallelized.
        This can be useful for debugging.
    """
    # Determine the number of workers
    while workers <= 0:
        workers += cpu_count()

    # Convert args to an iterable
    try:
        n_args_parallel = len(args) - front_num
    except TypeError:
        n_args_parallel = None
    args = iter(args)

    # Sequential execution for the first few elements (useful for debugging)
    front = []
    while len(front) < front_num:
        try:
            a = next(args)
        except StopIteration:
            return front  # end of the iterable
        front.append(
            function(*a) if star_args else function(**a) if kw_args else function(a)
        )

    # Parallel execution using multiprocessing.dummy
    out = []
    with Pool(workers) as pool:
        if star_args:
            map_func = pool.imap if ordered_res else pool.imap_unordered
            futures = map_func(starcall, [(function, a) for a in args])
        elif kw_args:
            map_func = pool.imap if ordered_res else pool.imap_unordered
            futures = map_func(starstarcall, [(function, a) for a in args])
        else:
            map_func = pool.imap if ordered_res else pool.imap_unordered
            futures = map_func(function, args)
        # Track progress with tqdm
        for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
            out.append(f)
    return front + out


def cuda_parallel_threads(
    function,
    args,
    workers=0,
    star_args=False,
    kw_args=False,
    front_num=1,
    Pool=TorchPool,
    ordered_res=True,
    **tqdm_kw,
):
    """
    Parallel execution of a function using torch.multiprocessing with CUDA support.
    This is the CUDA variant of the parallel_threads function.
    """
    # Set the start method for multiprocessing
    set_start_method("spawn", force=True)

    # Determine the number of workers
    while workers <= 0:
        workers += torch.multiprocessing.cpu_count()

    # Convert args to an iterable
    try:
        n_args_parallel = len(args) - front_num
    except TypeError:
        n_args_parallel = None
    args = iter(args)

    # Sequential execution for the first few elements (useful for debugging)
    front = []
    while len(front) < front_num:
        try:
            a = next(args)
        except StopIteration:
            return front  # End of the iterable
        front.append(
            function(*a) if star_args else function(**a) if kw_args else function(a)
        )

    # Parallel execution using torch.multiprocessing
    out = []
    with Pool(workers) as pool:
        if star_args:
            map_func = pool.imap if ordered_res else pool.imap_unordered
            futures = map_func(starcall, [(function, a) for a in args])
        elif kw_args:
            map_func = pool.imap if ordered_res else pool.imap_unordered
            futures = map_func(starstarcall, [(function, a) for a in args])
        else:
            map_func = pool.imap if ordered_res else pool.imap_unordered
            futures = map_func(function, args)
        # Track progress with tqdm
        for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
            out.append(f)
    return front + out


def parallel_processes(*args, **kwargs):
    """Same as parallel_threads, with processes"""
    import multiprocessing as mp

    kwargs["Pool"] = mp.Pool
    return parallel_threads(*args, **kwargs)


def starcall(args):
    """convenient wrapper for Process.Pool"""
    function, args = args
    return function(*args)


def starstarcall(args):
    """convenient wrapper for Process.Pool"""
    function, args = args
    return function(**args)