Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import math | |
| import os | |
| import warnings | |
| from collections import OrderedDict | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from random import sample | |
| from typing import Dict, Iterator, List, NamedTuple, Optional, Sequence, Tuple, Union, cast | |
| from typing import OrderedDict as OrderedDictType | |
| import h5py | |
| import numpy as np | |
| import rioxarray | |
| import torch | |
| import xarray as xr | |
| from einops import rearrange, repeat | |
| from torch.utils.data import Dataset as PyTorchDataset | |
| from tqdm import tqdm | |
| from .config import ( | |
| DATASET_OUTPUT_HW, | |
| EE_BUCKET_TIFS, | |
| EE_FOLDER_H5PYS, | |
| EE_FOLDER_TIFS, | |
| NUM_TIMESTEPS, | |
| ) | |
| from .earthengine.eo import ( | |
| ALL_DYNAMIC_IN_TIME_BANDS, | |
| DW_BANDS, | |
| DW_DIV_VALUES, | |
| DW_SHIFT_VALUES, | |
| ERA5_BANDS, | |
| LANDSCAN_BANDS, | |
| LOCATION_BANDS, | |
| S1_BANDS, | |
| SPACE_BANDS, | |
| SPACE_DIV_VALUES, | |
| SPACE_SHIFT_VALUES, | |
| SRTM_BANDS, | |
| TC_BANDS, | |
| TIME_BANDS, | |
| TIME_DIV_VALUES, | |
| TIME_SHIFT_VALUES, | |
| VIIRS_BANDS, | |
| WC_BANDS, | |
| WC_DIV_VALUES, | |
| WC_SHIFT_VALUES, | |
| ) | |
| from .earthengine.eo import SPACE_TIME_BANDS as EO_SPACE_TIME_BANDS | |
| from .earthengine.eo import SPACE_TIME_DIV_VALUES as EO_SPACE_TIME_DIV_VALUES | |
| from .earthengine.eo import SPACE_TIME_SHIFT_VALUES as EO_SPACE_TIME_SHIFT_VALUES | |
| from .earthengine.eo import STATIC_BANDS as EO_STATIC_BANDS | |
| from .earthengine.eo import STATIC_DIV_VALUES as EO_STATIC_DIV_VALUES | |
| from .earthengine.eo import STATIC_SHIFT_VALUES as EO_STATIC_SHIFT_VALUES | |
| logger = logging.getLogger("__main__") | |
| EO_DYNAMIC_IN_TIME_BANDS_NP = np.array(EO_SPACE_TIME_BANDS + TIME_BANDS) | |
| SPACE_TIME_BANDS = EO_SPACE_TIME_BANDS + ["NDVI"] | |
| SPACE_TIME_SHIFT_VALUES = np.append(EO_SPACE_TIME_SHIFT_VALUES, [0]) | |
| SPACE_TIME_DIV_VALUES = np.append(EO_SPACE_TIME_DIV_VALUES, [1]) | |
| SPACE_TIME_BANDS_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( | |
| { | |
| "S1": [SPACE_TIME_BANDS.index(b) for b in S1_BANDS], | |
| "S2_RGB": [SPACE_TIME_BANDS.index(b) for b in ["B2", "B3", "B4"]], | |
| "S2_Red_Edge": [SPACE_TIME_BANDS.index(b) for b in ["B5", "B6", "B7"]], | |
| "S2_NIR_10m": [SPACE_TIME_BANDS.index(b) for b in ["B8"]], | |
| "S2_NIR_20m": [SPACE_TIME_BANDS.index(b) for b in ["B8A"]], | |
| "S2_SWIR": [SPACE_TIME_BANDS.index(b) for b in ["B11", "B12"]], | |
| "NDVI": [SPACE_TIME_BANDS.index("NDVI")], | |
| } | |
| ) | |
| TIME_BAND_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( | |
| { | |
| "ERA5": [TIME_BANDS.index(b) for b in ERA5_BANDS], | |
| "TC": [TIME_BANDS.index(b) for b in TC_BANDS], | |
| "VIIRS": [TIME_BANDS.index(b) for b in VIIRS_BANDS], | |
| } | |
| ) | |
| SPACE_BAND_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( | |
| { | |
| "SRTM": [SPACE_BANDS.index(b) for b in SRTM_BANDS], | |
| "DW": [SPACE_BANDS.index(b) for b in DW_BANDS], | |
| "WC": [SPACE_BANDS.index(b) for b in WC_BANDS], | |
| } | |
| ) | |
| STATIC_DW_BANDS = [f"{x}_static" for x in DW_BANDS] | |
| STATIC_WC_BANDS = [f"{x}_static" for x in WC_BANDS] | |
| STATIC_BANDS = EO_STATIC_BANDS + STATIC_DW_BANDS + STATIC_WC_BANDS | |
| STATIC_DIV_VALUES = np.append(EO_STATIC_DIV_VALUES, (DW_DIV_VALUES + WC_DIV_VALUES)) | |
| STATIC_SHIFT_VALUES = np.append(EO_STATIC_SHIFT_VALUES, (DW_SHIFT_VALUES + WC_SHIFT_VALUES)) | |
| STATIC_BAND_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict( | |
| { | |
| "LS": [STATIC_BANDS.index(b) for b in LANDSCAN_BANDS], | |
| "location": [STATIC_BANDS.index(b) for b in LOCATION_BANDS], | |
| "DW_static": [STATIC_BANDS.index(b) for b in STATIC_DW_BANDS], | |
| "WC_static": [STATIC_BANDS.index(b) for b in STATIC_WC_BANDS], | |
| } | |
| ) | |
| # if this changes the normalizer will need to index against something else | |
| assert len(SPACE_TIME_BANDS) != len(SPACE_BANDS) != len(TIME_BANDS) != len(STATIC_BANDS) | |
| class Normalizer: | |
| # these are the bands we will replace with the 2*std computation | |
| # if std = True | |
| std_bands: Dict[int, list] = { | |
| len(SPACE_TIME_BANDS): [b for b in SPACE_TIME_BANDS if b != "NDVI"], | |
| len(SPACE_BANDS): SRTM_BANDS, | |
| len(TIME_BANDS): TIME_BANDS, | |
| len(STATIC_BANDS): LANDSCAN_BANDS, | |
| } | |
| def __init__( | |
| self, std: bool = True, normalizing_dicts: Optional[Dict] = None, std_multiplier: float = 2 | |
| ): | |
| self.shift_div_dict = { | |
| len(SPACE_TIME_BANDS): { | |
| "shift": deepcopy(SPACE_TIME_SHIFT_VALUES), | |
| "div": deepcopy(SPACE_TIME_DIV_VALUES), | |
| }, | |
| len(SPACE_BANDS): { | |
| "shift": deepcopy(SPACE_SHIFT_VALUES), | |
| "div": deepcopy(SPACE_DIV_VALUES), | |
| }, | |
| len(TIME_BANDS): { | |
| "shift": deepcopy(TIME_SHIFT_VALUES), | |
| "div": deepcopy(TIME_DIV_VALUES), | |
| }, | |
| len(STATIC_BANDS): { | |
| "shift": deepcopy(STATIC_SHIFT_VALUES), | |
| "div": deepcopy(STATIC_DIV_VALUES), | |
| }, | |
| } | |
| print(self.shift_div_dict.keys()) | |
| self.normalizing_dicts = normalizing_dicts | |
| if std: | |
| name_to_bands = { | |
| len(SPACE_TIME_BANDS): SPACE_TIME_BANDS, | |
| len(SPACE_BANDS): SPACE_BANDS, | |
| len(TIME_BANDS): TIME_BANDS, | |
| len(STATIC_BANDS): STATIC_BANDS, | |
| } | |
| assert normalizing_dicts is not None | |
| for key, val in normalizing_dicts.items(): | |
| if isinstance(key, str): | |
| continue | |
| bands_to_replace = self.std_bands[key] | |
| for band in bands_to_replace: | |
| band_idx = name_to_bands[key].index(band) | |
| mean = val["mean"][band_idx] | |
| std = val["std"][band_idx] | |
| min_value = mean - (std_multiplier * std) | |
| max_value = mean + (std_multiplier * std) | |
| div = max_value - min_value | |
| if div == 0: | |
| raise ValueError(f"{band} has div value of 0") | |
| self.shift_div_dict[key]["shift"][band_idx] = min_value | |
| self.shift_div_dict[key]["div"][band_idx] = div | |
| def _normalize(x: np.ndarray, shift_values: np.ndarray, div_values: np.ndarray) -> np.ndarray: | |
| x = (x - shift_values) / div_values | |
| return x | |
| def __call__(self, x: np.ndarray): | |
| div_values = self.shift_div_dict[x.shape[-1]]["div"] | |
| return self._normalize(x, self.shift_div_dict[x.shape[-1]]["shift"], div_values) | |
| class DatasetOutput(NamedTuple): | |
| space_time_x: np.ndarray | |
| space_x: np.ndarray | |
| time_x: np.ndarray | |
| static_x: np.ndarray | |
| months: np.ndarray | |
| def concatenate(cls, datasetoutputs: Sequence["DatasetOutput"]) -> "DatasetOutput": | |
| s_t_x = np.stack([o.space_time_x for o in datasetoutputs], axis=0) | |
| sp_x = np.stack([o.space_x for o in datasetoutputs], axis=0) | |
| t_x = np.stack([o.time_x for o in datasetoutputs], axis=0) | |
| st_x = np.stack([o.static_x for o in datasetoutputs], axis=0) | |
| months = np.stack([o.months for o in datasetoutputs], axis=0) | |
| return cls(s_t_x, sp_x, t_x, st_x, months) | |
| def normalize(self, normalizer: Optional[Normalizer]) -> "DatasetOutput": | |
| if normalizer is None: | |
| return self | |
| return DatasetOutput( | |
| normalizer(self.space_time_x).astype(np.half), | |
| normalizer(self.space_x).astype(np.half), | |
| normalizer(self.time_x).astype(np.half), | |
| normalizer(self.static_x).astype(np.half), | |
| self.months, | |
| ) | |
| def in_pixel_batches(self, batch_size: int, window_size: int) -> Iterator["DatasetOutput"]: | |
| if self.space_time_x.shape[0] % window_size != 0: | |
| raise ValueError("DatasetOutput height must be divisible by the patch size") | |
| if self.space_time_x.shape[1] % window_size != 0: | |
| raise ValueError("DatasetOutput width must be divisible by the patch size") | |
| # how many batches from the height dimension, how many from the width dimension? | |
| h_b = self.space_time_x.shape[0] // window_size | |
| w_b = self.space_time_x.shape[1] // window_size | |
| flat_s_t_x = rearrange( | |
| self.space_time_x, | |
| "(h_b h) (w_b w) t d -> (h_b w_b) h w t d", | |
| h=window_size, | |
| w=window_size, | |
| h_b=h_b, | |
| w_b=w_b, | |
| ) | |
| flat_sp_x = rearrange( | |
| self.space_x, | |
| "(h_b h) (w_b w) d -> (h_b w_b) h w d", | |
| h=window_size, | |
| w=window_size, | |
| h_b=h_b, | |
| w_b=w_b, | |
| ) | |
| # static in space modalities will just get repeated per batch | |
| cur_idx = 0 | |
| while cur_idx < flat_s_t_x.shape[0]: | |
| cur_idx_s_t_x = flat_s_t_x[cur_idx : cur_idx + batch_size].copy() | |
| b = cur_idx_s_t_x.shape[0] | |
| yield DatasetOutput( | |
| space_time_x=cur_idx_s_t_x, | |
| space_x=flat_sp_x[cur_idx : cur_idx + batch_size].copy(), | |
| time_x=repeat(self.time_x, "t d -> b t d", b=b), | |
| static_x=repeat(self.static_x, "d -> b d", b=b), | |
| months=repeat(self.months, "t -> b t", b=b), | |
| ) | |
| cur_idx += batch_size | |
| class ListOfDatasetOutputs(NamedTuple): | |
| space_time_x: List[np.ndarray] | |
| space_x: List[np.ndarray] | |
| time_x: List[np.ndarray] | |
| static_x: List[np.ndarray] | |
| months: List[np.ndarray] | |
| def to_datasetoutput(self) -> DatasetOutput: | |
| return DatasetOutput( | |
| np.stack(self.space_time_x, axis=0), | |
| np.stack(self.space_x, axis=0), | |
| np.stack(self.time_x, axis=0), | |
| np.stack(self.static_x, axis=0), | |
| np.stack(self.months, axis=0), | |
| ) | |
| def to_cartesian( | |
| lat: Union[float, np.ndarray, torch.Tensor], lon: Union[float, np.ndarray, torch.Tensor] | |
| ) -> Union[np.ndarray, torch.Tensor]: | |
| if isinstance(lat, float): | |
| assert -90 <= lat <= 90, f"lat out of range ({lat}). Make sure you are in EPSG:4326" | |
| assert -180 <= lon <= 180, f"lon out of range ({lon}). Make sure you are in EPSG:4326" | |
| assert isinstance(lon, float), f"Expected float got {type(lon)}" | |
| # transform to radians | |
| lat = lat * math.pi / 180 | |
| lon = lon * math.pi / 180 | |
| x = math.cos(lat) * math.cos(lon) | |
| y = math.cos(lat) * math.sin(lon) | |
| z = math.sin(lat) | |
| return np.array([x, y, z]) | |
| elif isinstance(lon, np.ndarray): | |
| assert -90 <= lat.min(), f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326" | |
| assert 90 >= lat.max(), f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326" | |
| assert -180 <= lon.min(), f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326" | |
| assert 180 >= lon.max(), f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326" | |
| assert isinstance(lat, np.ndarray), f"Expected np.ndarray got {type(lat)}" | |
| # transform to radians | |
| lat = lat * math.pi / 180 | |
| lon = lon * math.pi / 180 | |
| x_np = np.cos(lat) * np.cos(lon) | |
| y_np = np.cos(lat) * np.sin(lon) | |
| z_np = np.sin(lat) | |
| return np.stack([x_np, y_np, z_np], axis=-1) | |
| elif isinstance(lon, torch.Tensor): | |
| assert -90 <= lat.min(), f"lat out of range ({lat.min()}). Make sure you are in EPSG:4326" | |
| assert 90 >= lat.max(), f"lat out of range ({lat.max()}). Make sure you are in EPSG:4326" | |
| assert -180 <= lon.min(), f"lon out of range ({lon.min()}). Make sure you are in EPSG:4326" | |
| assert 180 >= lon.max(), f"lon out of range ({lon.max()}). Make sure you are in EPSG:4326" | |
| assert isinstance(lat, torch.Tensor), f"Expected torch.Tensor got {type(lat)}" | |
| # transform to radians | |
| lat = lat * math.pi / 180 | |
| lon = lon * math.pi / 180 | |
| x_t = torch.cos(lat) * torch.cos(lon) | |
| y_t = torch.cos(lat) * torch.sin(lon) | |
| z_t = torch.sin(lat) | |
| return torch.stack([x_t, y_t, z_t], dim=-1) | |
| else: | |
| raise AssertionError(f"Unexpected input type {type(lon)}") | |
| class Dataset(PyTorchDataset): | |
| def __init__( | |
| self, | |
| data_folder: Path, | |
| download: bool = True, | |
| h5py_folder: Optional[Path] = None, | |
| h5pys_only: bool = False, | |
| output_hw: int = DATASET_OUTPUT_HW, | |
| output_timesteps: int = NUM_TIMESTEPS, | |
| normalizer: Optional[Normalizer] = None, | |
| ): | |
| self.data_folder = data_folder | |
| self.h5pys_only = h5pys_only | |
| self.h5py_folder = h5py_folder | |
| self.cache = False | |
| self.normalizer = normalizer | |
| if h5py_folder is not None: | |
| self.cache = True | |
| if h5pys_only: | |
| assert h5py_folder is not None, "Can't use h5pys only if there is no cache folder" | |
| self.tifs: List[Path] = [] | |
| if download: | |
| self.download_h5pys(h5py_folder) | |
| self.h5pys = list(h5py_folder.glob("*.h5")) | |
| else: | |
| if download: | |
| self.download_tifs(data_folder) | |
| self.tifs = [] | |
| tifs = list(data_folder.glob("*.tif")) + list(data_folder.glob("*.tiff")) | |
| for tif in tifs: | |
| try: | |
| _ = self.start_month_from_file(tif) | |
| self.tifs.append(tif) | |
| except IndexError: | |
| warnings.warn(f"IndexError for {tif}") | |
| self.h5pys = [] | |
| self.output_hw = output_hw | |
| self.output_timesteps = output_timesteps | |
| def __len__(self) -> int: | |
| if self.h5pys_only: | |
| return len(self.h5pys) | |
| return len(self.tifs) | |
| def download_tifs(data_folder): | |
| # Download files (faster than using Python API) | |
| os.system(f"gcloud storage rsync -r gs://{EE_BUCKET_TIFS}/{EE_FOLDER_TIFS} {data_folder}") | |
| def download_h5pys(data_folder): | |
| # Download files (faster than using Python API) | |
| os.system(f"gcloud storage rsync -r gs://{EE_BUCKET_TIFS}/{EE_FOLDER_H5PYS} {data_folder}") | |
| def return_subset_indices( | |
| total_h, | |
| total_w, | |
| total_t, | |
| size: int, | |
| num_timesteps: int, | |
| ) -> Tuple[int, int, int]: | |
| """ | |
| space_time_x: array of shape [H, W, T, D] | |
| space_x: array of shape [H, W, D] | |
| time_x: array of shape [T, D] | |
| static_x: array of shape [D] | |
| size must be greater or equal to H & W | |
| """ | |
| possible_h = total_h - size | |
| possible_w = total_w - size | |
| assert (possible_h >= 0) & (possible_w >= 0) | |
| possible_t = total_t - num_timesteps | |
| assert possible_t >= 0 | |
| if possible_h > 0: | |
| start_h = np.random.choice(possible_h) | |
| else: | |
| start_h = possible_h | |
| if possible_w > 0: | |
| start_w = np.random.choice(possible_w) | |
| else: | |
| start_w = possible_w | |
| if possible_t > 0: | |
| start_t = np.random.choice(possible_t) | |
| else: | |
| start_t = possible_t | |
| return start_h, start_w, start_t | |
| def subset_image( | |
| space_time_x: np.ndarray, | |
| space_x: np.ndarray, | |
| time_x: np.ndarray, | |
| static_x: np.ndarray, | |
| months: np.ndarray, | |
| size: int, | |
| num_timesteps: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| space_time_x: array of shape [H, W, T, D] | |
| space_x: array of shape [H, W, D] | |
| time_x: array of shape [T, D] | |
| static_x: array of shape [D] | |
| size must be greater or equal to H & W | |
| """ | |
| assert (space_time_x.shape[0] == space_x.shape[0]) & ( | |
| space_time_x.shape[1] == space_x.shape[1] | |
| ) | |
| assert space_time_x.shape[2] == time_x.shape[0] | |
| possible_h = space_time_x.shape[0] - size | |
| possible_w = space_time_x.shape[1] - size | |
| assert (possible_h >= 0) & (possible_w >= 0) | |
| possible_t = space_time_x.shape[2] - num_timesteps | |
| assert possible_t >= 0 | |
| if possible_h > 0: | |
| start_h = np.random.choice(possible_h) | |
| else: | |
| start_h = possible_h | |
| if possible_w > 0: | |
| start_w = np.random.choice(possible_w) | |
| else: | |
| start_w = possible_w | |
| if possible_t > 0: | |
| start_t = np.random.choice(possible_t) | |
| else: | |
| start_t = possible_t | |
| return ( | |
| space_time_x[ | |
| start_h : start_h + size, | |
| start_w : start_w + size, | |
| start_t : start_t + num_timesteps, | |
| ], | |
| space_x[start_h : start_h + size, start_w : start_w + size], | |
| time_x[start_t : start_t + num_timesteps], | |
| static_x, | |
| months[start_t : start_t + num_timesteps], | |
| ) | |
| def _fillna(data: np.ndarray, bands_np: np.ndarray): | |
| """Fill in the missing values in the data array""" | |
| if data.shape[-1] != len(bands_np): | |
| raise ValueError(f"Expected data to have {len(bands_np)} bands - got {data.shape[-1]}") | |
| is_nan_inf = np.isnan(data) | np.isinf(data) | |
| if not is_nan_inf.any(): | |
| return data | |
| if len(data.shape) <= 2: | |
| return np.nan_to_num(data, nan=0) | |
| if len(data.shape) == 3: | |
| has_time = False | |
| elif len(data.shape) == 4: | |
| has_time = True | |
| else: | |
| raise ValueError( | |
| f"Expected data to be 3D or 4D (x, y, (time), band) - got {data.shape}" | |
| ) | |
| # treat infinities as NaNs | |
| data = np.nan_to_num(data, nan=np.nan, posinf=np.nan, neginf=np.nan) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore", category=RuntimeWarning) | |
| mean_per_time_band = np.nanmean(data, axis=(0, 1)) # t, b or b | |
| mean_per_time_band = np.nan_to_num(mean_per_time_band, nan=0, posinf=0, neginf=0) | |
| assert not (np.isnan(mean_per_time_band).any() | np.isinf(mean_per_time_band).any()) | |
| if is_nan_inf.any(): | |
| if has_time: | |
| means_to_fill = ( | |
| repeat( | |
| np.nanmean(mean_per_time_band, axis=0), | |
| "b -> h w t b", | |
| h=data.shape[0], | |
| w=data.shape[1], | |
| t=data.shape[2], | |
| ) | |
| * is_nan_inf | |
| ) | |
| else: | |
| means_to_fill = ( | |
| repeat(mean_per_time_band, "b -> h w b", h=data.shape[0], w=data.shape[1]) | |
| * is_nan_inf | |
| ) | |
| data = np.nan_to_num(data, nan=0, posinf=0, neginf=0) + means_to_fill | |
| return data | |
| def tif_to_h5py_path(self, tif_path: Path) -> Path: | |
| assert self.h5py_folder is not None | |
| tif_name = tif_path.stem | |
| return self.h5py_folder / f"{tif_name}.h5" | |
| def start_month_from_file(tif_path: Path) -> int: | |
| start_date = tif_path.name.partition("dates=")[2][:10] | |
| start_month = int(start_date.split("-")[1]) | |
| return start_month | |
| def month_array_from_file(cls, tif_path: Path, num_timesteps: int) -> np.ndarray: | |
| """ | |
| Given a filepath and num_timesteps, extract start_month and return an array of | |
| months where months[idx] is the month for list(range(num_timesteps))[i] | |
| """ | |
| # assumes all files are exported with filenames including: | |
| # *dates=<start_date>*, where the start_date is in a YYYY-MM-dd format | |
| start_month = cls.start_month_from_file(tif_path) | |
| # >>> np.fmod(np.array([9., 10, 11, 12, 13, 14]), 12) | |
| # array([ 9., 10., 11., 0., 1., 2.]) | |
| # - 1 because we want to index from 0 | |
| return np.fmod(np.arange(start_month - 1, start_month - 1 + num_timesteps), 12) | |
| def _tif_to_array(cls, tif_path: Path) -> DatasetOutput: | |
| with cast(xr.Dataset, rioxarray.open_rasterio(tif_path)) as data: | |
| # [all_combined_bands, H, W] | |
| # all_combined_bands includes all dynamic-in-time bands | |
| # interleaved for all timesteps | |
| # followed by the static-in-time bands | |
| values = cast(np.ndarray, data.values) | |
| lon = np.mean(cast(np.ndarray, data.x)).item() | |
| lat = np.mean(cast(np.ndarray, data.y)).item() | |
| # this is a bit hackey but is a unique edge case for locations, | |
| # which are not part of the exported bands but are instead | |
| # computed here | |
| static_bands_in_tif = len(EO_STATIC_BANDS) - len(LOCATION_BANDS) | |
| num_timesteps = (values.shape[0] - len(SPACE_BANDS) - static_bands_in_tif) / len( | |
| ALL_DYNAMIC_IN_TIME_BANDS | |
| ) | |
| assert num_timesteps % 1 == 0, f"{tif_path} has incorrect number of channels" | |
| dynamic_in_time_x = rearrange( | |
| values[: -(len(SPACE_BANDS) + static_bands_in_tif)], | |
| "(t c) h w -> h w t c", | |
| c=len(ALL_DYNAMIC_IN_TIME_BANDS), | |
| t=int(num_timesteps), | |
| ) | |
| dynamic_in_time_x = cls._fillna(dynamic_in_time_x, EO_DYNAMIC_IN_TIME_BANDS_NP) | |
| space_time_x = dynamic_in_time_x[:, :, :, : -len(TIME_BANDS)] | |
| # calculate indices, which have shape [h, w, t, 1] | |
| ndvi = cls.calculate_ndi(space_time_x, band_1="B8", band_2="B4") | |
| space_time_x = np.concatenate((space_time_x, ndvi), axis=-1) | |
| time_x = dynamic_in_time_x[:, :, :, -len(TIME_BANDS) :] | |
| time_x = np.nanmean(time_x, axis=(0, 1)) | |
| space_x = rearrange( | |
| values[-(len(SPACE_BANDS) + static_bands_in_tif) : -static_bands_in_tif], | |
| "c h w -> h w c", | |
| ) | |
| space_x = cls._fillna(space_x, np.array(SPACE_BANDS)) | |
| static_x = values[-static_bands_in_tif:] | |
| # add DW_STATIC and WC_STATIC | |
| dw_bands = space_x[:, :, [i for i, v in enumerate(SPACE_BANDS) if v in DW_BANDS]] | |
| wc_bands = space_x[:, :, [i for i, v in enumerate(SPACE_BANDS) if v in WC_BANDS]] | |
| static_x = np.concatenate( | |
| [ | |
| np.nanmean(static_x, axis=(1, 2)), | |
| to_cartesian(lat, lon), | |
| np.nanmean(dw_bands, axis=(0, 1)), | |
| np.nanmean(wc_bands, axis=(0, 1)), | |
| ] | |
| ) | |
| static_x = cls._fillna(static_x, np.array(STATIC_BANDS)) | |
| months = cls.month_array_from_file(tif_path, int(num_timesteps)) | |
| try: | |
| assert not np.isnan(space_time_x).any(), f"NaNs in s_t_x for {tif_path}" | |
| assert not np.isnan(space_x).any(), f"NaNs in sp_x for {tif_path}" | |
| assert not np.isnan(time_x).any(), f"NaNs in t_x for {tif_path}" | |
| assert not np.isnan(static_x).any(), f"NaNs in st_x for {tif_path}" | |
| assert not np.isinf(space_time_x).any(), f"Infs in s_t_x for {tif_path}" | |
| assert not np.isinf(space_x).any(), f"Infs in sp_x for {tif_path}" | |
| assert not np.isinf(time_x).any(), f"Infs in t_x for {tif_path}" | |
| assert not np.isinf(static_x).any(), f"Infs in st_x for {tif_path}" | |
| return DatasetOutput( | |
| space_time_x.astype(np.half), | |
| space_x.astype(np.half), | |
| time_x.astype(np.half), | |
| static_x.astype(np.half), | |
| months, | |
| ) | |
| except AssertionError as e: | |
| raise e | |
| def _tif_to_array_with_checks(self, idx): | |
| tif_path = self.tifs[idx] | |
| try: | |
| output = self._tif_to_array(tif_path) | |
| return output | |
| except Exception as e: | |
| print(f"Replacing tif {tif_path} due to {e}") | |
| if idx == 0: | |
| new_idx = idx + 1 | |
| else: | |
| new_idx = idx - 1 | |
| self.tifs[idx] = self.tifs[new_idx] | |
| tif_path = self.tifs[idx] | |
| output = self._tif_to_array(tif_path) | |
| return output | |
| def load_tif(self, idx: int) -> DatasetOutput: | |
| if self.h5py_folder is None: | |
| s_t_x, sp_x, t_x, st_x, months = self._tif_to_array_with_checks(idx) | |
| return DatasetOutput( | |
| *self.subset_image( | |
| s_t_x, | |
| sp_x, | |
| t_x, | |
| st_x, | |
| months, | |
| size=self.output_hw, | |
| num_timesteps=self.output_timesteps, | |
| ) | |
| ) | |
| else: | |
| h5py_path = self.tif_to_h5py_path(self.tifs[idx]) | |
| if h5py_path.exists(): | |
| try: | |
| return self.read_and_slice_h5py_file(h5py_path) | |
| except Exception as e: | |
| logger.warn(f"Exception {e} for {self.tifs[idx]}") | |
| h5py_path.unlink() | |
| s_t_x, sp_x, t_x, st_x, months = self._tif_to_array_with_checks(idx) | |
| self.save_h5py(s_t_x, sp_x, t_x, st_x, self.tifs[idx].stem) | |
| return DatasetOutput( | |
| *self.subset_image( | |
| s_t_x, sp_x, t_x, st_x, months, self.output_hw, self.output_timesteps | |
| ) | |
| ) | |
| else: | |
| s_t_x, sp_x, t_x, st_x, months = self._tif_to_array_with_checks(idx) | |
| self.save_h5py(s_t_x, sp_x, t_x, st_x, self.tifs[idx].stem) | |
| return DatasetOutput( | |
| *self.subset_image( | |
| s_t_x, sp_x, t_x, st_x, months, self.output_hw, self.output_timesteps | |
| ) | |
| ) | |
| def save_h5py(self, s_t_x, sp_x, t_x, st_x, tif_stem): | |
| assert self.h5py_folder is not None | |
| with h5py.File(self.h5py_folder / f"{tif_stem}.h5", "w") as hf: | |
| hf.create_dataset("s_t_x", data=s_t_x) | |
| hf.create_dataset("sp_x", data=sp_x) | |
| hf.create_dataset("t_x", data=t_x) | |
| hf.create_dataset("st_x", data=st_x) | |
| def calculate_ndi(input_array: np.ndarray, band_1: str, band_2: str) -> np.ndarray: | |
| r""" | |
| Given an input array of shape [h, w, t, bands] | |
| where bands == len(EO_DYNAMIC_IN_TIME_BANDS_NP), returns an array of shape | |
| [h, w, t, 1] representing NDI, | |
| (band_1 - band_2) / (band_1 + band_2) | |
| """ | |
| band_1_np = input_array[:, :, :, EO_SPACE_TIME_BANDS.index(band_1)] | |
| band_2_np = input_array[:, :, :, EO_SPACE_TIME_BANDS.index(band_2)] | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", message="invalid value encountered in divide") | |
| # suppress the following warning | |
| # RuntimeWarning: invalid value encountered in divide | |
| # for cases where near_infrared + red == 0 | |
| # since this is handled in the where condition | |
| return np.expand_dims( | |
| np.where( | |
| (band_1_np + band_2_np) > 0, | |
| (band_1_np - band_2_np) / (band_1_np + band_2_np), | |
| 0, | |
| ), | |
| -1, | |
| ) | |
| def read_and_slice_h5py_file(self, h5py_path: Path): | |
| with h5py.File(h5py_path, "r") as hf: | |
| h, w, t, _ = hf["s_t_x"].shape | |
| start_h, start_w, start_t = self.return_subset_indices( | |
| h, w, t, self.output_hw, self.output_timesteps | |
| ) | |
| months = self.month_array_from_file(h5py_path, t) | |
| output = DatasetOutput( | |
| hf["s_t_x"][ | |
| start_h : start_h + self.output_hw, | |
| start_w : start_w + self.output_hw, | |
| start_t : start_t + self.output_timesteps, | |
| ], | |
| hf["sp_x"][ | |
| start_h : start_h + self.output_hw, | |
| start_w : start_w + self.output_hw, | |
| ], | |
| hf["t_x"][start_t : start_t + self.output_timesteps], | |
| hf["st_x"][:], | |
| months[start_t : start_t + self.output_timesteps], | |
| ) | |
| return output | |
| def __getitem__(self, idx): | |
| if self.h5pys_only: | |
| return self.read_and_slice_h5py_file(self.h5pys[idx]).normalize(self.normalizer) | |
| else: | |
| return self.load_tif(idx).normalize(self.normalizer) | |
| def process_h5pys(self): | |
| # iterate through the dataset and save it all as h5pys | |
| assert self.h5py_folder is not None | |
| assert not self.h5pys_only | |
| assert self.cache | |
| for i in tqdm(range(len(self))): | |
| # loading the tifs also saves them | |
| # if they don't exist | |
| _ = self[i] | |
| def load_normalization_values(path: Path): | |
| if not path.exists(): | |
| raise ValueError(f"No file found at path {path}") | |
| with path.open("r") as f: | |
| norm_dict = json.load(f) | |
| # we computed the normalizing dict using the same datset | |
| output_dict = {} | |
| for key, val in norm_dict.items(): | |
| if "n" not in key: | |
| output_dict[int(key)] = val | |
| else: | |
| output_dict[key] = val | |
| return output_dict | |
| def compute_normalization_values( | |
| self, | |
| output_hw: int = 96, | |
| output_timesteps: int = 24, | |
| estimate_from: Optional[int] = 10000, | |
| ): | |
| org_hw = self.output_hw | |
| self.output_hw = output_hw | |
| org_t = self.output_timesteps | |
| self.output_timesteps = output_timesteps | |
| if estimate_from is not None: | |
| indices_to_sample = sample(list(range(len(self))), k=estimate_from) | |
| else: | |
| indices_to_sample = list(range(len(self))) | |
| output = ListOfDatasetOutputs([], [], [], [], []) | |
| for i in tqdm(indices_to_sample): | |
| s_t_x, sp_x, t_x, st_x, months = self[i] | |
| output.space_time_x.append(s_t_x.astype(np.float64)) | |
| output.space_x.append(sp_x.astype(np.float64)) | |
| output.time_x.append(t_x.astype(np.float64)) | |
| output.static_x.append(st_x.astype(np.float64)) | |
| output.months.append(months) | |
| d_o = output.to_datasetoutput() | |
| norm_dict = { | |
| "total_n": len(self), | |
| "sampled_n": len(indices_to_sample), | |
| len(SPACE_TIME_BANDS): { | |
| "mean": d_o.space_time_x.mean(axis=(0, 1, 2, 3)).tolist(), | |
| "std": d_o.space_time_x.std(axis=(0, 1, 2, 3)).tolist(), | |
| }, | |
| len(SPACE_BANDS): { | |
| "mean": d_o.space_x.mean(axis=(0, 1, 2)).tolist(), | |
| "std": d_o.space_x.std(axis=(0, 1, 2)).tolist(), | |
| }, | |
| len(TIME_BANDS): { | |
| "mean": d_o.time_x.mean(axis=(0, 1)).tolist(), | |
| "std": d_o.time_x.std(axis=(0, 1)).tolist(), | |
| }, | |
| len(STATIC_BANDS): { | |
| "mean": d_o.static_x.mean(axis=0).tolist(), | |
| "std": d_o.static_x.std(axis=0).tolist(), | |
| }, | |
| } | |
| self.output_hw = org_hw | |
| self.output_timesteps = org_t | |
| return norm_dict | |