File size: 4,980 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
from pathlib import Path
from typing import Optional

import geobench
import numpy as np
import torch.multiprocessing
from sklearn.utils import shuffle
from torch.utils.data import Dataset

from src.utils import DEFAULT_SEED

from ..preprocess import impute_bands, impute_normalization_stats, normalize_bands

torch.multiprocessing.set_sharing_strategy("file_system")


class GeobenchDataset(Dataset):
    """
    Class implementation inspired by: https://github.com/vishalned/MMEarth-train/tree/main
    """

    def __init__(
        self,
        dataset_config_file: str,
        split: str,
        norm_operation,
        augmentation,
        partition,
        manual_subsetting: Optional[float] = None,
    ):
        with (Path(__file__).parents[0] / Path("configs") / Path(dataset_config_file)).open(
            "r"
        ) as f:
            config = json.load(f)

        assert split in ["train", "valid", "test"]

        self.split = split
        self.config = config
        self.norm_operation = norm_operation
        self.augmentation = augmentation
        self.partition = partition

        if config["task_type"] == "cls":
            self.tiles_per_img = 1
        elif config["task_type"] == "seg":
            assert self.config["dataset_name"] in ["m-SA-crop-type", "m-cashew-plant"]
            # for cashew plant and SA crop type
            # images are 256x256, we want 64x64
            self.tiles_per_img = 16
        else:
            raise ValueError(f"task_type must be cls or seg, not {config['task_type']}")

        for task in geobench.task_iterator(benchmark_name=self.config["benchmark_name"]):
            if task.dataset_name == self.config["dataset_name"]:
                break

        self.dataset = task.get_dataset(split=self.split, partition_name=self.partition)
        print(
            f"In dataset length for split {split} and partition {partition}: length = {len(self.dataset)}"
        )

        original_band_names = [
            self.dataset[0].bands[i].band_info.name for i in range(len(self.dataset[0].bands))
        ]

        self.band_names = list(self.config["band_info"].keys())
        self.band_indices = [original_band_names.index(band_name) for band_name in self.band_names]
        self.band_info = impute_normalization_stats(
            self.config["band_info"], self.config["imputes"]
        )
        self.manual_subsetting = manual_subsetting

        if self.manual_subsetting is not None:
            num_vals_to_keep = int(self.manual_subsetting * len(self.dataset) * self.tiles_per_img)
            active_indices = list(range(int(len(self.dataset) * self.tiles_per_img)))
            self.active_indices = shuffle(
                active_indices, random_state=DEFAULT_SEED, n_samples=num_vals_to_keep
            )
        else:
            self.active_indices = list(range(int(len(self.dataset) * self.tiles_per_img)))

    def __getitem__(self, idx):
        dataset_idx = self.active_indices[idx]
        img_idx = dataset_idx // self.tiles_per_img  # thanks Gabi / Marlena
        label = self.dataset[img_idx].label

        x = []
        for band_idx in self.band_indices:
            x.append(self.dataset[img_idx].bands[band_idx].data)

        x = impute_bands(x, self.band_names, self.config["imputes"])

        x = np.stack(x, axis=2)  # (h, w, 13)
        assert x.shape[-1] == 13, f"All datasets must have 13 channels, not {x.shape[-1]}"
        if self.config["dataset_name"] == "m-so2sat":
            x = x * 10_000

        x = torch.tensor(normalize_bands(x, self.norm_operation, self.band_info))

        # check if label is an object or a number
        if not (isinstance(label, int) or isinstance(label, list)):
            label = label.data
            # label is a memoryview object, convert it to a list, and then to a numpy array
            label = np.array(list(label))

        target = torch.tensor(label, dtype=torch.long)

        if self.tiles_per_img == 16:
            # thanks Gabi / Marlena
            # for cashew plant and SA crop type
            subtiles_per_dim = 4
            h = 256
            assert h % subtiles_per_dim == 0
            pixels_per_dim = h // subtiles_per_dim
            subtile_idx = idx % self.tiles_per_img

            row_idx = subtile_idx // subtiles_per_dim
            col_idx = subtile_idx % subtiles_per_dim

            x = x[
                row_idx * pixels_per_dim : (row_idx + 1) * pixels_per_dim,
                col_idx * pixels_per_dim : (col_idx + 1) * pixels_per_dim,
                :,
            ]

            target = target[
                row_idx * pixels_per_dim : (row_idx + 1) * pixels_per_dim,
                col_idx * pixels_per_dim : (col_idx + 1) * pixels_per_dim,
            ]

        x, target = self.augmentation.apply(x, target, self.config["task_type"])
        return {"s2": x, "target": target}

    def __len__(self):
        return int(len(self.active_indices))