Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for multiprocessing | |
| """ | |
| import os | |
| from multiprocessing.dummy import Pool as ThreadPool | |
| import torch | |
| from torch.multiprocessing import Pool as TorchPool, set_start_method | |
| from tqdm import tqdm | |
| def cpu_count(): | |
| """ | |
| Returns the number of available CPUs for the python process | |
| """ | |
| return len(os.sched_getaffinity(0)) | |
| def parallel_threads( | |
| function, | |
| args, | |
| workers=0, | |
| star_args=False, | |
| kw_args=False, | |
| front_num=1, | |
| Pool=ThreadPool, | |
| ordered_res=True, | |
| **tqdm_kw, | |
| ): | |
| """tqdm but with parallel execution. | |
| Will essentially return | |
| res = [ function(arg) # default | |
| function(*arg) # if star_args is True | |
| function(**arg) # if kw_args is True | |
| for arg in args] | |
| Note: | |
| the <front_num> first elements of args will not be parallelized. | |
| This can be useful for debugging. | |
| """ | |
| # Determine the number of workers | |
| while workers <= 0: | |
| workers += cpu_count() | |
| # Convert args to an iterable | |
| try: | |
| n_args_parallel = len(args) - front_num | |
| except TypeError: | |
| n_args_parallel = None | |
| args = iter(args) | |
| # Sequential execution for the first few elements (useful for debugging) | |
| front = [] | |
| while len(front) < front_num: | |
| try: | |
| a = next(args) | |
| except StopIteration: | |
| return front # end of the iterable | |
| front.append( | |
| function(*a) if star_args else function(**a) if kw_args else function(a) | |
| ) | |
| # Parallel execution using multiprocessing.dummy | |
| out = [] | |
| with Pool(workers) as pool: | |
| if star_args: | |
| map_func = pool.imap if ordered_res else pool.imap_unordered | |
| futures = map_func(starcall, [(function, a) for a in args]) | |
| elif kw_args: | |
| map_func = pool.imap if ordered_res else pool.imap_unordered | |
| futures = map_func(starstarcall, [(function, a) for a in args]) | |
| else: | |
| map_func = pool.imap if ordered_res else pool.imap_unordered | |
| futures = map_func(function, args) | |
| # Track progress with tqdm | |
| for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): | |
| out.append(f) | |
| return front + out | |
| def cuda_parallel_threads( | |
| function, | |
| args, | |
| workers=0, | |
| star_args=False, | |
| kw_args=False, | |
| front_num=1, | |
| Pool=TorchPool, | |
| ordered_res=True, | |
| **tqdm_kw, | |
| ): | |
| """ | |
| Parallel execution of a function using torch.multiprocessing with CUDA support. | |
| This is the CUDA variant of the parallel_threads function. | |
| """ | |
| # Set the start method for multiprocessing | |
| set_start_method("spawn", force=True) | |
| # Determine the number of workers | |
| while workers <= 0: | |
| workers += torch.multiprocessing.cpu_count() | |
| # Convert args to an iterable | |
| try: | |
| n_args_parallel = len(args) - front_num | |
| except TypeError: | |
| n_args_parallel = None | |
| args = iter(args) | |
| # Sequential execution for the first few elements (useful for debugging) | |
| front = [] | |
| while len(front) < front_num: | |
| try: | |
| a = next(args) | |
| except StopIteration: | |
| return front # End of the iterable | |
| front.append( | |
| function(*a) if star_args else function(**a) if kw_args else function(a) | |
| ) | |
| # Parallel execution using torch.multiprocessing | |
| out = [] | |
| with Pool(workers) as pool: | |
| if star_args: | |
| map_func = pool.imap if ordered_res else pool.imap_unordered | |
| futures = map_func(starcall, [(function, a) for a in args]) | |
| elif kw_args: | |
| map_func = pool.imap if ordered_res else pool.imap_unordered | |
| futures = map_func(starstarcall, [(function, a) for a in args]) | |
| else: | |
| map_func = pool.imap if ordered_res else pool.imap_unordered | |
| futures = map_func(function, args) | |
| # Track progress with tqdm | |
| for f in tqdm(futures, total=n_args_parallel, **tqdm_kw): | |
| out.append(f) | |
| return front + out | |
| def parallel_processes(*args, **kwargs): | |
| """Same as parallel_threads, with processes""" | |
| import multiprocessing as mp | |
| kwargs["Pool"] = mp.Pool | |
| return parallel_threads(*args, **kwargs) | |
| def starcall(args): | |
| """convenient wrapper for Process.Pool""" | |
| function, args = args | |
| return function(*args) | |
| def starstarcall(args): | |
| """convenient wrapper for Process.Pool""" | |
| function, args = args | |
| return function(**args) | |