File size: 3,593 Bytes
b20c769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from breizhcrops import BreizhCrops
from breizhcrops.datasets.breizhcrops import SELECTED_BANDS
from einops import repeat
from torch.utils.data import ConcatDataset, Dataset

from src.data.config import DATA_FOLDER

from ..preprocess import normalize_bands

LEVEL = "L1C"
DATAPATH = DATA_FOLDER / "breizhcrops"
OUTPUT_BAND_ORDER = [
    "B1",
    "B2",
    "B3",
    "B4",
    "B5",
    "B6",
    "B7",
    "B8",
    "B8A",
    "B9",
    "B10",
    "B11",
    "B12",
]
INPUT_TO_OUTPUT_BAND_MAPPING = [SELECTED_BANDS[LEVEL].index(b) for b in OUTPUT_BAND_ORDER]


class BreizhCropsDataset(Dataset):
    def __init__(
        self,
        path_to_splits: Path,
        split: str,
        norm_operation,
        augmentation,
        partition,
        monthly_average: bool = True,
    ):
        """
        https://isprs-archives.copernicus.org/articles/XLIII-B2-2020/1545/2020/
        isprs-archives-XLIII-B2-2020-1545-2020.pdf

        We partitioned all acquired field parcels
        according to the NUTS-3 regions and suggest to subdivide the
        dataset into training (FRH01, FRH02), validation (FRH03), and
        evaluation (FRH04) subsets based on these spatially distinct
        regions.
        """
        kwargs = {
            "root": path_to_splits,
            "preload_ram": False,
            "level": LEVEL,
            "transform": raw_transform,
        }
        # belle-ille is small, so its useful for testing
        assert split in ["train", "valid", "test", "belle-ile"]
        if split == "train":
            self.ds: Dataset = ConcatDataset(
                [BreizhCrops(region=r, **kwargs) for r in ["frh01", "frh02"]]
            )
        elif split == "valid":
            self.ds = BreizhCrops(region="frh03", **kwargs)
        elif split == "test":
            self.ds = BreizhCrops(region="frh04", **kwargs)
        else:
            self.ds = BreizhCrops(region="belle-ile", **kwargs)
        self.monthly_average = monthly_average

        with (Path(__file__).parents[0] / Path("configs") / Path("breizhcrops.json")).open(
            "r"
        ) as f:
            config = json.load(f)
        self.band_info = config["band_info"]
        self.norm_operation = norm_operation
        self.augmentation = augmentation
        warnings.warn("Augmentations ignored for time series")
        if partition != "default":
            raise NotImplementedError(f"partition {partition} not implemented yet")

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        x, y_true, _ = self.ds[idx]
        if self.monthly_average:
            x = self.average_over_month(x)
        eo = normalize_bands(
            x[:, INPUT_TO_OUTPUT_BAND_MAPPING], self.norm_operation, self.band_info
        )
        eo = repeat(eo, "t d -> h w t d", h=1, w=1)
        months = x[:, SELECTED_BANDS[LEVEL].index("doa")]
        return {"s2": torch.tensor(eo), "months": torch.tensor(months), "target": y_true}

    @staticmethod
    def average_over_month(x: np.ndarray):
        x[:, SELECTED_BANDS[LEVEL].index("doa")] = np.array(
            [t.month - 1 for t in pd.to_datetime(x[:, SELECTED_BANDS[LEVEL].index("doa")])]
        )
        per_month = np.split(
            x, np.unique(x[:, SELECTED_BANDS[LEVEL].index("doa")], return_index=True)[1]
        )[1:]
        return np.array([per_month[idx].mean(axis=0) for idx in range(len(per_month))])


def raw_transform(input_timeseries):
    return input_timeseries