Spaces:
Sleeping
Sleeping
| import json | |
| import warnings | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from breizhcrops import BreizhCrops | |
| from breizhcrops.datasets.breizhcrops import SELECTED_BANDS | |
| from einops import repeat | |
| from torch.utils.data import ConcatDataset, Dataset | |
| from src.data.config import DATA_FOLDER | |
| from ..preprocess import normalize_bands | |
| LEVEL = "L1C" | |
| DATAPATH = DATA_FOLDER / "breizhcrops" | |
| OUTPUT_BAND_ORDER = [ | |
| "B1", | |
| "B2", | |
| "B3", | |
| "B4", | |
| "B5", | |
| "B6", | |
| "B7", | |
| "B8", | |
| "B8A", | |
| "B9", | |
| "B10", | |
| "B11", | |
| "B12", | |
| ] | |
| INPUT_TO_OUTPUT_BAND_MAPPING = [SELECTED_BANDS[LEVEL].index(b) for b in OUTPUT_BAND_ORDER] | |
| class BreizhCropsDataset(Dataset): | |
| def __init__( | |
| self, | |
| path_to_splits: Path, | |
| split: str, | |
| norm_operation, | |
| augmentation, | |
| partition, | |
| monthly_average: bool = True, | |
| ): | |
| """ | |
| https://isprs-archives.copernicus.org/articles/XLIII-B2-2020/1545/2020/ | |
| isprs-archives-XLIII-B2-2020-1545-2020.pdf | |
| We partitioned all acquired field parcels | |
| according to the NUTS-3 regions and suggest to subdivide the | |
| dataset into training (FRH01, FRH02), validation (FRH03), and | |
| evaluation (FRH04) subsets based on these spatially distinct | |
| regions. | |
| """ | |
| kwargs = { | |
| "root": path_to_splits, | |
| "preload_ram": False, | |
| "level": LEVEL, | |
| "transform": raw_transform, | |
| } | |
| # belle-ille is small, so its useful for testing | |
| assert split in ["train", "valid", "test", "belle-ile"] | |
| if split == "train": | |
| self.ds: Dataset = ConcatDataset( | |
| [BreizhCrops(region=r, **kwargs) for r in ["frh01", "frh02"]] | |
| ) | |
| elif split == "valid": | |
| self.ds = BreizhCrops(region="frh03", **kwargs) | |
| elif split == "test": | |
| self.ds = BreizhCrops(region="frh04", **kwargs) | |
| else: | |
| self.ds = BreizhCrops(region="belle-ile", **kwargs) | |
| self.monthly_average = monthly_average | |
| with (Path(__file__).parents[0] / Path("configs") / Path("breizhcrops.json")).open( | |
| "r" | |
| ) as f: | |
| config = json.load(f) | |
| self.band_info = config["band_info"] | |
| self.norm_operation = norm_operation | |
| self.augmentation = augmentation | |
| warnings.warn("Augmentations ignored for time series") | |
| if partition != "default": | |
| raise NotImplementedError(f"partition {partition} not implemented yet") | |
| def __len__(self): | |
| return len(self.ds) | |
| def __getitem__(self, idx): | |
| x, y_true, _ = self.ds[idx] | |
| if self.monthly_average: | |
| x = self.average_over_month(x) | |
| eo = normalize_bands( | |
| x[:, INPUT_TO_OUTPUT_BAND_MAPPING], self.norm_operation, self.band_info | |
| ) | |
| eo = repeat(eo, "t d -> h w t d", h=1, w=1) | |
| months = x[:, SELECTED_BANDS[LEVEL].index("doa")] | |
| return {"s2": torch.tensor(eo), "months": torch.tensor(months), "target": y_true} | |
| def average_over_month(x: np.ndarray): | |
| x[:, SELECTED_BANDS[LEVEL].index("doa")] = np.array( | |
| [t.month - 1 for t in pd.to_datetime(x[:, SELECTED_BANDS[LEVEL].index("doa")])] | |
| ) | |
| per_month = np.split( | |
| x, np.unique(x[:, SELECTED_BANDS[LEVEL].index("doa")], return_index=True)[1] | |
| )[1:] | |
| return np.array([per_month[idx].mean(axis=0) for idx in range(len(per_month))]) | |
| def raw_transform(input_timeseries): | |
| return input_timeseries | |