Spaces:
Sleeping
Sleeping
| from functools import partial | |
| from typing import List | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from timm.models.vision_transformer import PatchEmbed, VisionTransformer | |
| class SatMAEWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| pretrained_path, | |
| size="large", | |
| img_size=96, | |
| do_pool=True, | |
| temporal_pooling: str = "mean", | |
| ): | |
| super().__init__() | |
| if size == "large": | |
| self.encoder = vit_large(img_size=img_size, patch_size=8, in_chans=10) | |
| self.dim = 1024 | |
| elif size == "base": | |
| self.encoder = vit_base(img_size=img_size, patch_size=8, in_chans=10) | |
| self.dim = 768 | |
| checkpoint = torch.load(pretrained_path, map_location="cpu")["model"] | |
| if img_size != 96: | |
| checkpoint = interpolate_pos_embed(self.encoder, checkpoint) | |
| self.encoder.load_state_dict(checkpoint, strict=False) | |
| self.image_resolution = img_size | |
| self.do_pool = do_pool | |
| self.patch_size = 8 | |
| self.grid_size = int(self.image_resolution / self.patch_size) | |
| if temporal_pooling not in ["mean", "max"]: | |
| raise ValueError( | |
| f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}" | |
| ) | |
| self.temporal_pooling = temporal_pooling | |
| def resize(self, images): | |
| images = F.interpolate( | |
| images, | |
| size=(self.image_resolution, self.image_resolution), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return images | |
| def preproccess(self, images): | |
| if len(images.shape) == 5: | |
| # take the mean along the temporal dimension | |
| images = torch.mean(images, dim=2) | |
| images = rearrange(images, "b h w c -> b c h w") | |
| assert images.shape[1] == 13 | |
| return self.resize(images) # (bsz, C, H, W) | |
| def forward(self, s2=None, s1=None, months=None): | |
| if s2 is None: | |
| raise ValueError("S2 can't be None for SatMAE") | |
| if len(s2.shape) == 5: | |
| outputs_l: List[torch.Tensor] = [] | |
| for timestep in range(s2.shape[3]): | |
| image = self.preproccess(s2[:, :, :, timestep]) | |
| output = self.encoder.forward_features(image) | |
| # output shape for atto: (bsz, 320, 7, 7) | |
| # output shape for tiny: (bsz, 768, 6, 6) | |
| if self.do_pool: | |
| output = output.mean(dim=1) | |
| else: | |
| output = rearrange(output, "b (c_g l) d -> b l c_g d", c_g=3).mean(dim=-2) | |
| outputs_l.append(output) | |
| outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t | |
| if self.temporal_pooling == "mean": | |
| return outputs_t.mean(dim=-1) | |
| else: | |
| return torch.amax(outputs_t, dim=-1) | |
| else: | |
| s2 = self.preproccess(s2) | |
| output = self.encoder.forward_features(s2) | |
| if self.do_pool: | |
| return output.mean(dim=1) | |
| else: | |
| return rearrange(output, "b (c_g l) d -> b l c_g d", c_g=3).mean(dim=-2) | |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
| """ | |
| grid_size: int of the grid height and width | |
| return: | |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| grid_h = np.arange(grid_size, dtype=float) | |
| grid_w = np.arange(grid_size, dtype=float) | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size, grid_size]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token: | |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=float) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = torch.arange(embed_dim // 2, dtype=float, device=pos.device) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = torch.sin(out) # (M, D/2) | |
| emb_cos = torch.cos(out) # (M, D/2) | |
| emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) | |
| return emb.double() | |
| # -------------------------------------------------------- | |
| # Interpolate position embeddings for high-resolution | |
| # References: | |
| # DeiT: https://github.com/facebookresearch/deit | |
| # -------------------------------------------------------- | |
| def interpolate_pos_embed(model, checkpoint_model): | |
| if "pos_embed" in checkpoint_model: | |
| pos_embed_checkpoint = checkpoint_model["pos_embed"] | |
| embedding_size = pos_embed_checkpoint.shape[-1] | |
| try: | |
| num_patches = model.patch_embed.num_patches | |
| except AttributeError: | |
| num_patches = model.patch_embed[0].num_patches | |
| num_extra_tokens = model.pos_embed.shape[-2] - num_patches | |
| # height (== width) for the checkpoint position embedding | |
| orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) | |
| # height (== width) for the new position embedding | |
| new_size = int(num_patches**0.5) | |
| # class_token and dist_token are kept unchanged | |
| if orig_size != new_size: | |
| print( | |
| "Position interpolate from %dx%d to %dx%d" | |
| % (orig_size, orig_size, new_size, new_size) | |
| ) | |
| extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] | |
| # only the position tokens are interpolated | |
| pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] | |
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute( | |
| 0, 3, 1, 2 | |
| ) | |
| pos_tokens = torch.nn.functional.interpolate( | |
| pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False | |
| ) | |
| pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) | |
| new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) | |
| checkpoint_model["pos_embed"] = new_pos_embed | |
| return checkpoint_model | |
| class GroupChannelsVisionTransformer(VisionTransformer): | |
| """Vision Transformer with support for global average pooling""" | |
| def __init__( | |
| self, | |
| global_pool=False, | |
| channel_embed=256, | |
| channel_groups=((0, 1, 2, 6), (3, 4, 5, 7), (8, 9)), | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| img_size = kwargs["img_size"] | |
| patch_size = kwargs["patch_size"] | |
| embed_dim = kwargs["embed_dim"] | |
| self.channel_groups = channel_groups | |
| self.patch_embed = nn.ModuleList( | |
| [PatchEmbed(img_size, patch_size, len(group), embed_dim) for group in channel_groups] | |
| ) | |
| num_patches = self.patch_embed[0].num_patches | |
| # Positional and channel embed | |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim - channel_embed)) | |
| pos_embed = get_2d_sincos_pos_embed( | |
| self.pos_embed.shape[-1], int(num_patches**0.5), cls_token=True | |
| ) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| num_groups = len(channel_groups) | |
| self.channel_embed = nn.Parameter(torch.zeros(1, num_groups, channel_embed)) | |
| chan_embed = get_1d_sincos_pos_embed_from_grid( | |
| self.channel_embed.shape[-1], torch.arange(num_groups).numpy() | |
| ) | |
| self.channel_embed.data.copy_(torch.from_numpy(chan_embed).float().unsqueeze(0)) | |
| # Extra embedding for cls to fill embed_dim | |
| self.channel_cls_embed = nn.Parameter(torch.zeros(1, 1, channel_embed)) | |
| channel_cls_embed = torch.zeros((1, channel_embed)) | |
| self.channel_cls_embed.data.copy_(channel_cls_embed.float().unsqueeze(0)) | |
| self.global_pool = global_pool | |
| if self.global_pool: | |
| norm_layer = kwargs["norm_layer"] | |
| embed_dim = kwargs["embed_dim"] | |
| self.fc_norm = norm_layer(embed_dim) | |
| del self.norm # remove the original norm | |
| def forward_features(self, x): | |
| b, c, h, w = x.shape | |
| x_c_embed = [] | |
| for i, group in enumerate(self.channel_groups): | |
| x_c = x[:, group, :, :] | |
| x_c_embed.append(self.patch_embed[i](x_c)) # (N, L, D) | |
| x = torch.stack(x_c_embed, dim=1) # (N, G, L, D) | |
| _, G, L, D = x.shape | |
| # add channel embed | |
| channel_embed = self.channel_embed.unsqueeze(2) # (1, c, 1, cD) | |
| pos_embed = self.pos_embed[:, 1:, :].unsqueeze(1) # (1, 1, L, pD) | |
| # Channel embed same across (x,y) position, and pos embed same across channel (c) | |
| channel_embed = channel_embed.expand(-1, -1, pos_embed.shape[2], -1) # (1, c, L, cD) | |
| pos_embed = pos_embed.expand(-1, channel_embed.shape[1], -1, -1) # (1, c, L, pD) | |
| pos_channel = torch.cat((pos_embed, channel_embed), dim=-1) # (1, c, L, D) | |
| # add pos embed w/o cls token | |
| x = x + pos_channel # (N, G, L, D) | |
| x = x.view(b, -1, D) # (N, G*L, D) | |
| cls_pos_channel = torch.cat( | |
| (self.pos_embed[:, :1, :], self.channel_cls_embed), dim=-1 | |
| ) # (1, 1, D) | |
| cls_tokens = cls_pos_channel + self.cls_token.expand(b, -1, -1) | |
| x = torch.cat((cls_tokens, x), dim=1) # (N, 1 + c*L, D) | |
| x = self.pos_drop(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| return x[:, 1:, :] # remove cls token | |
| def vit_base(**kwargs): | |
| model = GroupChannelsVisionTransformer( | |
| channel_embed=256, | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs, | |
| ) | |
| return model | |
| def vit_large(**kwargs): | |
| model = GroupChannelsVisionTransformer( | |
| channel_embed=256, | |
| embed_dim=1024, | |
| depth=24, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
| **kwargs, | |
| ) | |
| return model | |