Spaces:
Sleeping
Sleeping
File size: 3,593 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import json
import warnings
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from breizhcrops import BreizhCrops
from breizhcrops.datasets.breizhcrops import SELECTED_BANDS
from einops import repeat
from torch.utils.data import ConcatDataset, Dataset
from src.data.config import DATA_FOLDER
from ..preprocess import normalize_bands
LEVEL = "L1C"
DATAPATH = DATA_FOLDER / "breizhcrops"
OUTPUT_BAND_ORDER = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B10",
"B11",
"B12",
]
INPUT_TO_OUTPUT_BAND_MAPPING = [SELECTED_BANDS[LEVEL].index(b) for b in OUTPUT_BAND_ORDER]
class BreizhCropsDataset(Dataset):
def __init__(
self,
path_to_splits: Path,
split: str,
norm_operation,
augmentation,
partition,
monthly_average: bool = True,
):
"""
https://isprs-archives.copernicus.org/articles/XLIII-B2-2020/1545/2020/
isprs-archives-XLIII-B2-2020-1545-2020.pdf
We partitioned all acquired field parcels
according to the NUTS-3 regions and suggest to subdivide the
dataset into training (FRH01, FRH02), validation (FRH03), and
evaluation (FRH04) subsets based on these spatially distinct
regions.
"""
kwargs = {
"root": path_to_splits,
"preload_ram": False,
"level": LEVEL,
"transform": raw_transform,
}
# belle-ille is small, so its useful for testing
assert split in ["train", "valid", "test", "belle-ile"]
if split == "train":
self.ds: Dataset = ConcatDataset(
[BreizhCrops(region=r, **kwargs) for r in ["frh01", "frh02"]]
)
elif split == "valid":
self.ds = BreizhCrops(region="frh03", **kwargs)
elif split == "test":
self.ds = BreizhCrops(region="frh04", **kwargs)
else:
self.ds = BreizhCrops(region="belle-ile", **kwargs)
self.monthly_average = monthly_average
with (Path(__file__).parents[0] / Path("configs") / Path("breizhcrops.json")).open(
"r"
) as f:
config = json.load(f)
self.band_info = config["band_info"]
self.norm_operation = norm_operation
self.augmentation = augmentation
warnings.warn("Augmentations ignored for time series")
if partition != "default":
raise NotImplementedError(f"partition {partition} not implemented yet")
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
x, y_true, _ = self.ds[idx]
if self.monthly_average:
x = self.average_over_month(x)
eo = normalize_bands(
x[:, INPUT_TO_OUTPUT_BAND_MAPPING], self.norm_operation, self.band_info
)
eo = repeat(eo, "t d -> h w t d", h=1, w=1)
months = x[:, SELECTED_BANDS[LEVEL].index("doa")]
return {"s2": torch.tensor(eo), "months": torch.tensor(months), "target": y_true}
@staticmethod
def average_over_month(x: np.ndarray):
x[:, SELECTED_BANDS[LEVEL].index("doa")] = np.array(
[t.month - 1 for t in pd.to_datetime(x[:, SELECTED_BANDS[LEVEL].index("doa")])]
)
per_month = np.split(
x, np.unique(x[:, SELECTED_BANDS[LEVEL].index("doa")], return_index=True)[1]
)[1:]
return np.array([per_month[idx].mean(axis=0) for idx in range(len(per_month))])
def raw_transform(input_timeseries):
return input_timeseries
|