Spaces:
Sleeping
Sleeping
| from typing import List, NamedTuple, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import default_collate | |
| from src.masking import ( | |
| MASKING_MODES, | |
| MaskingFunctions, | |
| batch_subset_mask_galileo, | |
| ) | |
| class CollateFnOutput(NamedTuple): | |
| s_t_x: torch.Tensor | |
| sp_x: torch.Tensor | |
| t_x: torch.Tensor | |
| st_x: torch.Tensor | |
| s_t_m: torch.Tensor | |
| sp_m: torch.Tensor | |
| t_m: torch.Tensor | |
| st_m: torch.Tensor | |
| months: torch.Tensor | |
| patch_size: float | |
| def collated_batch_to_output( | |
| s_t_x: torch.Tensor, | |
| sp_x: torch.Tensor, | |
| t_x: torch.Tensor, | |
| st_x: torch.Tensor, | |
| months: torch.Tensor, | |
| patch_sizes, | |
| shape_time_combinations, | |
| encode_ratio, | |
| decode_ratio, | |
| masking_function: MaskingFunctions, | |
| augmentation_strategies=None, | |
| fixed_patch_size=None, | |
| fixed_space_time_combination=None, | |
| masking_probabilities=None, | |
| max_unmasking_channels=4, | |
| unmasking_channels_combo: str = "shapes", | |
| ignore_band_groups: Optional[List[str]] = None, | |
| ) -> CollateFnOutput: | |
| if fixed_patch_size is not None: | |
| patch_size = fixed_patch_size | |
| else: | |
| # randomly sample a patch size, and a corresponding image size | |
| patch_size = np.random.choice(patch_sizes) | |
| if fixed_space_time_combination is not None: | |
| space_time_combination = fixed_space_time_combination | |
| else: | |
| space_time_combination = np.random.choice(shape_time_combinations) | |
| spatial_patches_per_dim = space_time_combination["size"] | |
| if int(spatial_patches_per_dim * patch_size) > s_t_x.shape[1]: | |
| spatial_patches_per_dim = int(s_t_x.shape[1] / patch_size) | |
| timesteps = space_time_combination["timesteps"] | |
| image_size = patch_size * spatial_patches_per_dim | |
| if masking_probabilities is None: | |
| masking_probabilities = [1] * len(MASKING_MODES) | |
| # randomly select a masking strategy | |
| s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m, months = batch_subset_mask_galileo( | |
| s_t_x, | |
| sp_x, | |
| t_x, | |
| st_x, | |
| months, | |
| encode_ratio=encode_ratio, | |
| patch_size=patch_size, | |
| image_size=image_size, | |
| num_timesteps=timesteps, | |
| decode_ratio=decode_ratio, | |
| augmentation_strategies=augmentation_strategies, | |
| masking_probabilities=masking_probabilities, | |
| masking_function=masking_function, | |
| max_unmasking_channels=max_unmasking_channels, | |
| unmasking_channels_combo=unmasking_channels_combo, | |
| ignore_band_groups=ignore_band_groups, | |
| ) | |
| return CollateFnOutput( | |
| s_t_x, | |
| sp_x, | |
| t_x, | |
| st_x, | |
| s_t_m, | |
| sp_m, | |
| t_m, | |
| st_m, | |
| months, | |
| patch_size, | |
| ) | |
| def galileo_collate_fn( | |
| batch, | |
| patch_sizes, | |
| shape_time_combinations, | |
| st_encode_ratio=None, | |
| st_decode_ratio=None, | |
| random_encode_ratio=None, | |
| random_decode_ratio=None, | |
| augmentation_strategies=None, | |
| fixed_patch_size=None, | |
| fixed_space_time_combination=None, | |
| masking_probabilities=None, | |
| max_unmasking_channels=4, | |
| random_masking: str = "None", | |
| unmasking_channels_combo: str = "shapes", | |
| ignore_band_groups: Optional[List[str]] = None, | |
| ) -> Tuple[CollateFnOutput, CollateFnOutput, CollateFnOutput, CollateFnOutput]: | |
| s_t_x, sp_x, t_x, st_x, months = default_collate(batch) | |
| input_args = { | |
| "s_t_x": s_t_x, | |
| "sp_x": sp_x, | |
| "t_x": t_x, | |
| "st_x": st_x, | |
| "months": months, | |
| "patch_sizes": patch_sizes, | |
| "augmentation_strategies": augmentation_strategies, | |
| "fixed_patch_size": fixed_patch_size, | |
| "fixed_space_time_combination": fixed_space_time_combination, | |
| "masking_probabilities": masking_probabilities, | |
| "shape_time_combinations": shape_time_combinations, | |
| "max_unmasking_channels": max_unmasking_channels, | |
| "unmasking_channels_combo": unmasking_channels_combo, | |
| "ignore_band_groups": ignore_band_groups, | |
| } | |
| if random_masking == "none": | |
| if st_encode_ratio is None: | |
| raise ValueError("st_encode_ratio can't be None for random_masking='none'") | |
| if st_decode_ratio is None: | |
| raise ValueError("st_decode_ratio can't be None for random_masking='none'") | |
| input_args.update({"encode_ratio": st_encode_ratio, "decode_ratio": st_decode_ratio}) | |
| return ( | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.TIME, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.SPACE, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.TIME, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.SPACE, | |
| ), | |
| ) | |
| elif random_masking == "half": | |
| if st_encode_ratio is None: | |
| raise ValueError("st_encode_ratio can't be None for random_masking='half'") | |
| if st_decode_ratio is None: | |
| raise ValueError("st_decode_ratio can't be None for random_masking='half'") | |
| if random_encode_ratio is None: | |
| raise ValueError("random_encode_ratio can't be None for random_masking='half'") | |
| if random_decode_ratio is None: | |
| raise ValueError("random_decode_ratio can't be None for random_masking='half'") | |
| return ( | |
| collated_batch_to_output( | |
| **input_args, | |
| encode_ratio=st_encode_ratio, | |
| decode_ratio=st_decode_ratio, | |
| masking_function=MaskingFunctions.TIME, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| encode_ratio=st_encode_ratio, | |
| decode_ratio=st_decode_ratio, | |
| masking_function=MaskingFunctions.SPACE, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| encode_ratio=random_encode_ratio, | |
| decode_ratio=random_decode_ratio, | |
| masking_function=MaskingFunctions.RANDOM, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| encode_ratio=random_encode_ratio, | |
| decode_ratio=random_decode_ratio, | |
| masking_function=MaskingFunctions.RANDOM, | |
| ), | |
| ) | |
| elif random_masking == "full": | |
| if random_encode_ratio is None: | |
| raise ValueError("random_encode_ratio can't be None for random_masking='full'") | |
| if random_decode_ratio is None: | |
| raise ValueError("random_decode_ratio can't be None for random_masking='full'") | |
| input_args.update( | |
| {"encode_ratio": random_encode_ratio, "decode_ratio": random_decode_ratio} | |
| ) | |
| return ( | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.RANDOM, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.RANDOM, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.RANDOM, | |
| ), | |
| collated_batch_to_output( | |
| **input_args, | |
| masking_function=MaskingFunctions.RANDOM, | |
| ), | |
| ) | |
| else: | |
| raise ValueError(f"Expected random_masking to be (none, half full), got {random_masking}") | |