Spaces:
Sleeping
Sleeping
| # type: ignore | |
| # Copyright (c) IBM Corp. 2024. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------- | |
| # References: | |
| # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
| # transformers: https://github.com/huggingface/transformers | |
| # -------------------------------------------------------- | |
| import logging | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import yaml | |
| from einops import rearrange, repeat | |
| from timm.layers import to_2tuple | |
| from timm.models.vision_transformer import Block | |
| class PrithviWrapper(nn.Module): | |
| # we assume any data passed to this wrapper | |
| # will contain S2 data with the following channels | |
| INPUT_S2_BAND_ORDERING = [ | |
| "B01", | |
| "B02", | |
| "B03", | |
| "B04", | |
| "B05", | |
| "B06", | |
| "B07", | |
| "B08", | |
| "B08A", | |
| "B09", | |
| "B10", | |
| "B11", | |
| "B12", | |
| ] | |
| def __init__(self, weights_path: Path, do_pool=True, temporal_pooling: str = "mean"): | |
| super().__init__() | |
| with (weights_path / "prithvi/config.json").open("r") as f: | |
| config = yaml.safe_load(f)["pretrained_cfg"] | |
| config["num_frames"] = 1 | |
| self.model = PrithviMAE(**config) | |
| state_dict = torch.load(weights_path / "prithvi/Prithvi_EO_V2_300M.pt", map_location="cpu") | |
| # discard fixed pos_embedding weight, following | |
| # https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M/blob/e4aabdc440c8ee703a749def8af5bf4700dee35b/inference.py#L362 | |
| for k in list(state_dict.keys()): | |
| if "pos_embed" in k: | |
| del state_dict[k] | |
| self.model.load_state_dict(state_dict, strict=False) | |
| self.image_resolution = config["img_size"] | |
| self.grid_size = int(config["img_size"] // config["patch_size"][-1]) | |
| self.bands = config["bands"] | |
| self.inputs_to_prithvi = [self.INPUT_S2_BAND_ORDERING.index(b) for b in self.bands] | |
| self.do_pool = do_pool | |
| if temporal_pooling not in ["mean", "max"]: | |
| raise ValueError( | |
| f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}" | |
| ) | |
| self.temporal_pooling = temporal_pooling | |
| self.dim = config["embed_dim"] | |
| def resize(self, images): | |
| images = F.interpolate( | |
| images, | |
| size=(self.image_resolution, self.image_resolution), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return images | |
| def preproccess(self, images): | |
| if len(images.shape) == 5: | |
| # take the mean along the temporal dimension | |
| images = torch.mean(images, dim=2) | |
| images = rearrange(images, "b h w c -> b c h w") | |
| assert images.shape[1] == 13 | |
| images = images[:, self.inputs_to_prithvi, :, :] | |
| images = self.resize(images) # (bsz, C, H, W) | |
| return repeat(images, "b c h w -> b c t h w", t=1) | |
| def forward(self, s2=None, s1=None, months=None): | |
| if s2 is None: | |
| raise ValueError("S2 can't be None for Prithvi") | |
| if len(s2.shape) == 5: | |
| outputs_l: List[torch.Tensor] = [] | |
| for timestep in range(s2.shape[3]): | |
| image = self.preproccess(s2[:, :, :, timestep]) | |
| output = self.model.forward_features(image)[-1] | |
| # following | |
| # https://github.com/IBM/terratorch/blob/main/terratorch/models/backbones/prithvi_mae.py#L449 | |
| # we remove the class token. This is also the approach they | |
| # take for classification: https://github.com/IBM/terratorch/blob/main/terratorch/models/scalar_output_model.py#L19 | |
| output = output[:, 1:, :] | |
| # output shape: (bsz, num_tokens, dim) | |
| if self.do_pool: | |
| output = output.mean(dim=1) | |
| outputs_l.append(output) | |
| outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t | |
| if self.temporal_pooling == "mean": | |
| return outputs_t.mean(dim=-1) | |
| else: | |
| return torch.amax(outputs_t, dim=-1) | |
| else: | |
| s2 = self.preproccess(s2) | |
| output = self.model.forward_features(s2)[-1] | |
| output = output[:, 1:, :] | |
| if self.do_pool: | |
| return output.mean(dim=1) | |
| else: | |
| return output | |
| def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): | |
| """ | |
| Create 3D sin/cos positional embeddings. | |
| Args: | |
| embed_dim (int): | |
| Embedding dimension. | |
| grid_size (tuple[int, int, int] | list[int]): | |
| The grid depth, height and width. | |
| add_cls_token (bool, *optional*, defaults to False): | |
| Whether or not to add a classification (CLS) token. | |
| Returns: | |
| (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or | |
| (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) | |
| """ | |
| assert embed_dim % 16 == 0 | |
| t_size, h_size, w_size = grid_size | |
| w_embed_dim = embed_dim // 16 * 6 | |
| h_embed_dim = embed_dim // 16 * 6 | |
| t_embed_dim = embed_dim // 16 * 4 | |
| w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) | |
| h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) | |
| t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) | |
| w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) | |
| h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) | |
| t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) | |
| pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) | |
| if add_cls_token: | |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
| """ | |
| if embed_dim % 2 != 0: | |
| raise ValueError("embed_dim must be even") | |
| omega = np.arange(embed_dim // 2, dtype=float) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): | |
| """This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However, | |
| it was modified to cast omega values to pos.dtype which must be float (and not int as in | |
| regular positional embeddings). This was required in order to allow for native FSDP mixed | |
| precision support: modify omega to appropriate dtype (pos carries the correct float dtype), | |
| instead of manually forcing float32. | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) - must be float dtype! | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] | |
| omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = torch.sin(out) # (M, D/2) | |
| emb_cos = torch.cos(out) # (M, D/2) | |
| emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
| return emb | |
| def _init_weights(module): | |
| """Initialize the weights""" | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| class PatchEmbed(nn.Module): | |
| """3D version of timm.models.vision_transformer.PatchEmbed""" | |
| def __init__( | |
| self, | |
| input_size: Tuple[int, int, int] = (1, 224, 224), | |
| patch_size: Tuple[int, int, int] = (1, 16, 16), | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| norm_layer: nn.Module | None = None, | |
| flatten: bool = True, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.patch_size = patch_size | |
| self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] | |
| self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | |
| self.flatten = flatten | |
| self.proj = nn.Conv3d( | |
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias | |
| ) | |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
| def forward(self, x): | |
| B, C, T, H, W = x.shape | |
| if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: | |
| logging.warning( | |
| f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." | |
| f"The border will be ignored, add backbone_padding for pixel-wise tasks." | |
| ) | |
| x = self.proj(x) | |
| if self.flatten: | |
| x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C | |
| x = self.norm(x) | |
| return x | |
| class TemporalEncoder(nn.Module): | |
| def __init__(self, embed_dim: int, trainable_scale: bool = False): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.year_embed_dim = embed_dim // 2 | |
| self.julian_day_embed_dim = embed_dim - self.year_embed_dim | |
| # If trainable, initialize scale with small number | |
| if trainable_scale: | |
| self.scale = nn.Parameter(torch.full((1,), 0.1)) | |
| else: | |
| self.register_buffer("scale", torch.ones(1)) | |
| def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): | |
| """ | |
| temporal_coords: year and day-of-year info with shape (B, T, 2). | |
| tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be | |
| repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). | |
| """ | |
| shape = temporal_coords.shape[:2] + (-1,) # B, T, -1 | |
| year = _get_1d_sincos_embed_from_grid_torch( | |
| self.year_embed_dim, temporal_coords[:, :, 0].flatten() | |
| ).reshape(shape) | |
| julian_day = _get_1d_sincos_embed_from_grid_torch( | |
| self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten() | |
| ).reshape(shape) | |
| embedding = self.scale * torch.cat([year, julian_day], dim=-1) | |
| if tokens_per_frame is not None: | |
| embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) | |
| return embedding # B, T*tokens_per_frame, embed_dim | |
| class LocationEncoder(nn.Module): | |
| def __init__(self, embed_dim: int, trainable_scale: bool = False): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.lat_embed_dim = embed_dim // 2 | |
| self.lon_embed_dim = embed_dim - self.lat_embed_dim | |
| # If trainable, initialize scale with small number | |
| if trainable_scale: | |
| self.scale = nn.Parameter(torch.full((1,), 0.1)) | |
| else: | |
| self.register_buffer("scale", torch.ones(1)) | |
| def forward(self, location_coords: torch.Tensor): | |
| """ | |
| location_coords: lat and lon info with shape (B, 2). | |
| """ | |
| shape = location_coords.shape[:1] + (1, -1) # B, 1, -1 | |
| lat = _get_1d_sincos_embed_from_grid_torch( | |
| self.lat_embed_dim, location_coords[:, 0].flatten() | |
| ).reshape(shape) | |
| lon = _get_1d_sincos_embed_from_grid_torch( | |
| self.lon_embed_dim, location_coords[:, 1].flatten() | |
| ).reshape(shape) | |
| embedding = self.scale * torch.cat([lat, lon], dim=-1) | |
| return embedding # B, 1, embed_dim | |
| class PrithviViT(nn.Module): | |
| """Prithvi ViT Encoder""" | |
| def __init__( | |
| self, | |
| img_size: int | Tuple[int, int] = 224, | |
| patch_size: int | Tuple[int, int, int] = (1, 16, 16), | |
| num_frames: int = 1, | |
| in_chans: int = 3, | |
| embed_dim: int = 1024, | |
| depth: int = 24, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), | |
| coords_encoding: List[str] | None = None, | |
| coords_scale_learn: bool = False, | |
| encoder_only: bool = True, # needed for timm | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.feature_info = [] | |
| self.encoder_only = encoder_only | |
| self.in_chans = in_chans | |
| self.num_frames = num_frames | |
| self.embed_dim = embed_dim | |
| self.img_size = to_2tuple(img_size) | |
| if isinstance(patch_size, int): | |
| patch_size = (1, patch_size, patch_size) | |
| # 3D patch embedding | |
| self.patch_embed = PatchEmbed( | |
| input_size=(num_frames,) + self.img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| ) | |
| # Optional temporal and location embedding | |
| coords_encoding = coords_encoding or [] | |
| self.temporal_encoding = "time" in coords_encoding | |
| self.location_encoding = "location" in coords_encoding | |
| if self.temporal_encoding: | |
| assert ( | |
| patch_size[0] == 1 | |
| ), f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" | |
| self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) | |
| if self.location_encoding: | |
| self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.register_buffer( | |
| "pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim) | |
| ) | |
| # Transformer layers | |
| self.blocks = [] | |
| for i in range(depth): | |
| self.blocks.append( | |
| Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
| ) | |
| self.feature_info.append( | |
| { | |
| "num_chs": embed_dim * self.patch_embed.patch_size[0], | |
| "reduction": 1, | |
| "module": f"blocks.{i}", | |
| } | |
| ) | |
| self.blocks = nn.ModuleList(self.blocks) | |
| self.norm = norm_layer(embed_dim) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # initialize (and freeze) position embeddings by sin-cos embedding | |
| pos_embed = get_3d_sincos_pos_embed( | |
| self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True | |
| ) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) | |
| w = self.patch_embed.proj.weight.data | |
| torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
| torch.nn.init.normal_(self.cls_token, std=0.02) | |
| self.apply(_init_weights) | |
| def random_masking(self, sequence, mask_ratio, noise=None): | |
| """ | |
| Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random | |
| noise. | |
| Args: | |
| sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) | |
| mask_ratio (float): mask ratio to use. | |
| noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is | |
| mainly used for testing purposes to control randomness and maintain the reproducibility | |
| """ | |
| batch_size, seq_length, dim = sequence.shape | |
| len_keep = int(seq_length * (1 - mask_ratio)) | |
| if noise is None: | |
| noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] | |
| # sort noise for each sample | |
| ids_shuffle = torch.argsort(noise, dim=1).to( | |
| sequence.device | |
| ) # ascend: small is keep, large is remove | |
| ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) | |
| # keep the first subset | |
| ids_keep = ids_shuffle[:, :len_keep] | |
| sequence_unmasked = torch.gather( | |
| sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim) | |
| ) | |
| # generate the binary mask: 0 is keep, 1 is remove | |
| mask = torch.ones([batch_size, seq_length], device=sequence.device) | |
| mask[:, :len_keep] = 0 | |
| # unshuffle to get the binary mask | |
| mask = torch.gather(mask, dim=1, index=ids_restore) | |
| return sequence_unmasked, mask, ids_restore | |
| def _get_pos_embed(self, x): | |
| t, h, w = x.shape[-3:] | |
| pos_embed = ( | |
| torch.from_numpy( | |
| get_3d_sincos_pos_embed( | |
| self.embed_dim, | |
| ( | |
| t // self.patch_embed.patch_size[0], | |
| h // self.patch_embed.patch_size[1], | |
| w // self.patch_embed.patch_size[2], | |
| ), | |
| add_cls_token=True, | |
| ) | |
| ) | |
| .float() | |
| .unsqueeze(0) | |
| .to(x) | |
| ) | |
| return pos_embed | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| temporal_coords: None | torch.Tensor = None, | |
| location_coords: None | torch.Tensor = None, | |
| mask_ratio=0.0, | |
| ): | |
| if x.shape[-3:] != self.patch_embed.input_size: | |
| # changed input size | |
| pos_embed = self._get_pos_embed(x) | |
| else: | |
| pos_embed = self.pos_embed | |
| # embed patches | |
| x = self.patch_embed(x) | |
| # add pos embed w/o cls token | |
| x = x + pos_embed[:, 1:, :] | |
| if self.temporal_encoding: | |
| num_tokens_per_frame = x.shape[1] // self.num_frames | |
| temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) | |
| x = x + temporal_encoding | |
| if self.location_encoding: | |
| location_encoding = self.location_embed_enc(location_coords) | |
| x = x + location_encoding | |
| # masking: length -> length * mask_ratio | |
| x, mask, ids_restore = self.random_masking(x, mask_ratio) | |
| # append cls token | |
| cls_token = self.cls_token + pos_embed[:, :1, :] | |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| # apply Transformer blocks | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm(x) | |
| return x, mask, ids_restore | |
| def forward_features( | |
| self, | |
| x: torch.Tensor, | |
| temporal_coords: None | torch.Tensor = None, | |
| location_coords: None | torch.Tensor = None, | |
| ) -> list[torch.Tensor]: | |
| if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: | |
| # add time dim | |
| x = x.unsqueeze(2) | |
| if x.shape[-3:] != self.patch_embed.input_size: | |
| pos_embed = self._get_pos_embed(x) | |
| else: | |
| pos_embed = self.pos_embed | |
| # embed patches | |
| x = self.patch_embed(x) | |
| # add pos embed w/o cls token | |
| x = x + pos_embed[:, 1:, :] | |
| if self.temporal_encoding: | |
| num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames | |
| temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) | |
| x = x + temporal_encoding | |
| if self.location_encoding: | |
| location_encoding = self.location_embed_enc(location_coords) | |
| x = x + location_encoding | |
| # append cls token | |
| cls_token = self.cls_token + pos_embed[:, :1, :] | |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| # apply Transformer blocks | |
| out = [] | |
| for block in self.blocks: | |
| x = block(x) | |
| out.append(x.clone()) | |
| x = self.norm(x) | |
| out[-1] = x | |
| return out | |
| def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: | |
| out = [] | |
| effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] | |
| for x in features: | |
| x_no_token = x[:, 1:, :] | |
| number_of_tokens = x_no_token.shape[1] | |
| tokens_per_timestep = number_of_tokens // effective_time_dim | |
| h = int(np.sqrt(tokens_per_timestep)) | |
| encoded = rearrange( | |
| x_no_token, | |
| "batch (t h w) e -> batch (t e) h w", | |
| e=self.embed_dim, | |
| t=effective_time_dim, | |
| h=h, | |
| ) | |
| out.append(encoded) | |
| return out | |
| class MAEDecoder(nn.Module): | |
| """Transformer Decoder used in the Prithvi MAE""" | |
| def __init__( | |
| self, | |
| patch_size: int | Tuple[int, int, int] = (1, 16, 16), | |
| grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14), | |
| in_chans: int = 3, | |
| encoder_embed_dim: int = 1024, | |
| decoder_embed_dim: int = 512, | |
| depth: int = 8, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| coords_encoding: List[str] | None = None, | |
| coords_scale_learn: bool = False, | |
| ): | |
| super().__init__() | |
| self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) | |
| self.decoder_embed_dim = decoder_embed_dim | |
| self.grid_size = grid_size | |
| if isinstance(patch_size, int): | |
| patch_size = (1, patch_size, patch_size) | |
| self.patch_size = patch_size | |
| self.num_frames = self.grid_size[0] * patch_size[0] | |
| num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | |
| # Optional temporal and location embedding | |
| coords_encoding = coords_encoding or [] | |
| self.temporal_encoding = "time" in coords_encoding | |
| self.location_encoding = "location" in coords_encoding | |
| if self.temporal_encoding: | |
| self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn) | |
| if self.location_encoding: | |
| self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn) | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
| self.register_buffer( | |
| "decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim) | |
| ) | |
| self.decoder_blocks = nn.ModuleList( | |
| [ | |
| Block( | |
| decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| self.decoder_norm = norm_layer(decoder_embed_dim) | |
| self.decoder_pred = nn.Linear( | |
| decoder_embed_dim, patch_size[0] * patch_size[1] * patch_size[2] * in_chans, bias=True | |
| ) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # initialize (and freeze) position embeddings by sin-cos embedding | |
| decoder_pos_embed = get_3d_sincos_pos_embed( | |
| self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True | |
| ) | |
| self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) | |
| # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
| torch.nn.init.normal_(self.mask_token, std=0.02) | |
| self.apply(_init_weights) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| ids_restore: torch.Tensor, | |
| temporal_coords: None | torch.Tensor = None, | |
| location_coords: None | torch.Tensor = None, | |
| input_size: list[int] = None, | |
| ): | |
| # embed tokens | |
| x = self.decoder_embed(hidden_states) | |
| t, h, w = input_size[-3:] | |
| decoder_pos_embed = torch.from_numpy( | |
| get_3d_sincos_pos_embed( | |
| self.decoder_embed_dim, | |
| ( | |
| t // self.patch_size[0], | |
| h // self.patch_size[1], | |
| w // self.patch_size[2], | |
| ), | |
| add_cls_token=True, | |
| ) | |
| ).to(x) | |
| # append mask tokens to sequence | |
| mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token | |
| # unshuffle | |
| x_ = torch.gather( | |
| x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device) | |
| ) | |
| x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token | |
| # add pos embed | |
| x = x + decoder_pos_embed | |
| # remove cls token | |
| x_ = x[:, 1:, :] | |
| if self.temporal_encoding: | |
| num_tokens_per_frame = x_.shape[1] // self.num_frames | |
| temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) | |
| # Add temporal encoding w/o cls token | |
| x_ = x_ + temporal_encoding | |
| if self.location_encoding: | |
| location_encoding = self.location_embed_dec(location_coords) | |
| # Add location encoding w/o cls token | |
| x_ = x_ + location_encoding | |
| # append cls token | |
| x = torch.cat([x[:, :1, :], x_], dim=1) | |
| # apply Transformer layers (blocks) | |
| for block in self.decoder_blocks: | |
| x = block(x) | |
| x = self.decoder_norm(x) | |
| # predictor projection | |
| pred = self.decoder_pred(x) | |
| # remove cls token | |
| pred = pred[:, 1:, :] | |
| return pred | |
| class PrithviMAE(nn.Module): | |
| """Prithvi Masked Autoencoder""" | |
| def __init__( | |
| self, | |
| img_size: int | Tuple[int, int] = 224, | |
| patch_size: int | Tuple[int, int, int] = (1, 16, 16), | |
| num_frames: int = 3, | |
| in_chans: int = 3, | |
| embed_dim: int = 1024, | |
| depth: int = 24, | |
| num_heads: int = 16, | |
| decoder_embed_dim: int = 512, | |
| decoder_depth: int = 8, | |
| decoder_num_heads: int = 16, | |
| mlp_ratio: float = 4.0, | |
| norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6), | |
| norm_pix_loss: bool = False, | |
| coords_encoding: List[str] | None = None, | |
| coords_scale_learn: bool = False, | |
| encoder_only: bool = False, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.encoder = PrithviViT( | |
| img_size=img_size, | |
| num_frames=num_frames, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| depth=depth, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| norm_layer=norm_layer, | |
| coords_encoding=coords_encoding, | |
| coords_scale_learn=coords_scale_learn, | |
| ) | |
| self.encoder_only = encoder_only | |
| if not encoder_only: | |
| self.decoder = MAEDecoder( | |
| patch_size=patch_size, | |
| grid_size=self.encoder.patch_embed.grid_size, | |
| in_chans=in_chans, | |
| encoder_embed_dim=embed_dim, | |
| decoder_embed_dim=decoder_embed_dim, | |
| depth=decoder_depth, | |
| num_heads=decoder_num_heads, | |
| mlp_ratio=mlp_ratio, | |
| norm_layer=norm_layer, | |
| coords_encoding=coords_encoding, | |
| coords_scale_learn=coords_scale_learn, | |
| ) | |
| else: | |
| self.decoder = nn.Identity() | |
| self.norm_pix_loss = norm_pix_loss | |
| def patchify(self, pixel_values): | |
| """ | |
| Args: | |
| pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): | |
| Pixel values. | |
| Returns: | |
| torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | |
| Patchified pixel values. | |
| """ | |
| patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size | |
| num_channels = self.encoder.in_chans | |
| # patchify | |
| patchified_pixel_values = rearrange( | |
| pixel_values, | |
| "b c (t s) (h p) (w q) -> b (t h w) (s p q c)", | |
| c=num_channels, | |
| s=patch_size_t, | |
| p=patch_size_h, | |
| q=patch_size_w, | |
| ) | |
| return patchified_pixel_values | |
| def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): | |
| """ | |
| Args: | |
| patchified_pixel_values (`torch.FloatTensor` of shape | |
| `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | |
| Patchified pixel values. | |
| image_size (`Tuple[int, int]`, *optional*): | |
| Original image size. | |
| Returns: | |
| `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: | |
| Pixel values. | |
| """ | |
| patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size | |
| image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size | |
| original_height, original_width = image_size | |
| num_patches_h = original_height // patch_size_h | |
| num_patches_w = original_width // patch_size_w | |
| num_channels = self.encoder.in_chans | |
| pixel_values = rearrange( | |
| patchified_pixel_values, | |
| "b (t h w) (s p q c) -> b c (t s) (h p) (w q)", | |
| c=num_channels, | |
| h=num_patches_h, | |
| w=num_patches_w, | |
| s=patch_size_t, | |
| p=patch_size_h, | |
| q=patch_size_w, | |
| ) | |
| return pixel_values | |
| def forward_loss(self, pixel_values, pred, mask): | |
| """ | |
| Args: | |
| pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): | |
| Pixel values. | |
| pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: | |
| Predicted pixel values. | |
| mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): | |
| Tensor indicating which patches are masked (1) and which are not (0). | |
| Returns: | |
| `torch.FloatTensor`: Pixel reconstruction loss. | |
| """ | |
| target = self.patchify(pixel_values) | |
| if self.norm_pix_loss: | |
| mean = target.mean(dim=-1, keepdim=True) | |
| var = target.var(dim=-1, keepdim=True) | |
| target = (target - mean) / (var + 1.0e-6) ** 0.5 | |
| loss = (pred - target) ** 2 | |
| loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
| loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
| return loss | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| temporal_coords: None | torch.Tensor = None, | |
| location_coords: None | torch.Tensor = None, | |
| mask_ratio: float = 0.75, | |
| ): | |
| if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: | |
| # add time dim | |
| pixel_values = pixel_values.unsqueeze(2) | |
| latent, mask, ids_restore = self.encoder( | |
| pixel_values, temporal_coords, location_coords, mask_ratio | |
| ) | |
| pred = self.decoder( | |
| latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape | |
| ) | |
| loss = self.forward_loss(pixel_values, pred, mask) | |
| return loss, pred, mask | |
| def forward_features( | |
| self, | |
| x: torch.Tensor, | |
| temporal_coords: None | torch.Tensor = None, | |
| location_coords: None | torch.Tensor = None, | |
| ) -> List[torch.Tensor]: | |
| return self.encoder.forward_features(x, temporal_coords, location_coords) | |