Spaces:
Sleeping
Sleeping
File size: 901 Bytes
95b1715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from torch.utils.data import DataLoader
class InfiniteLoader(DataLoader):
def __init__(
self,
*args,
num_workers=0,
pin_memory=True,
is_infinite = True,
**kwargs,
):
super().__init__(
*args,
multiprocessing_context="fork" if num_workers > 0 else None,
num_workers=num_workers,
pin_memory=pin_memory,
**kwargs,
)
self.dataset_iterator = super().__iter__()
self.is_infinite = is_infinite
def __iter__(self):
return self
def __next__(self):
try:
x = next(self.dataset_iterator)
except StopIteration:
self.dataset_iterator = super().__iter__()
if self.is_infinite:
x = next(self.dataset_iterator)
else:
raise StopIteration
return x
|