ColamanAI's picture
Upload 243 files
0efdb24 verified
"""
Utility functions for multiprocessing
"""
import os
from multiprocessing.dummy import Pool as ThreadPool
import torch
from torch.multiprocessing import Pool as TorchPool, set_start_method
from tqdm import tqdm
def cpu_count():
"""
Returns the number of available CPUs for the python process
"""
return len(os.sched_getaffinity(0))
def parallel_threads(
function,
args,
workers=0,
star_args=False,
kw_args=False,
front_num=1,
Pool=ThreadPool,
ordered_res=True,
**tqdm_kw,
):
"""tqdm but with parallel execution.
Will essentially return
res = [ function(arg) # default
function(*arg) # if star_args is True
function(**arg) # if kw_args is True
for arg in args]
Note:
the <front_num> first elements of args will not be parallelized.
This can be useful for debugging.
"""
# Determine the number of workers
while workers <= 0:
workers += cpu_count()
# Convert args to an iterable
try:
n_args_parallel = len(args) - front_num
except TypeError:
n_args_parallel = None
args = iter(args)
# Sequential execution for the first few elements (useful for debugging)
front = []
while len(front) < front_num:
try:
a = next(args)
except StopIteration:
return front # end of the iterable
front.append(
function(*a) if star_args else function(**a) if kw_args else function(a)
)
# Parallel execution using multiprocessing.dummy
out = []
with Pool(workers) as pool:
if star_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starcall, [(function, a) for a in args])
elif kw_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starstarcall, [(function, a) for a in args])
else:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(function, args)
# Track progress with tqdm
for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
out.append(f)
return front + out
def cuda_parallel_threads(
function,
args,
workers=0,
star_args=False,
kw_args=False,
front_num=1,
Pool=TorchPool,
ordered_res=True,
**tqdm_kw,
):
"""
Parallel execution of a function using torch.multiprocessing with CUDA support.
This is the CUDA variant of the parallel_threads function.
"""
# Set the start method for multiprocessing
set_start_method("spawn", force=True)
# Determine the number of workers
while workers <= 0:
workers += torch.multiprocessing.cpu_count()
# Convert args to an iterable
try:
n_args_parallel = len(args) - front_num
except TypeError:
n_args_parallel = None
args = iter(args)
# Sequential execution for the first few elements (useful for debugging)
front = []
while len(front) < front_num:
try:
a = next(args)
except StopIteration:
return front # End of the iterable
front.append(
function(*a) if star_args else function(**a) if kw_args else function(a)
)
# Parallel execution using torch.multiprocessing
out = []
with Pool(workers) as pool:
if star_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starcall, [(function, a) for a in args])
elif kw_args:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(starstarcall, [(function, a) for a in args])
else:
map_func = pool.imap if ordered_res else pool.imap_unordered
futures = map_func(function, args)
# Track progress with tqdm
for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
out.append(f)
return front + out
def parallel_processes(*args, **kwargs):
"""Same as parallel_threads, with processes"""
import multiprocessing as mp
kwargs["Pool"] = mp.Pool
return parallel_threads(*args, **kwargs)
def starcall(args):
"""convenient wrapper for Process.Pool"""
function, args = args
return function(*args)
def starstarcall(args):
"""convenient wrapper for Process.Pool"""
function, args = args
return function(**args)