Spaces:
Sleeping
Sleeping
File size: 3,907 Bytes
b20c769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from pathlib import Path
from typing import List
import satlaspretrain_models
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from satlaspretrain_models.utils import Backbone
class SatlasWrapper(nn.Module):
def __init__(
self, weights_path: Path, size="base", do_pool=True, temporal_pooling: str = "mean"
):
super().__init__()
if size == "base":
self.dim = 1024
weights = torch.load(
weights_path / "satlas-model-v1-lowres-band.pth", map_location="cpu"
)
self.satlas = satlaspretrain_models.Model(
num_channels=9,
multi_image=False,
backbone=Backbone.SWINB,
fpn=False,
head=None,
num_categories=None,
weights=weights,
)
elif size == "tiny":
self.dim = 768
weights = torch.load(weights_path / "sentinel2_swint_si_ms.pth", map_location="cpu")
self.satlas = satlaspretrain_models.Model(
num_channels=9,
multi_image=False,
backbone=Backbone.SWINT,
fpn=False,
head=None,
num_categories=None,
weights=weights,
)
else:
raise ValueError(f"size must be base or tiny, not {size}")
self.image_resolution = 512
self.grid_size = 16 # Swin spatially pools
self.do_pool = do_pool
if temporal_pooling not in ["mean", "max"]:
raise ValueError(
f"Expected temporal_pooling to be in ['mean', 'max'], got {temporal_pooling}"
)
self.temporal_pooling = temporal_pooling
def resize(self, images):
images = F.interpolate(
images,
size=(self.image_resolution, self.image_resolution),
mode="bilinear",
align_corners=False,
)
return images
def preproccess(self, images):
images = rearrange(images, "b h w c -> b c h w")
assert images.shape[1] == 13
# From: https://github.com/allenai/satlas/blob/main/Normalization.md
images = images[:, (1, 2, 3, 4, 5, 6, 0, 7, 8), :, :]
return self.resize(images) # (bsz, 12, 120, 120)
def forward(self, s2=None, s1=None, months=None):
if s2 is None:
raise ValueError("S2 can't be None for Satlas")
# not using the FPN
# we should get output shapes, for base:
# [[bsz, 128, 128, 128], [bsz, 256, 64, 64], [bsz, 512, 32, 32], [bsz, 1024, 16, 16]]
# and for tiny:
# [[bsz, 96, 128, 128], [bsz, 192, 64, 64], [bsz, 384, 32, 32], [bsz, 768, 16, 16]]
if len(s2.shape) == 5:
outputs_l: List[torch.Tensor] = []
for timestep in range(s2.shape[3]):
image = self.preproccess(s2[:, :, :, timestep])
output = self.satlas(image)
# output shape for atto: (bsz, 320, 7, 7)
# output shape for tiny: (bsz, 768, 6, 6)
if self.do_pool:
output = output[-1].mean(dim=-1).mean(dim=-1)
else:
output = rearrange(output[-1], "b c h w -> b (h w) c")
outputs_l.append(output)
outputs_t = torch.stack(outputs_l, dim=-1) # b h w d t
if self.temporal_pooling == "mean":
return outputs_t.mean(dim=-1)
else:
return torch.amax(outputs_t, dim=-1)
else:
s2 = self.preproccess(s2)
output = self.satlas(s2)
if self.do_pool:
return output[-1].mean(dim=-1).mean(dim=-1) # (bsz, dim)
else:
return rearrange(output[-1], "b c h w -> b (h w) c") # (bsz, seq_len, dim)
|