File size: 901 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from torch.utils.data import DataLoader


class InfiniteLoader(DataLoader):
    def __init__(
        self,
        *args,
        num_workers=0,
        pin_memory=True,
        is_infinite = True,
        **kwargs,
    ):
        super().__init__(
            *args,
            multiprocessing_context="fork" if num_workers > 0 else None,
            num_workers=num_workers,
            pin_memory=pin_memory,
            **kwargs,
        )
        self.dataset_iterator = super().__iter__()
        self.is_infinite = is_infinite

    def __iter__(self):
        return self

    def __next__(self):
        try:
            x = next(self.dataset_iterator)
        except StopIteration:
            self.dataset_iterator = super().__iter__()
            if self.is_infinite:
                x = next(self.dataset_iterator)
            else:
                raise StopIteration

        return x