""" Utility functions for multiprocessing """ import os from multiprocessing.dummy import Pool as ThreadPool import torch from torch.multiprocessing import Pool as TorchPool, set_start_method from tqdm import tqdm def cpu_count(): """ Returns the number of available CPUs for the python process """ return len(os.sched_getaffinity(0)) def parallel_threads( function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, ordered_res=True, **tqdm_kw, ): """tqdm but with parallel execution. Will essentially return res = [ function(arg) # default function(*arg) # if star_args is True function(**arg) # if kw_args is True for arg in args] Note: the first elements of args will not be parallelized. This can be useful for debugging. """ # Determine the number of workers while workers <= 0: workers += cpu_count() # Convert args to an iterable try: n_args_parallel = len(args) - front_num except TypeError: n_args_parallel = None args = iter(args) # Sequential execution for the first few elements (useful for debugging) front = [] while len(front) < front_num: try: a = next(args) except StopIteration: return front # end of the iterable front.append( function(*a) if star_args else function(**a) if kw_args else function(a) ) # Parallel execution using multiprocessing.dummy out = [] with Pool(workers) as pool: if star_args: map_func = pool.imap if ordered_res else pool.imap_unordered futures = map_func(starcall, [(function, a) for a in args]) elif kw_args: map_func = pool.imap if ordered_res else pool.imap_unordered futures = map_func(starstarcall, [(function, a) for a in args]) else: map_func = pool.imap if ordered_res else pool.imap_unordered futures = map_func(function, args) # Track progress with tqdm for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): out.append(f) return front + out def cuda_parallel_threads( function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=TorchPool, ordered_res=True, **tqdm_kw, ): """ Parallel execution of a function using torch.multiprocessing with CUDA support. This is the CUDA variant of the parallel_threads function. """ # Set the start method for multiprocessing set_start_method("spawn", force=True) # Determine the number of workers while workers <= 0: workers += torch.multiprocessing.cpu_count() # Convert args to an iterable try: n_args_parallel = len(args) - front_num except TypeError: n_args_parallel = None args = iter(args) # Sequential execution for the first few elements (useful for debugging) front = [] while len(front) < front_num: try: a = next(args) except StopIteration: return front # End of the iterable front.append( function(*a) if star_args else function(**a) if kw_args else function(a) ) # Parallel execution using torch.multiprocessing out = [] with Pool(workers) as pool: if star_args: map_func = pool.imap if ordered_res else pool.imap_unordered futures = map_func(starcall, [(function, a) for a in args]) elif kw_args: map_func = pool.imap if ordered_res else pool.imap_unordered futures = map_func(starstarcall, [(function, a) for a in args]) else: map_func = pool.imap if ordered_res else pool.imap_unordered futures = map_func(function, args) # Track progress with tqdm for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): out.append(f) return front + out def parallel_processes(*args, **kwargs): """Same as parallel_threads, with processes""" import multiprocessing as mp kwargs["Pool"] = mp.Pool return parallel_threads(*args, **kwargs) def starcall(args): """convenient wrapper for Process.Pool""" function, args = args return function(*args) def starstarcall(args): """convenient wrapper for Process.Pool""" function, args = args return function(**args)