Spaces:
Sleeping
Sleeping
| # type: ignore | |
| import math | |
| from argparse import Namespace | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import AnyStr, Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from timm.models.layers import DropPath, trunc_normal_ | |
| from torch import Tensor | |
| PIXEL_WISE_MODALITIES = [ | |
| "sentinel2", | |
| "sentinel1", | |
| "aster", | |
| "canopy_height_eth", | |
| "esa_worldcover", | |
| "dynamic_world", | |
| ] | |
| # Input modalities for training | |
| INP_MODALITIES = { | |
| "sentinel2": [ | |
| "B1", | |
| "B2", | |
| "B3", | |
| "B4", | |
| "B5", | |
| "B6", | |
| "B7", | |
| "B8A", | |
| "B8", | |
| "B9", | |
| "B11", | |
| "B12", | |
| ], | |
| } | |
| # Output modalities for training | |
| OUT_MODALITIES = { | |
| "sentinel2": [ | |
| "B1", | |
| "B2", | |
| "B3", | |
| "B4", | |
| "B5", | |
| "B6", | |
| "B7", | |
| "B8A", | |
| "B8", | |
| "B9", | |
| "B11", | |
| "B12", | |
| ], | |
| "sentinel1": "all", | |
| "aster": "all", | |
| "era5": "all", | |
| "dynamic_world": "all", | |
| "canopy_height_eth": "all", | |
| "lat": "all", | |
| "lon": "all", | |
| "biome": "all", | |
| "eco_region": "all", | |
| "month": "all", | |
| "esa_worldcover": "all", | |
| } | |
| # an example of all the modalities. DO NOT CHANGE THIS, ALWAYS CHANGE THE INP and OUT MODALITIES ABOVE | |
| MODALITIES_FULL = { | |
| "sentinel2": [ | |
| "B1", | |
| "B2", | |
| "B3", | |
| "B4", | |
| "B5", | |
| "B6", | |
| "B7", | |
| "B8A", | |
| "B8", | |
| "B9", | |
| "B10", | |
| "B11", | |
| "B12", | |
| ], | |
| "sentinel2_cloudmask": ["QA60"], | |
| "sentinel2_cloudprod": ["MSK_CLDPRB"], | |
| "sentinel2_scl": ["SCL"], | |
| "sentinel1": [ | |
| "asc_VV", | |
| "asc_VH", | |
| "asc_HH", | |
| "asc_HV", | |
| "desc_VV", | |
| "desc_VH", | |
| "desc_HH", | |
| "desc_HV", | |
| ], | |
| "aster": ["elevation", "slope"], | |
| "era5": [ | |
| "prev_month_avg_temp", | |
| "prev_month_min_temp", | |
| "prev_month_max_temp", | |
| "prev_month_total_precip", | |
| "curr_month_avg_temp", | |
| "curr_month_min_temp", | |
| "curr_month_max_temp", | |
| "curr_month_total_precip", | |
| "year_avg_temp", | |
| "year_min_temp", | |
| "year_max_temp", | |
| "year_total_precip", | |
| ], | |
| "dynamic_world": ["landcover"], | |
| "canopy_height_eth": ["height", "std"], | |
| "lat": ["sin", "cos"], | |
| "lon": ["sin", "cos"], | |
| "biome": ["biome"], | |
| "eco_region": ["eco_region"], | |
| "month": ["sin_month", "cos_month"], | |
| "esa_worldcover": ["map"], | |
| } | |
| class MMEarthWrapper(nn.Module): | |
| def __init__( | |
| self, weights_path: Path, size="atto", do_pool=True, temporal_pooling: str = "mean" | |
| ): | |
| super().__init__() | |
| if size == "atto": | |
| self.dim = 320 | |
| check = weights_path / "mmearth-atto-checkpoint-199.pth" | |
| checkpoint = torch.load(check, map_location="cpu") | |
| weights = remap_checkpoint_keys(checkpoint["model"]) | |
| args = Namespace( | |
| checkpoint_dir=check, | |
| random_crop=True, | |
| random_crop_size=112, | |
| patch_size=16, | |
| loss_aggr="uncertainty", | |
| use_orig_stem=False, | |
| mask_ratio=0.6, | |
| linear_probe=False, | |
| ) | |
| args.inp_modalities = INP_MODALITIES | |
| args.out_modalities = OUT_MODALITIES | |
| args.modalities = args.inp_modalities.copy() | |
| args.modalities.update(args.out_modalities) | |
| args.modalities_full = MODALITIES_FULL | |
| model = convnextv2_atto( | |
| mask_ratio=args.mask_ratio, | |
| decoder_depth=1, | |
| decoder_embed_dim=512, | |
| norm_pix_loss=True, | |
| patch_size=args.patch_size, | |
| img_size=args.random_crop_size, | |
| args=args, | |
| ) | |
| self.encoder = model.encoder | |
| self.encoder.load_state_dict(weights, strict=False) | |
| self.image_resolution = 112 | |
| self.grid_size = 7 | |
| elif size == "tiny": | |
| self.dim = 768 | |
| check = weights_path / "mmearth-tiny-checkpoint-199.pth" | |
| checkpoint = torch.load(check, map_location="cpu") | |
| weights = remap_checkpoint_keys(checkpoint["model"]) | |
| args = Namespace( | |
| checkpoint_dir=check, | |
| random_crop=True, | |
| random_crop_size=56, | |
| patch_size=8, | |
| loss_aggr="uncertainty", | |
| use_orig_stem=False, | |
| mask_ratio=0.6, | |
| linear_probe=False, | |
| ) | |
| args.inp_modalities = INP_MODALITIES | |
| args.out_modalities = OUT_MODALITIES | |
| args.modalities = args.inp_modalities.copy() | |
| args.modalities.update(args.out_modalities) | |
| args.modalities_full = MODALITIES_FULL | |
| model = convnextv2_tiny( | |
| mask_ratio=args.mask_ratio, | |
| decoder_depth=1, | |
| decoder_embed_dim=512, | |
| norm_pix_loss=True, | |
| patch_size=args.patch_size, | |
| img_size=args.random_crop_size, | |
| args=args, | |
| ) | |
| self.encoder = model.encoder | |
| self.encoder.load_state_dict(weights, strict=False) | |
| self.image_resolution = 56 | |
| self.grid_size = 6 | |
| else: | |
| raise ValueError(f"size must be atto or tiny, not {size}") | |
| self.do_pool = do_pool | |
| if temporal_pooling not in ["mean", "max"]: | |
| raise ValueError( | |
| f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}" | |
| ) | |
| self.temporal_pooling = temporal_pooling | |
| def resize(self, images): | |
| images = F.interpolate( | |
| images, | |
| size=(self.image_resolution, self.image_resolution), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return images | |
| def preproccess(self, images): | |
| if len(images.shape) == 5: | |
| raise ValueError(f"Unexpected input shape {images.shape}") | |
| images = rearrange(images, "b h w c -> b c h w") | |
| assert images.shape[1] == 13 | |
| # MMEarth does not use B10 as input | |
| remove_idx = 10 | |
| images = torch.cat( | |
| [images[:, :remove_idx, :, :], images[:, (remove_idx + 1) :, :, :]], dim=1 | |
| ) | |
| assert images.shape[1] == 12 | |
| return self.resize(images) # (bsz, 12, 112, 112) | |
| def forward(self, s2=None, s1=None, months=None): | |
| if s2 is None: | |
| raise ValueError("S2 can't be None for MMEarth") | |
| if len(s2.shape) == 5: | |
| outputs_l: List[torch.Tensor] = [] | |
| for timestep in range(s2.shape[3]): | |
| image = self.preproccess(s2[:, :, :, timestep]) | |
| output = self.encoder(image) | |
| # output shape for atto: (bsz, 320, 7, 7) | |
| # output shape for tiny: (bsz, 768, 6, 6) | |
| if self.do_pool: | |
| output = output.mean(dim=-1).mean(dim=-1) | |
| else: | |
| output = rearrange(output, "b c h w -> b (h w) c") | |
| outputs_l.append(output) | |
| outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t | |
| if self.temporal_pooling == "mean": | |
| return outputs_t.mean(dim=-1) | |
| else: | |
| return torch.amax(outputs_t, dim=-1) | |
| else: | |
| s2 = self.preproccess(s2) | |
| output = self.encoder(s2) | |
| if self.do_pool: | |
| return output.mean(dim=-1).mean(dim=-1) # (bsz, dim) | |
| else: | |
| return rearrange(output, "b c h w -> b (h w) c") # (bsz, seq_len, dim) | |
| def remap_checkpoint_keys(ckpt): | |
| new_ckpt = OrderedDict() | |
| for k, v in ckpt.items(): | |
| if k.startswith("encoder"): | |
| k = ".".join(k.split(".")[1:]) # remove encoder in the name | |
| if k.endswith("kernel"): | |
| k = ".".join(k.split(".")[:-1]) # remove kernel in the name | |
| new_k = k + ".weight" | |
| if len(v.shape) == 3: # resahpe standard convolution | |
| kv, in_dim, out_dim = v.shape | |
| ks = int(math.sqrt(kv)) | |
| new_ckpt[new_k] = ( | |
| v.permute(2, 1, 0).reshape(out_dim, in_dim, ks, ks).transpose(3, 2) | |
| ) | |
| elif len(v.shape) == 2: # reshape depthwise convolution | |
| kv, dim = v.shape | |
| ks = int(math.sqrt(kv)) | |
| new_ckpt[new_k] = v.permute(1, 0).reshape(dim, 1, ks, ks).transpose(3, 2) | |
| continue | |
| elif "ln" in k or "linear" in k: | |
| k = k.split(".") | |
| k.pop(-2) # remove ln and linear in the name | |
| new_k = ".".join(k) | |
| elif "backbone.resnet" in k: | |
| # sometimes the resnet model is saved with the prefix backbone.resnet | |
| # we need to remove this prefix | |
| new_k = k.split("backbone.resnet.")[1] | |
| else: | |
| new_k = k | |
| new_ckpt[new_k] = v | |
| # reshape grn affine parameters and biases | |
| for k, v in new_ckpt.items(): | |
| if k.endswith("bias") and len(v.shape) != 1: | |
| new_ckpt[k] = v.reshape(-1) | |
| elif "grn" in k: | |
| new_ckpt[k] = v.unsqueeze(0).unsqueeze(1) | |
| return new_ckpt | |
| class LayerNorm(nn.Module): | |
| """LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with | |
| shape (batch_size, height, width, channels) while channels_first corresponds to inputs | |
| with shape (batch_size, channels, height, width). | |
| """ | |
| def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
| self.eps = eps | |
| self.data_format = data_format | |
| if self.data_format not in ["channels_last", "channels_first"]: | |
| raise NotImplementedError | |
| self.normalized_shape = (normalized_shape,) | |
| def forward(self, x): | |
| if self.data_format == "channels_last": | |
| return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
| elif self.data_format == "channels_first": | |
| u = x.mean(1, keepdim=True) | |
| s = (x - u).pow(2).mean(1, keepdim=True) | |
| x = (x - u) / torch.sqrt(s + self.eps) | |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
| return x | |
| class GRN(nn.Module): | |
| """GRN (Global Response Normalization) layer""" | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
| self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) | |
| def forward(self, x): | |
| Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) | |
| Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-4) | |
| return self.gamma * (x * Nx) + self.beta + x | |
| class Block(nn.Module): | |
| """ConvNeXtV2 Block. | |
| Args: | |
| dim (int): Number of input channels. | |
| drop_path (float): Stochastic depth rate. Default: 0.0 | |
| """ | |
| def __init__(self, dim, drop_path=0.0): | |
| super().__init__() | |
| self.dwconv: nn.Module = nn.Conv2d( | |
| dim, dim, kernel_size=7, padding=3, groups=dim | |
| ) # depth-wise conv | |
| self.norm: nn.Module = LayerNorm(dim, eps=1e-6) | |
| self.pwconv1: nn.Module = nn.Linear( | |
| dim, 4 * dim | |
| ) # point-wise/1x1 convs, implemented with linear layers | |
| self.act: nn.Module = nn.GELU() | |
| self.grn: nn.Module = GRN(4 * dim) | |
| self.pwconv2: nn.Module = nn.Linear(4 * dim, dim) | |
| self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
| def forward(self, x: Tensor) -> Tensor: | |
| input = x | |
| x = self.dwconv(x) | |
| x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) | |
| x = self.norm(x) | |
| x = self.pwconv1(x) | |
| x = self.act(x) | |
| x = self.grn(x) | |
| x = self.pwconv2(x) | |
| x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) | |
| x = input + self.drop_path(x) | |
| return x | |
| class ConvNeXtV2(nn.Module): | |
| """ConvNeXt V2 | |
| Args: | |
| in_chans (int): Number of input image channels. Default: 3 | |
| num_classes (int): Number of classes for classification head. Default: 1000 | |
| depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] | |
| dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] | |
| drop_path_rate (float): Stochastic depth rate. Default: 0. | |
| head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. | |
| """ | |
| def __init__( | |
| self, | |
| patch_size: int = 32, | |
| img_size: int = 128, | |
| in_chans: int = 3, | |
| num_classes: int = 1000, | |
| depths: Optional[list[int]] = None, | |
| dims: Optional[list[int]] = None, | |
| drop_path_rate: float = 0.0, | |
| head_init_scale: float = 1.0, | |
| use_orig_stem: bool = False, | |
| args: Optional[Namespace] = None, | |
| ): | |
| super().__init__() | |
| self.depths = depths | |
| if self.depths is None: # set default value | |
| self.depths = [3, 3, 9, 3] | |
| self.img_size = img_size | |
| self.use_orig_stem = use_orig_stem | |
| assert depths is not None | |
| self.num_stage = len(depths) | |
| self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layer | |
| self.patch_size = patch_size | |
| if dims is None: | |
| dims = [96, 192, 384, 768] | |
| if self.use_orig_stem: | |
| self.stem_orig = nn.Sequential( | |
| nn.Conv2d( | |
| in_chans, | |
| dims[0], | |
| kernel_size=patch_size // (2 ** (self.num_stage - 1)), | |
| stride=patch_size // (2 ** (self.num_stage - 1)), | |
| ), | |
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | |
| ) | |
| else: | |
| self.initial_conv = nn.Sequential( | |
| nn.Conv2d(in_chans, dims[0], kernel_size=3, stride=1), | |
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | |
| nn.GELU(), | |
| ) | |
| # depthwise conv for stem | |
| self.stem = nn.Sequential( | |
| nn.Conv2d( | |
| dims[0], | |
| dims[0], | |
| kernel_size=patch_size // (2 ** (self.num_stage - 1)), | |
| stride=patch_size // (2 ** (self.num_stage - 1)), | |
| padding=(patch_size // (2 ** (self.num_stage - 1))) // 2, | |
| groups=dims[0], | |
| ), | |
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), | |
| ) | |
| for i in range(3): | |
| downsample_layer = nn.Sequential( | |
| LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), | |
| nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), | |
| ) | |
| self.downsample_layers.append(downsample_layer) | |
| self.stages = ( | |
| nn.ModuleList() | |
| ) # 4 feature resolution stages, each consisting of multiple residual blocks | |
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
| cur = 0 | |
| for i in range(self.num_stage): | |
| stage = nn.Sequential( | |
| *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] | |
| ) | |
| self.stages.append(stage) | |
| cur += depths[i] | |
| self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer | |
| self.head = nn.Linear(dims[-1], num_classes) | |
| self.apply(self._init_weights) | |
| self.head.weight.data.mul_(head_init_scale) | |
| self.head.bias.data.mul_(head_init_scale) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def forward_features(self, x): | |
| if self.use_orig_stem: | |
| x = self.stem_orig(x) | |
| else: | |
| x = self.initial_conv(x) | |
| x = self.stem(x) | |
| x = self.stages[0](x) | |
| for i in range(3): | |
| x = self.downsample_layers[i](x) | |
| x = self.stages[i + 1](x) | |
| return x # pool with wrapper | |
| def upsample_mask(self, mask, scale): | |
| assert len(mask.shape) == 2 | |
| p = int(mask.shape[1] ** 0.5) | |
| return ( | |
| mask.reshape(-1, p, p) | |
| .repeat_interleave(scale, axis=1) | |
| .repeat_interleave(scale, axis=2) | |
| ) | |
| def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: | |
| # no masking | |
| return self.forward_features(x) | |
| class FCMAE(nn.Module): | |
| """Fully Convolutional Masked Autoencoder with ConvNeXtV2 backbone""" | |
| def __init__( | |
| self, | |
| img_size: int = 112, | |
| depths: list[int] = None, | |
| dims: list[int] = None, | |
| decoder_depth: int = 1, | |
| decoder_embed_dim: int = 512, | |
| patch_size: float = 16, | |
| mask_ratio: float = 0.6, | |
| norm_pix_loss: bool = False, | |
| args: Namespace = None, | |
| loss_fn=None, | |
| sparse: bool = True, | |
| ): | |
| super().__init__() | |
| print("using the multi-modal fcmae model") | |
| # configs | |
| self.args = args | |
| self.img_size = img_size | |
| if depths is None: # set default value | |
| depths = [3, 3, 9, 3] | |
| self.depths = depths | |
| if dims is None: | |
| dims = [96, 192, 384, 768] | |
| self.dims = dims | |
| self.patch_size = patch_size | |
| self.mask_ratio = mask_ratio | |
| self.num_patches = (img_size // patch_size) ** 2 | |
| self.decoder_embed_dim = decoder_embed_dim | |
| self.decoder_depth = decoder_depth | |
| self.norm_pix_loss = norm_pix_loss | |
| self.loss_fn = loss_fn | |
| self.sparse = sparse | |
| self.in_chans = ( | |
| len(args.modalities["sentinel2"]) | |
| if args.modalities["sentinel2"] != "all" | |
| else len(args.modalities_full["sentinel2"]) | |
| ) | |
| self.out_chans = {} | |
| for modality in self.args.modalities.keys(): | |
| if modality in ["sentinel2", "sentinel1", "aster", "canopy_height_eth"]: | |
| # all the conituous pixel level modalities | |
| if self.args.modalities[modality] == "all": | |
| self.out_chans[modality] = len(self.args.modalities_full[modality]) | |
| else: | |
| self.out_chans[modality] = len(self.args.modalities[modality]) | |
| elif modality == "biome": | |
| self.out_chans[modality] = 14 # 14 biomes | |
| elif modality == "eco_region": | |
| self.out_chans[modality] = 846 # 846 eco regions | |
| elif modality in ["lat", "lon", "month", "era5"]: | |
| if self.args.modalities[modality] == "all": | |
| self.out_chans[modality] = len(self.args.modalities_full[modality]) | |
| else: | |
| self.out_chans[modality] = len(self.args.modalities[modality]) | |
| elif modality == "esa_worldcover": | |
| self.out_chans[modality] = 11 # 11 classes for esa worldcover | |
| elif modality == "dynamic_world": | |
| self.out_chans[modality] = 9 # 9 classes for dynamic world | |
| # encoder | |
| self.encoder = ConvNeXtV2( | |
| in_chans=self.in_chans, | |
| depths=depths, | |
| dims=dims, | |
| patch_size=patch_size, | |
| img_size=img_size, | |
| use_orig_stem=args.use_orig_stem, | |
| ) | |
| self.proj = nn.Conv2d(in_channels=dims[-1], out_channels=decoder_embed_dim, kernel_size=1) | |
| # mask tokens | |
| self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim, 1, 1)) | |
| decoder = [Block(dim=decoder_embed_dim, drop_path=0.0) for _ in range(decoder_depth)] | |
| # creating a decoder for each modality | |
| self.decoder_dict = nn.ModuleDict() | |
| self.pred_dict = nn.ModuleDict() | |
| for modality in self.args.out_modalities.keys(): | |
| if modality in [ | |
| "sentinel2", | |
| "sentinel1", | |
| "aster", | |
| "canopy_height_eth", | |
| "dynamic_world", | |
| "esa_worldcover", | |
| "IMNET", | |
| ]: | |
| # all the pixel level modalities | |
| self.decoder_dict[modality] = nn.Sequential(*decoder) | |
| self.pred_dict[modality] = nn.Conv2d( | |
| in_channels=decoder_embed_dim, | |
| out_channels=patch_size**2 * self.out_chans[modality], | |
| kernel_size=1, | |
| ) | |
| elif modality in ["biome", "eco_region", "lat", "lon", "month", "era5"]: | |
| # all the non-pixel level modalities along with a global average pooling | |
| self.decoder_dict[modality] = nn.Sequential(*decoder) | |
| self.layer_norm_tmp = LayerNorm( | |
| decoder_embed_dim, eps=1e-6, data_format="channels_first" | |
| ) | |
| self.pred_dict[modality] = nn.Linear( | |
| in_features=decoder_embed_dim, out_features=self.out_chans[modality] | |
| ) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Conv2d): | |
| w = m.weight.data | |
| trunc_normal_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(m.bias, 0) | |
| if isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| if hasattr(self, "mask_token"): | |
| torch.nn.init.normal_(self.mask_token, std=0.02) | |
| def patchify(self, imgs: Tensor, modality: str) -> Tensor: | |
| """ | |
| imgs: (N, 3, H, W) | |
| x: (N, L, patch_size**2 *3) | |
| """ | |
| p = self.patch_size | |
| assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 | |
| if modality in ["dynamic_world", "esa_worldcover"]: | |
| # for these modalities, we only have one channel | |
| channels = 1 | |
| else: | |
| channels = self.out_chans[modality] | |
| h = w = imgs.shape[2] // p | |
| x = imgs.reshape(shape=(imgs.shape[0], channels, h, p, w, p)) | |
| x = torch.einsum("nchpwq->nhwpqc", x) | |
| x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * channels)) | |
| return x | |
| def unpatchify(self, x: Tensor) -> Tensor: | |
| """ | |
| x: (N, L, patch_size**2 *3) | |
| imgs: (N, 3, H, W) | |
| """ | |
| p = self.patch_size | |
| print("shape of x:", x.shape) | |
| h = w = self.img_size // p | |
| # assert h * w == x.shape[1] | |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_chans)) | |
| x = torch.einsum("nhwpqc->nchpwq", x) | |
| imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p, h * p)) | |
| return imgs | |
| def gen_random_mask(self, x: Tensor, mask_ratio: float) -> Tensor: | |
| N = x.shape[0] # number of samples | |
| L = (x.shape[2] // self.patch_size) ** 2 # number of patches | |
| len_keep = int(L * (1 - mask_ratio)) # number of patches to keep | |
| # the following lines generate a mask with 0s and 1s at random locations | |
| noise = torch.randn(N, L, device=x.device) | |
| # sort noise for each sample | |
| ids_shuffle = torch.argsort(noise, dim=1) | |
| ids_restore = torch.argsort(ids_shuffle, dim=1) | |
| # generate the binary mask: 0 is keep 1 is remove | |
| mask = torch.ones([N, L], device=x.device) | |
| mask[:, :len_keep] = 0 | |
| # unshuffle to get the binary mask | |
| mask = torch.gather(mask, dim=1, index=ids_restore) | |
| return mask # (batch_size, no_patches**2) | |
| def upsample_mask(self, mask: Tensor, scale: float): | |
| assert len(mask.shape) == 2 | |
| p = int(mask.shape[1] ** 0.5) | |
| return ( | |
| mask.reshape(-1, p, p).repeat_interleave(scale, dim=1).repeat_interleave(scale, dim=2) | |
| ) | |
| def forward_encoder(self, imgs: Tensor, mask_ratio: float) -> Tuple[Tensor, Tensor]: | |
| # generate random masks | |
| mask = self.gen_random_mask(imgs, mask_ratio) | |
| # encoding | |
| x = self.encoder(imgs, mask) | |
| return x, mask | |
| def forward_decoder(self, x: Tensor, mask: Tensor) -> Dict[AnyStr, Tensor]: | |
| pred = {} | |
| x = self.proj(x) | |
| n, c, h, w = x.shape | |
| mask = mask.reshape(-1, h, w).unsqueeze(1).type_as(x) | |
| mask_token = self.mask_token.repeat(x.shape[0], 1, x.shape[2], x.shape[3]) | |
| x = x * (1.0 - mask) + mask_token * mask | |
| for modalities in self.args.out_modalities.keys(): | |
| # decoding | |
| x_ = self.decoder_dict[modalities](x) | |
| if modalities in ["biome", "eco_region", "lat", "lon", "month", "era5"]: | |
| x_ = self.layer_norm_tmp(x_) | |
| # for the image level modalities we use global average pooling followed by the linear layer in pred_dict | |
| x_ = x_.mean(dim=[-2, -1]) | |
| # pred | |
| pred[modalities] = self.pred_dict[modalities](x_) | |
| return pred | |
| def forward_loss( | |
| self, imgs_dict: Dict[AnyStr, Tensor], preds: Dict[AnyStr, Tensor], mask: Tensor | |
| ) -> Tuple[Tensor, Dict, Tensor, Tensor]: | |
| """ | |
| imgs_dict: A dict of different modalities, each with shape of [N, C, H, W], C is the number of channels/bands | |
| preds: A dict of predictions for different modalities each of shape [N, L, p*p*C] | |
| mask: [N, L], 0 is keep, 1 is remove | |
| """ | |
| loss_dict = {} | |
| for modality in self.args.out_modalities.keys(): | |
| if modality in ["biome", "eco_region", "lat", "lon", "month", "era5"]: | |
| # all the image level modalities | |
| # we still further divide this into categorical and continuous modalities | |
| if modality in ["biome", "eco_region"]: | |
| # categorical modalities | |
| imgs = imgs_dict[modality] | |
| pred = preds[modality] | |
| imgs_classes = torch.argmax(imgs, dim=-1) | |
| # we don't need to patchify the image for these modalities | |
| # compute the loss | |
| loss = nn.CrossEntropyLoss()(pred, imgs_classes) | |
| loss_dict[modality] = loss | |
| elif modality in ["lat", "lon", "month", "era5"]: | |
| # continuous modalities | |
| imgs = imgs_dict[modality] | |
| pred = preds[modality] | |
| # we don't need to patchify the image for these modalities but we can still ignore any nan values | |
| nan_mask = torch.isnan(imgs) | |
| pred = pred[~nan_mask] | |
| imgs = imgs[~nan_mask] | |
| # compute the loss | |
| loss = nn.MSELoss()(pred, imgs) | |
| loss_dict[modality] = loss | |
| elif modality in ["dynamic_world", "esa_worldcover"]: | |
| # pixel level modalities but categorical | |
| imgs = imgs_dict[modality] | |
| pred = preds[modality] | |
| if len(pred.shape) == 4: | |
| n, c, _, _ = pred.shape | |
| pred = pred.reshape(n, c, -1) | |
| pred = torch.einsum("ncl->nlc", pred) | |
| # pred is of the shape [N, L, C] where C is patch_size**2 * num_classes. we need to first convert this to [N, L, patch_size**2, num_classes] | |
| # L is the number of patches | |
| pred = pred.reshape(pred.shape[0], pred.shape[1], self.patch_size**2, -1) | |
| target = self.patchify(imgs, modality) | |
| # we only compute the loss on the patches where the mask is 1 | |
| # mask is of the shape [N, L] | |
| # target is of the shape [N, L, patch_size**2 * num_classes] | |
| # pred is of the shape [N, L, patch_size**2, num_classes] | |
| # we need to apply the mask on target and pred for every channel | |
| target = target.reshape(target.shape[0], target.shape[1], self.patch_size**2, -1) | |
| mask_tmp = mask.unsqueeze(-1).repeat(1, 1, self.patch_size**2).unsqueeze(-1) | |
| target = target.reshape(target.shape[0], -1) | |
| pred = pred.reshape(pred.shape[0], -1, self.out_chans[modality]) | |
| mask_tmp = mask_tmp.reshape(mask.shape[0], -1) | |
| # we only compute the loss on the patches where the mask is 1 | |
| target = target[mask_tmp == 1] | |
| pred = pred[mask_tmp == 1] | |
| # we also apply a nan mask on the target and pred, since sometimes the target can be nan | |
| nan_mask = target == -1 | |
| target = target[~nan_mask] | |
| pred = pred[~nan_mask] | |
| loss = nn.CrossEntropyLoss()(pred, target) | |
| loss_dict[modality] = loss | |
| elif modality == "IMNET": | |
| imgs = imgs_dict[modality] | |
| pred = preds[modality] | |
| if len(pred.shape) == 4: | |
| n, c, _, _ = pred.shape | |
| pred = pred.reshape(n, c, -1) | |
| pred = torch.einsum("ncl->nlc", pred) | |
| target = self.patchify(imgs, modality) | |
| if self.norm_pix_loss: | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
| loss = (pred - target) ** 2 | |
| loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
| loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
| loss_dict[modality] = loss | |
| else: | |
| # pixel level modalities but continuous | |
| imgs = imgs_dict[modality] | |
| pred = preds[modality] | |
| if len(pred.shape) == 4: | |
| n, c, _, _ = pred.shape # [N, C, H, W] | |
| pred = pred.reshape(n, c, -1) | |
| pred = torch.einsum("ncl->nlc", pred) | |
| target = self.patchify(imgs, modality) | |
| if ( | |
| self.norm_pix_loss and modality == "sentinel2" | |
| ): # we only compute the per-patch norm on sentinel2 | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
| loss = (pred - target) ** 2 # using mean squared error | |
| nan_mask = torch.isnan(loss) | |
| count = torch.count_nonzero(~nan_mask, dim=-1) | |
| loss[nan_mask] = 0 | |
| loss = loss.sum(dim=-1) / count | |
| # uncomment the below line to compute the loss on the whole image - this results in better reconstructions, but | |
| # not better representations for downstream tasks | |
| # mask = torch.ones_like(mask) | |
| # counting the number of pixels where mask is 1 and loss is not nan. since we only compute the loss on these. | |
| # we create the nan mask again, since sometimes count can be 0. | |
| nan_mask = torch.isnan(loss * mask) | |
| tmp = loss * mask | |
| tmp[nan_mask] = 0 | |
| sum_ = tmp.sum() | |
| count = torch.count_nonzero(tmp) | |
| loss = sum_ / count # mean loss on removed patches | |
| loss_dict[modality] = loss | |
| loss_list = [loss_dict[modality] for modality in loss_dict.keys()] | |
| if self.args.loss_aggr == "uncertainty": | |
| uncertainty_loss_, log_vars = self.loss_fn(loss_list) | |
| loss_combined = sum(uncertainty_loss_) | |
| return loss_combined, loss_dict, log_vars, uncertainty_loss_ | |
| elif self.args.loss_aggr == "unweighted": | |
| loss_combined = sum(loss_list) | |
| return loss_combined, loss_dict, None, None | |
| def forward(self, imgs_dict: Dict[AnyStr, Tensor], labels=None, mask_ratio: float = 0.6): | |
| # apply random crop to all pixel-wise modalities | |
| params = self.random_crop.generate_parameters(imgs_dict["sentinel2"].shape) | |
| # Apply the same transform to all images in the batch | |
| for modality in imgs_dict: | |
| if modality in PIXEL_WISE_MODALITIES: | |
| imgs_dict[modality] = self.random_crop.apply_transform( | |
| imgs_dict[modality], params, None | |
| ) | |
| # here imgs_dict is a dictionary with every modality, we set imgs to be the input which in this case | |
| # is always sentinel2. | |
| imgs = imgs_dict["sentinel2"] | |
| # convert nan to 0 for "sentinel2", "sentinel1", "aster", "canopy_height_eth". | |
| # This is done since the data is normalized to have a mean of 0 and std of 1. hence | |
| # effectively we are setting the nan values to the mean. In the case of the input, | |
| # setting to 0 also ensures that these values become sparse. | |
| for modality in imgs_dict.keys(): | |
| if modality in ["sentinel2", "sentinel1", "aster", "canopy_height_eth"]: | |
| imgs_dict[modality] = torch.nan_to_num( | |
| imgs_dict[modality], nan=0.0, posinf=0.0, neginf=0.0 | |
| ) | |
| x, mask = self.forward_encoder(imgs, mask_ratio) | |
| pred = self.forward_decoder(x, mask) | |
| loss, loss_dict, log_vars, normalized_loss_list = self.forward_loss(imgs_dict, pred, mask) | |
| return loss, pred, mask, loss_dict, log_vars, normalized_loss_list | |
| def convnextv2_atto(**kwargs): | |
| model = FCMAE(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) | |
| return model | |
| def convnextv2_femto(**kwargs): | |
| model = FCMAE(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) | |
| return model | |
| def convnextv2_pico(**kwargs): | |
| model = FCMAE(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) | |
| return model | |
| def convnextv2_nano(**kwargs): | |
| model = FCMAE(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) | |
| return model | |
| def convnextv2_tiny(**kwargs): | |
| model = FCMAE(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) | |
| return model | |
| def convnextv2_base(**kwargs): | |
| model = FCMAE(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) | |
| return model | |
| def convnextv2_large(**kwargs): | |
| model = FCMAE(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) | |
| return model | |
| def convnextv2_huge(**kwargs): | |
| model = FCMAE(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) | |
| return model | |