Spaces:
Sleeping
Sleeping
| import itertools | |
| import math | |
| import warnings | |
| from pathlib import Path | |
| from typing import List | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from torch import einsum, nn | |
| class CROMAWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| weights_path: Path, | |
| size="base", | |
| modality="optical", | |
| do_pool=True, | |
| temporal_pooling: str = "mean", | |
| ): | |
| super().__init__() | |
| assert modality in ["SAR", "optical"] | |
| if size == "base": | |
| self.croma = PretrainedCROMA( | |
| str(weights_path / "CROMA_base.pt"), size, modality=modality, image_resolution=120 | |
| ) | |
| self.dim = 768 | |
| elif size == "large": | |
| self.croma = PretrainedCROMA( | |
| str(weights_path / "CROMA_large.pt"), size, modality=modality, image_resolution=120 | |
| ) | |
| self.dim = 1024 | |
| else: | |
| raise ValueError(f"size must be base or large, not {size}") | |
| self.image_resolution = 120 | |
| self.patch_size = 8 | |
| self.grid_size = int(self.image_resolution / self.patch_size) | |
| self.do_pool = do_pool | |
| if temporal_pooling not in ["mean", "max"]: | |
| raise ValueError( | |
| f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}" | |
| ) | |
| self.temporal_pooling = temporal_pooling | |
| def resize(self, images): | |
| images = F.interpolate( | |
| images, | |
| size=(self.image_resolution, self.image_resolution), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return images | |
| def preproccess(self, images): | |
| images = rearrange(images, "b h w c -> b c h w") | |
| assert images.shape[1] == 13 | |
| # remove cirrus | |
| remove_idx = 10 | |
| images = torch.cat( | |
| [images[:, :remove_idx, :, :], images[:, (remove_idx + 1) :, :, :]], dim=1 | |
| ) | |
| assert images.shape[1] == 12 | |
| return self.resize(images) # (bsz, 12, 120, 120) | |
| def preproccess_s1(self, images): | |
| images = rearrange(images, "b h w c -> b c h w") | |
| assert images.shape[1] == 2 | |
| return self.resize(images) # (bsz, 2, 120, 120) | |
| def forward(self, s2=None, s1=None, months=None): | |
| output_key = "optical_GAP" if self.do_pool else "optical_encodings" | |
| if s1 is not None: | |
| assert s2 is None, "joint s2 and s1 not implemented for CROMA" | |
| if len(s1.shape) == 5: | |
| outputs: List[torch.Tensor] = [] | |
| for timestep in range(s1.shape[3]): | |
| image = self.preproccess_s1(s1[:, :, :, timestep]) | |
| outputs.append(self.croma(SAR_images=image)[output_key]) | |
| outputs_t = torch.stack(outputs, dim=-1) # b h w d t | |
| if self.temporal_pooling == "mean": | |
| return outputs_t.mean(dim=-1) | |
| else: | |
| return torch.amax(outputs_t, dim=-1) | |
| else: | |
| s1 = self.preproccess_s1(s1) | |
| return self.croma(SAR_images=s1)[output_key] | |
| else: | |
| # just S2 | |
| if len(s2.shape) == 5: | |
| outputs: List[torch.Tensor] = [] | |
| for timestep in range(s2.shape[3]): | |
| image = self.preproccess(s2[:, :, :, timestep]) | |
| outputs.append(self.croma(optical_images=image)[output_key]) | |
| outputs_t = torch.stack(outputs, dim=-1) # b h w d t | |
| if self.temporal_pooling == "mean": | |
| return outputs_t.mean(dim=-1) | |
| else: | |
| return torch.amax(outputs_t, dim=-1) | |
| else: | |
| s2 = self.preproccess(s2) | |
| return self.croma(optical_images=s2)[output_key] | |
| class PretrainedCROMA(nn.Module): | |
| def __init__( | |
| self, pretrained_path="CROMA_base.pt", size="base", modality="both", image_resolution=120 | |
| ): | |
| """ | |
| NOTE: image_resolution is not the spatial, spectral, or temporal resolution. It is the height and width of the image, in pixels. | |
| E.g., CROMA was pretrained on 120x120px images, hence image_resolution is 120 by default | |
| """ | |
| super().__init__() | |
| # check types | |
| assert isinstance(pretrained_path, str) | |
| assert isinstance(size, str) | |
| assert isinstance(modality, str) | |
| assert isinstance(image_resolution, int) | |
| # check values | |
| assert size in ["base", "large"], f"size must be either base or large, not {size}" | |
| assert ( | |
| image_resolution % 8 == 0 | |
| ), f"image_resolution must be a multiple of 8, not {image_resolution}" | |
| assert modality in [ | |
| "both", | |
| "SAR", | |
| "optical", | |
| ], f"modality must be either both, SAR, or optical, not {modality}" | |
| # warn the user if the path contains a different size than the size parameter | |
| if size == "base" and "large" in pretrained_path: | |
| warnings.warn( | |
| "The size is set to base, but the word large appears in the pretrained path!" | |
| ) | |
| elif size == "large" and "base" in pretrained_path: | |
| warnings.warn( | |
| "The size is set to large, but the word base appears in the pretrained path!" | |
| ) | |
| if size == "base": | |
| self.encoder_dim = 768 | |
| self.encoder_depth = 12 | |
| self.num_heads = 16 | |
| self.patch_size = 8 | |
| else: | |
| # large by default | |
| self.encoder_dim = 1024 | |
| self.encoder_depth = 24 | |
| self.num_heads = 16 | |
| self.patch_size = 8 | |
| self.modality = modality | |
| self.num_patches = int((image_resolution / 8) ** 2) | |
| self.s1_channels = 2 # fixed at 2 SAR backscatter channels | |
| self.s2_channels = 12 # fixed at 12 multispectral optical channels | |
| self.attn_bias = get_2dalibi(num_heads=self.num_heads, num_patches=self.num_patches) | |
| if modality in ["SAR", "both"]: | |
| print("Initializing SAR encoder") | |
| self.s1_encoder = ViT( | |
| dim=self.encoder_dim, | |
| depth=int(self.encoder_depth / 2), | |
| in_channels=self.s1_channels, | |
| ) | |
| self.GAP_FFN_s1 = nn.Sequential( | |
| nn.LayerNorm(self.encoder_dim), | |
| nn.Linear( | |
| self.encoder_dim, int(4 * self.encoder_dim) | |
| ), # (BSZ, num_patches, inner_dim) | |
| nn.GELU(), # (BSZ, num_patches, inner_dim) | |
| nn.Linear(int(4 * self.encoder_dim), self.encoder_dim), # (BSZ, num_patches, dim) | |
| ) | |
| # load weights | |
| self.s1_encoder.load_state_dict( | |
| torch.load(pretrained_path, map_location="cpu")["s1_encoder"] | |
| ) | |
| self.GAP_FFN_s1.load_state_dict( | |
| torch.load(pretrained_path, map_location="cpu")["s1_GAP_FFN"] | |
| ) | |
| if modality in ["optical", "both"]: | |
| print("Initializing optical encoder") | |
| self.s2_encoder = ViT( | |
| dim=self.encoder_dim, depth=self.encoder_depth, in_channels=self.s2_channels | |
| ) | |
| self.GAP_FFN_s2 = nn.Sequential( | |
| nn.LayerNorm(self.encoder_dim), | |
| nn.Linear( | |
| self.encoder_dim, int(4 * self.encoder_dim) | |
| ), # (BSZ, num_patches, inner_dim) | |
| nn.GELU(), # (BSZ, num_patches, inner_dim) | |
| nn.Linear(int(4 * self.encoder_dim), self.encoder_dim), # (BSZ, num_patches, dim) | |
| ) | |
| # load weights | |
| self.s2_encoder.load_state_dict( | |
| torch.load(pretrained_path, map_location="cpu")["s2_encoder"] | |
| ) | |
| self.GAP_FFN_s2.load_state_dict( | |
| torch.load(pretrained_path, map_location="cpu")["s2_GAP_FFN"] | |
| ) | |
| if modality == "both": | |
| print("Initializing joint SAR-optical encoder") | |
| self.cross_encoder = BaseTransformerCrossAttn( | |
| dim=self.encoder_dim, | |
| depth=int(self.encoder_depth / 2), | |
| num_heads=self.num_heads, | |
| ) | |
| # load weights | |
| self.cross_encoder.load_state_dict( | |
| torch.load(pretrained_path, map_location="cpu")["joint_encoder"] | |
| ) | |
| def forward(self, SAR_images=None, optical_images=None): | |
| return_dict = {} | |
| if self.modality in ["SAR", "both"]: | |
| assert ( | |
| SAR_images is not None | |
| ), f"Modality is set to {self.modality}, but SAR_images are None" | |
| SAR_encodings = self.s1_encoder( | |
| imgs=SAR_images, attn_bias=self.attn_bias.to(SAR_images.device) | |
| ) # (bsz, num_patches, encoder_dim) | |
| SAR_GAP = self.GAP_FFN_s1(SAR_encodings.mean(dim=1)) # (bsz, encoder_dim) | |
| return_dict["SAR_encodings"] = SAR_encodings | |
| return_dict["SAR_GAP"] = SAR_GAP | |
| if self.modality in ["optical", "both"]: | |
| assert ( | |
| optical_images is not None | |
| ), f"Modality is set to {self.modality}, but optical_images are None" | |
| optical_encodings = self.s2_encoder( | |
| imgs=optical_images, attn_bias=self.attn_bias.to(optical_images.device) | |
| ) # (bsz, num_patches, encoder_dim) | |
| optical_GAP = self.GAP_FFN_s2(optical_encodings.mean(dim=1)) # (bsz, encoder_dim) | |
| return_dict["optical_encodings"] = optical_encodings | |
| return_dict["optical_GAP"] = optical_GAP | |
| if self.modality == "both": | |
| joint_encodings = self.cross_encoder( | |
| x=SAR_encodings, | |
| context=optical_encodings, | |
| relative_position_bias=self.attn_bias.to(optical_images.device), | |
| ) # (bsz, num_patches, encoder_dim) | |
| joint_GAP = joint_encodings.mean(dim=1) # (bsz, encoder_dim) | |
| return_dict["joint_encodings"] = joint_encodings | |
| return_dict["joint_GAP"] = joint_GAP | |
| return return_dict | |
| def get_2dalibi(num_heads, num_patches): | |
| # inspired by: https://github.com/ofirpress/attention_with_linear_biases | |
| points = list( | |
| itertools.product(range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches)))) | |
| ) | |
| def get_slopes(n): | |
| def get_slopes_power_of_2(n): | |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) | |
| ratio = start | |
| return [start * ratio**i for i in range(n)] | |
| if math.log2(n).is_integer(): | |
| return get_slopes_power_of_2(n) | |
| else: | |
| closest_power_of_2 = 2 ** math.floor(math.log2(n)) | |
| return ( | |
| get_slopes_power_of_2(closest_power_of_2) | |
| + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] | |
| ) | |
| slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1) | |
| idxs = [] | |
| for p1 in points: | |
| for p2 in points: | |
| dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) | |
| idxs.append(dist * slopes * -1) | |
| all_bias = torch.cat(idxs, dim=1) | |
| return all_bias.view(1, num_heads, num_patches, num_patches) | |
| class FFN(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| mult=4, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| inner_dim = int(dim * mult) | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, inner_dim), # (BSZ, num_patches, inner_dim) | |
| nn.GELU(), # (BSZ, num_patches, inner_dim) | |
| nn.Dropout(dropout), # (BSZ, num_patches, inner_dim) | |
| nn.Linear(inner_dim, dim), # (BSZ, num_patches, dim) | |
| ) | |
| self.input_norm = nn.LayerNorm(dim) | |
| def forward(self, x): | |
| x = self.input_norm(x) # (BSZ, num_patches, dim) | |
| return self.net(x) # (BSZ, num_patches, dim) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| assert dim % num_heads == 0, "dim must be evenly divisible by num_heads" | |
| dim_head = int(dim / num_heads) | |
| self.scale = dim_head**-0.5 | |
| self.to_qkv = nn.Linear(dim, dim * 3, bias=False) | |
| self.to_out = nn.Linear(dim, dim) | |
| self.input_norm = nn.LayerNorm(dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, relative_position_bias): | |
| x = self.input_norm(x) # (BSZ, num_patches, dim) | |
| q, k, v = self.to_qkv(x).chunk(3, dim=-1) # (BSZ, num_patches, dim) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) | |
| ) # (BSZ, num_heads, num_patches, dim_head) | |
| attention_scores = ( | |
| einsum("b h i d, b h j d -> b h i j", q, k) * self.scale | |
| ) # (BSZ, num_heads, num_patches, num_patches) | |
| attention_scores = ( | |
| attention_scores + relative_position_bias | |
| ) # (BSZ, num_heads, num_patches, num_patches) | |
| attn = attention_scores.softmax(dim=-1) # (BSZ, num_heads, num_patches, num_patches) | |
| attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches) | |
| out = einsum( | |
| "b h i j, b h j d -> b h i d", attn, v | |
| ) # (BSZ, num_heads, num_patches, dim_head) | |
| out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim) | |
| return self.to_out(out) # (BSZ, num_patches, dim) | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=8, | |
| dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| assert dim % num_heads == 0, "dim must be evenly divisible by num_heads" | |
| dim_head = int(dim / num_heads) | |
| self.scale = dim_head**-0.5 | |
| self.to_q = nn.Linear(dim, dim, bias=False) | |
| self.to_k = nn.Linear(dim, dim, bias=False) | |
| self.to_v = nn.Linear(dim, dim, bias=False) | |
| self.to_out = nn.Linear(dim, dim) | |
| self.input_norm = nn.LayerNorm(dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, context, relative_position_bias): | |
| x = self.input_norm(x) # (BSZ, num_patches, dim) | |
| context = self.input_norm(context) # (BSZ, num_patches, dim) | |
| q = self.to_q(x) # (BSZ, num_patches, dim) | |
| k = self.to_k(context) # (BSZ, num_patches, dim) | |
| v = self.to_v(context) # (BSZ, num_patches, dim) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v) | |
| ) # (BSZ, num_heads, num_patches, dim_head) | |
| attention_scores = ( | |
| einsum("b h i d, b h j d -> b h i j", q, k) * self.scale | |
| ) # (BSZ, num_heads, num_patches, num_patches) | |
| attention_scores = ( | |
| attention_scores + relative_position_bias | |
| ) # (BSZ, num_heads, num_patches, num_patches) | |
| attn = attention_scores.softmax(dim=-1) # (BSZ, num_heads, num_patches, num_patches) | |
| attn = self.dropout(attn) # (BSZ, num_heads, num_patches, num_patches) | |
| out = einsum( | |
| "b h i j, b h j d -> b h i d", attn, v | |
| ) # (BSZ, num_heads, num_patches, dim_head) | |
| out = rearrange(out, "b h n d -> b n (h d)") # (BSZ, num_patches, dim) | |
| return self.to_out(out) # (BSZ, num_patches, dim) | |
| class BaseTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| depth, | |
| num_heads=8, | |
| attn_dropout=0.0, | |
| ff_dropout=0.0, | |
| ff_mult=4, | |
| final_norm=True, | |
| ): | |
| super().__init__() | |
| self.final_norm = final_norm | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), | |
| FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), | |
| ] | |
| ) | |
| ) | |
| if self.final_norm: | |
| self.norm_out = nn.LayerNorm(dim) | |
| def forward(self, x, relative_position_bias=False): | |
| for self_attn, ffn in self.layers: | |
| x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim) | |
| x = ffn(x) + x # (BSZ, num_patches, dim) | |
| if self.final_norm: | |
| return self.norm_out(x) | |
| else: | |
| return x | |
| class BaseTransformerCrossAttn(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| depth, | |
| num_heads=8, | |
| attn_dropout=0.0, | |
| ff_dropout=0.0, | |
| ff_mult=4, | |
| ): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), | |
| CrossAttention(dim=dim, num_heads=num_heads, dropout=attn_dropout), | |
| FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), | |
| ] | |
| ) | |
| ) | |
| self.norm_out = nn.LayerNorm(dim) | |
| def forward(self, x, context, relative_position_bias): | |
| for self_attn, cross_attn, ffn in self.layers: | |
| x = self_attn(x, relative_position_bias) + x # (BSZ, num_patches, dim) | |
| x = cross_attn(x, context, relative_position_bias) + x # (BSZ, num_patches, dim) | |
| x = ffn(x) + x # (BSZ, num_patches, dim) | |
| x = self.norm_out(x) | |
| return x # (BSZ, num_patches, dim) | |
| class ViT(nn.Module): | |
| def __init__(self, dim, depth, in_channels): | |
| super().__init__() | |
| self.depth = depth | |
| self.in_channels = in_channels | |
| self.dim = dim | |
| self.num_heads = 16 # always 16, for base and large models | |
| self.patch_size = 8 # always 8, for base and large models | |
| pixels_per_patch = int(self.patch_size * self.patch_size * in_channels) | |
| self.linear_input = nn.Linear(pixels_per_patch, self.dim) | |
| self.transformer = BaseTransformer( | |
| dim=self.dim, | |
| depth=self.depth, | |
| num_heads=self.num_heads, | |
| ) | |
| def forward(self, imgs, attn_bias): | |
| x = rearrange( | |
| imgs, "b c (h i) (w j) -> b (h w) (c i j)", i=self.patch_size, j=self.patch_size | |
| ) | |
| # x is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size) | |
| x = self.linear_input(x) # (bsz, num_patches, dim) | |
| x = self.transformer(x, relative_position_bias=attn_bias) | |
| return x | |