Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| from typing import Optional | |
| import geobench | |
| import numpy as np | |
| import torch.multiprocessing | |
| from sklearn.utils import shuffle | |
| from torch.utils.data import Dataset | |
| from src.utils import DEFAULT_SEED | |
| from ..preprocess import impute_bands, impute_normalization_stats, normalize_bands | |
| torch.multiprocessing.set_sharing_strategy("file_system") | |
| class GeobenchDataset(Dataset): | |
| """ | |
| Class implementation inspired by: https://github.com/vishalned/MMEarth-train/tree/main | |
| """ | |
| def __init__( | |
| self, | |
| dataset_config_file: str, | |
| split: str, | |
| norm_operation, | |
| augmentation, | |
| partition, | |
| manual_subsetting: Optional[float] = None, | |
| ): | |
| with (Path(__file__).parents[0] / Path("configs") / Path(dataset_config_file)).open( | |
| "r" | |
| ) as f: | |
| config = json.load(f) | |
| assert split in ["train", "valid", "test"] | |
| self.split = split | |
| self.config = config | |
| self.norm_operation = norm_operation | |
| self.augmentation = augmentation | |
| self.partition = partition | |
| if config["task_type"] == "cls": | |
| self.tiles_per_img = 1 | |
| elif config["task_type"] == "seg": | |
| assert self.config["dataset_name"] in ["m-SA-crop-type", "m-cashew-plant"] | |
| # for cashew plant and SA crop type | |
| # images are 256x256, we want 64x64 | |
| self.tiles_per_img = 16 | |
| else: | |
| raise ValueError(f"task_type must be cls or seg, not {config['task_type']}") | |
| for task in geobench.task_iterator(benchmark_name=self.config["benchmark_name"]): | |
| if task.dataset_name == self.config["dataset_name"]: | |
| break | |
| self.dataset = task.get_dataset(split=self.split, partition_name=self.partition) | |
| print( | |
| f"In dataset length for split {split} and partition {partition}: length = {len(self.dataset)}" | |
| ) | |
| original_band_names = [ | |
| self.dataset[0].bands[i].band_info.name for i in range(len(self.dataset[0].bands)) | |
| ] | |
| self.band_names = list(self.config["band_info"].keys()) | |
| self.band_indices = [original_band_names.index(band_name) for band_name in self.band_names] | |
| self.band_info = impute_normalization_stats( | |
| self.config["band_info"], self.config["imputes"] | |
| ) | |
| self.manual_subsetting = manual_subsetting | |
| if self.manual_subsetting is not None: | |
| num_vals_to_keep = int(self.manual_subsetting * len(self.dataset) * self.tiles_per_img) | |
| active_indices = list(range(int(len(self.dataset) * self.tiles_per_img))) | |
| self.active_indices = shuffle( | |
| active_indices, random_state=DEFAULT_SEED, n_samples=num_vals_to_keep | |
| ) | |
| else: | |
| self.active_indices = list(range(int(len(self.dataset) * self.tiles_per_img))) | |
| def __getitem__(self, idx): | |
| dataset_idx = self.active_indices[idx] | |
| img_idx = dataset_idx // self.tiles_per_img # thanks Gabi / Marlena | |
| label = self.dataset[img_idx].label | |
| x = [] | |
| for band_idx in self.band_indices: | |
| x.append(self.dataset[img_idx].bands[band_idx].data) | |
| x = impute_bands(x, self.band_names, self.config["imputes"]) | |
| x = np.stack(x, axis=2) # (h, w, 13) | |
| assert x.shape[-1] == 13, f"All datasets must have 13 channels, not {x.shape[-1]}" | |
| if self.config["dataset_name"] == "m-so2sat": | |
| x = x * 10_000 | |
| x = torch.tensor(normalize_bands(x, self.norm_operation, self.band_info)) | |
| # check if label is an object or a number | |
| if not (isinstance(label, int) or isinstance(label, list)): | |
| label = label.data | |
| # label is a memoryview object, convert it to a list, and then to a numpy array | |
| label = np.array(list(label)) | |
| target = torch.tensor(label, dtype=torch.long) | |
| if self.tiles_per_img == 16: | |
| # thanks Gabi / Marlena | |
| # for cashew plant and SA crop type | |
| subtiles_per_dim = 4 | |
| h = 256 | |
| assert h % subtiles_per_dim == 0 | |
| pixels_per_dim = h // subtiles_per_dim | |
| subtile_idx = idx % self.tiles_per_img | |
| row_idx = subtile_idx // subtiles_per_dim | |
| col_idx = subtile_idx % subtiles_per_dim | |
| x = x[ | |
| row_idx * pixels_per_dim : (row_idx + 1) * pixels_per_dim, | |
| col_idx * pixels_per_dim : (col_idx + 1) * pixels_per_dim, | |
| :, | |
| ] | |
| target = target[ | |
| row_idx * pixels_per_dim : (row_idx + 1) * pixels_per_dim, | |
| col_idx * pixels_per_dim : (col_idx + 1) * pixels_per_dim, | |
| ] | |
| x, target = self.augmentation.apply(x, target, self.config["task_type"]) | |
| return {"s2": x, "target": target} | |
| def __len__(self): | |
| return int(len(self.active_indices)) | |