Spaces:
Sleeping
Sleeping
| from typing import cast | |
| import torch | |
| from src.data import ( | |
| SPACE_BAND_GROUPS_IDX, | |
| SPACE_TIME_BANDS_GROUPS_IDX, | |
| STATIC_BAND_GROUPS_IDX, | |
| TIME_BAND_GROUPS_IDX, | |
| ) | |
| from src.data.dataset import ( | |
| SPACE_BANDS, | |
| SPACE_TIME_BANDS, | |
| STATIC_BANDS, | |
| TIME_BANDS, | |
| Normalizer, | |
| to_cartesian, | |
| ) | |
| from src.data.earthengine.eo import ( | |
| DW_BANDS, | |
| ERA5_BANDS, | |
| LANDSCAN_BANDS, | |
| LOCATION_BANDS, | |
| S1_BANDS, | |
| S2_BANDS, | |
| SRTM_BANDS, | |
| TC_BANDS, | |
| VIIRS_BANDS, | |
| WC_BANDS, | |
| ) | |
| from src.masking import MaskedOutput | |
| DEFAULT_MONTH = 5 | |
| def construct_galileo_input( | |
| s1: torch.Tensor | None = None, # [H, W, T, D] | |
| s2: torch.Tensor | None = None, # [H, W, T, D] | |
| era5: torch.Tensor | None = None, # [T, D] | |
| tc: torch.Tensor | None = None, # [T, D] | |
| viirs: torch.Tensor | None = None, # [T, D] | |
| srtm: torch.Tensor | None = None, # [H, W, D] | |
| dw: torch.Tensor | None = None, # [H, W, D] | |
| wc: torch.Tensor | None = None, # [H, W, D] | |
| landscan: torch.Tensor | None = None, # [D] | |
| latlon: torch.Tensor | None = None, # [D] | |
| months: torch.Tensor | None = None, # [T] | |
| normalize: bool = False, | |
| ): | |
| space_time_inputs = [s1, s2] | |
| time_inputs = [era5, tc, viirs] | |
| space_inputs = [srtm, dw, wc] | |
| static_inputs = [landscan, latlon] | |
| devices = [ | |
| x.device | |
| for x in space_time_inputs + time_inputs + space_inputs + static_inputs | |
| if x is not None | |
| ] | |
| if len(devices) == 0: | |
| raise ValueError("At least one input must be not None") | |
| if not all(devices[0] == device for device in devices): | |
| raise ValueError("Received tensors on multiple devices") | |
| device = devices[0] | |
| # first, check all the input shapes are consistent | |
| timesteps_list = [x.shape[2] for x in space_time_inputs if x is not None] + [ | |
| x.shape[1] for x in time_inputs if x is not None | |
| ] | |
| height_list = [x.shape[0] for x in space_time_inputs if x is not None] + [ | |
| x.shape[0] for x in space_inputs if x is not None | |
| ] | |
| width_list = [x.shape[1] for x in space_time_inputs if x is not None] + [ | |
| x.shape[1] for x in space_inputs if x is not None | |
| ] | |
| if len(timesteps_list) > 0: | |
| if not all(timesteps_list[0] == timestep for timestep in timesteps_list): | |
| raise ValueError("Inconsistent number of timesteps per input") | |
| t = timesteps_list[0] | |
| else: | |
| t = 1 | |
| if len(height_list) > 0: | |
| if not all(height_list[0] == height for height in height_list): | |
| raise ValueError("Inconsistent heights per input") | |
| if not all(width_list[0] == width for width in width_list): | |
| raise ValueError("Inconsistent widths per input") | |
| h = height_list[0] | |
| w = width_list[0] | |
| else: | |
| h, w = 1, 1 | |
| # now, we can construct our empty input tensors. By default, everything is masked | |
| s_t_x = torch.zeros((h, w, t, len(SPACE_TIME_BANDS)), dtype=torch.float, device=device) | |
| s_t_m = torch.ones( | |
| (h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), dtype=torch.float, device=device | |
| ) | |
| sp_x = torch.zeros((h, w, len(SPACE_BANDS)), dtype=torch.float, device=device) | |
| sp_m = torch.ones((h, w, len(SPACE_BAND_GROUPS_IDX)), dtype=torch.float, device=device) | |
| t_x = torch.zeros((t, len(TIME_BANDS)), dtype=torch.float, device=device) | |
| t_m = torch.ones((t, len(TIME_BAND_GROUPS_IDX)), dtype=torch.float, device=device) | |
| st_x = torch.zeros((len(STATIC_BANDS)), dtype=torch.float, device=device) | |
| st_m = torch.ones((len(STATIC_BAND_GROUPS_IDX)), dtype=torch.float, device=device) | |
| for x, bands_list, group_key in zip([s1, s2], [S1_BANDS, S2_BANDS], ["S1", "S2"]): | |
| if x is not None: | |
| indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in bands_list] | |
| groups_idx = [ | |
| idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if group_key in key | |
| ] | |
| s_t_x[:, :, :, indices] = x | |
| s_t_m[:, :, :, groups_idx] = 0 | |
| for x, bands_list, group_key in zip( | |
| [srtm, dw, wc], [SRTM_BANDS, DW_BANDS, WC_BANDS], ["SRTM", "DW", "WC"] | |
| ): | |
| if x is not None: | |
| indices = [idx for idx, val in enumerate(SPACE_BANDS) if val in bands_list] | |
| groups_idx = [idx for idx, key in enumerate(SPACE_BAND_GROUPS_IDX) if group_key in key] | |
| sp_x[:, :, indices] = x | |
| sp_m[:, :, groups_idx] = 0 | |
| for x, bands_list, group_key in zip( | |
| [era5, tc, viirs], [ERA5_BANDS, TC_BANDS, VIIRS_BANDS], ["ERA5", "TC", "VIIRS"] | |
| ): | |
| if x is not None: | |
| indices = [idx for idx, val in enumerate(TIME_BANDS) if val in bands_list] | |
| groups_idx = [idx for idx, key in enumerate(TIME_BAND_GROUPS_IDX) if group_key in key] | |
| t_x[:, indices] = x | |
| t_m[:, groups_idx] = 0 | |
| for x, bands_list, group_key in zip( | |
| [landscan, latlon], [LANDSCAN_BANDS, LOCATION_BANDS], ["LS", "location"] | |
| ): | |
| if x is not None: | |
| if group_key == "location": | |
| # transform latlon to cartesian | |
| x = cast(torch.Tensor, to_cartesian(x[0], x[1])) | |
| indices = [idx for idx, val in enumerate(STATIC_BANDS) if val in bands_list] | |
| groups_idx = [ | |
| idx for idx, key in enumerate(STATIC_BAND_GROUPS_IDX) if group_key in key | |
| ] | |
| st_x[indices] = x | |
| st_m[groups_idx] = 0 | |
| if months is None: | |
| months = torch.ones((t), dtype=torch.long, device=device) * DEFAULT_MONTH | |
| else: | |
| if months.shape[0] != t: | |
| raise ValueError("Incorrect number of input months") | |
| if normalize: | |
| normalizer = Normalizer(std=False) | |
| s_t_x = torch.from_numpy(normalizer(s_t_x.cpu().numpy())).to(device) | |
| sp_x = torch.from_numpy(normalizer(sp_x.cpu().numpy())).to(device) | |
| t_x = torch.from_numpy(normalizer(t_x.cpu().numpy())).to(device) | |
| st_x = torch.from_numpy(normalizer(st_x.cpu().numpy())).to(device) | |
| return MaskedOutput( | |
| space_time_x=s_t_x, | |
| space_time_mask=s_t_m, | |
| space_x=sp_x, | |
| space_mask=sp_m, | |
| time_x=t_x, | |
| time_mask=t_m, | |
| static_x=st_x, | |
| static_mask=st_m, | |
| months=months, | |
| ) | |