Spaces:
Sleeping
Sleeping
File size: 4,507 Bytes
9507532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
Utility functions for multiprocessing
"""
import os
from multiprocessing.dummy import Pool as ThreadPool
import torch
from torch.multiprocessing import Pool as TorchPool, set_start_method
from tqdm import tqdm
def cpu_count():
"""
Returns the number of available CPUs for the python process
"""
return len(os.sched_getaffinity(0))
def parallel_threads(
function,
args,
workers=0,
star_args=False,
kw_args=False,
front_num=1,
Pool=ThreadPool,
ordered_res=True,
**tqdm_kw,
):
"""tqdm but with parallel execution.
Will essentially return
res = [ function(arg) # default
function(*arg) # if star_args is True
function(**arg) # if kw_args is True
for arg in args]
Note:
the <front_num> first elements of args will not be parallelized.
This can be useful for debugging.
"""
# Determine the number of workers
while workers <= 0:
workers += cpu_count()
# Convert args to an iterable
try:
n_args_parallel = len(args) - front_num
except TypeError:
n_args_parallel = None
args = iter(args)
# Sequential execution for the first few elements (useful for debugging)
front = []
while len(front) < front_num:
try:
a = next(args)
except StopIteration:
return front # end of the iterable
front.append(
function(*a) if star_args else function(**a) if kw_args else function(a)
)
# Parallel execution using multiprocessing.dummy
out = []
with Pool(workers) as pool:
if star_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starcall, [(function, a) for a in args])
elif kw_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starstarcall, [(function, a) for a in args])
else:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(function, args)
# Track progress with tqdm
for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
out.append(f)
return front + out
def cuda_parallel_threads(
function,
args,
workers=0,
star_args=False,
kw_args=False,
front_num=1,
Pool=TorchPool,
ordered_res=True,
**tqdm_kw,
):
"""
Parallel execution of a function using torch.multiprocessing with CUDA support.
This is the CUDA variant of the parallel_threads function.
"""
# Set the start method for multiprocessing
set_start_method("spawn", force=True)
# Determine the number of workers
while workers <= 0:
workers += torch.multiprocessing.cpu_count()
# Convert args to an iterable
try:
n_args_parallel = len(args) - front_num
except TypeError:
n_args_parallel = None
args = iter(args)
# Sequential execution for the first few elements (useful for debugging)
front = []
while len(front) < front_num:
try:
a = next(args)
except StopIteration:
return front # End of the iterable
front.append(
function(*a) if star_args else function(**a) if kw_args else function(a)
)
# Parallel execution using torch.multiprocessing
out = []
with Pool(workers) as pool:
if star_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starcall, [(function, a) for a in args])
elif kw_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starstarcall, [(function, a) for a in args])
else:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(function, args)
# Track progress with tqdm
for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
out.append(f)
return front + out
def parallel_processes(*args, **kwargs):
"""Same as parallel_threads, with processes"""
import multiprocessing as mp
kwargs["Pool"] = mp.Pool
return parallel_threads(*args, **kwargs)
def starcall(args):
"""convenient wrapper for Process.Pool"""
function, args = args
return function(*args)
def starstarcall(args):
"""convenient wrapper for Process.Pool"""
function, args = args
return function(**args)
|