Spaces:
Sleeping
Sleeping
| import unittest | |
| from functools import partial | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from src.collate_fns import galileo_collate_fn | |
| from src.data import Dataset | |
| from src.galileo import Decoder, Encoder | |
| from src.loss import mse_loss | |
| from src.utils import device | |
| DATA_FOLDER = Path(__file__).parents[1] / "data/tifs" | |
| class TestEndtoEnd(unittest.TestCase): | |
| def test_end_to_end(self): | |
| self._test_end_to_end() | |
| def _test_end_to_end(self): | |
| embedding_size = 32 | |
| dataset = Dataset(DATA_FOLDER, download=False, h5py_folder=None) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=True, | |
| num_workers=0, | |
| collate_fn=partial( | |
| galileo_collate_fn, | |
| patch_sizes=[1, 2, 3, 4, 5, 6, 7, 8], | |
| shape_time_combinations=[ | |
| {"size": 4, "timesteps": 12}, | |
| {"size": 5, "timesteps": 6}, | |
| {"size": 6, "timesteps": 4}, | |
| {"size": 7, "timesteps": 3}, | |
| {"size": 9, "timesteps": 3}, | |
| {"size": 12, "timesteps": 3}, | |
| ], | |
| st_encode_ratio=0.25, | |
| st_decode_ratio=0.25, | |
| random_encode_ratio=0.25, | |
| random_decode_ratio=0.25, | |
| random_masking="half", | |
| ), | |
| pin_memory=True, | |
| ) | |
| encoder = Encoder(embedding_size=embedding_size, num_heads=1).to(device) | |
| predictor = Decoder( | |
| encoder_embedding_size=embedding_size, | |
| decoder_embedding_size=embedding_size, | |
| num_heads=1, | |
| learnable_channel_embeddings=False, | |
| ).to(device) | |
| param_groups = [{"params": encoder.parameters()}, {"params": predictor.parameters()}] | |
| optimizer = torch.optim.AdamW(param_groups, lr=3e-4) # type: ignore | |
| # let's just consider one of the augmentations | |
| for _, bs in enumerate(dataloader): | |
| b = bs[0] | |
| for x in b: | |
| if isinstance(x, torch.Tensor): | |
| self.assertFalse(torch.isnan(x).any()) | |
| b = [t.to(device) if isinstance(t, torch.Tensor) else t for t in b] | |
| ( | |
| s_t_x, | |
| sp_x, | |
| t_x, | |
| st_x, | |
| s_t_m, | |
| sp_m, | |
| t_m, | |
| st_m, | |
| months, | |
| patch_size, | |
| ) = b | |
| # no autocast since its poorly supported on CPU | |
| (p_s_t, p_sp, p_t, p_st) = predictor( | |
| *encoder( | |
| s_t_x=s_t_x.float(), | |
| sp_x=sp_x.float(), | |
| t_x=t_x.float(), | |
| st_x=st_x.float(), | |
| s_t_m=s_t_m.int(), | |
| sp_m=sp_m.int(), | |
| t_m=t_m.int(), | |
| st_m=st_m.int(), | |
| months=months.long(), | |
| patch_size=patch_size, | |
| ), | |
| patch_size=patch_size, | |
| ) | |
| with torch.no_grad(): | |
| t_s_t, t_sp, t_t, t_st, _, _, _, _ = encoder.apply_linear_projection( | |
| s_t_x.float(), | |
| sp_x.float(), | |
| t_x.float(), | |
| st_x.float(), | |
| ~(s_t_m == 2).int(), # we want 0s where the mask == 2 | |
| ~(sp_m == 2).int(), | |
| ~(t_m == 2).int(), | |
| ~(st_m == 2).int(), | |
| patch_size, | |
| ) | |
| t_s_t = encoder.blocks[0].norm1(t_s_t) | |
| t_sp = encoder.blocks[0].norm1(t_sp) | |
| t_sp = encoder.blocks[0].norm1(t_sp) | |
| t_st = encoder.blocks[0].norm1(t_st) | |
| # commenting out because this fails on the github runner. It doesn't fail locally | |
| # or cause problems when running experiments. | |
| # self.assertFalse(torch.isnan(p_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
| # self.assertFalse(torch.isnan(p_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
| # self.assertFalse(torch.isnan(p_t[t_m == 2]).any()) | |
| # self.assertFalse(torch.isnan(p_st[st_m == 2]).any()) | |
| # self.assertFalse(torch.isnan(t_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
| # self.assertFalse(torch.isnan(t_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2]).any()) | |
| # self.assertFalse(torch.isnan(t_t[t_m == 2]).any()) | |
| # self.assertFalse(torch.isnan(t_st[st_m == 2]).any()) | |
| self.assertTrue( | |
| len( | |
| torch.concat( | |
| [ | |
| p_s_t[s_t_m[:, 0::patch_size, 0::patch_size] == 2], | |
| p_sp[sp_m[:, 0::patch_size, 0::patch_size] == 2], | |
| p_t[t_m == 2], | |
| p_st[st_m == 2], | |
| ] | |
| ) | |
| > 0 | |
| ) | |
| ) | |
| loss = mse_loss( | |
| t_s_t, | |
| t_sp, | |
| t_t, | |
| t_st, | |
| p_s_t, | |
| p_sp, | |
| p_t, | |
| p_st, | |
| s_t_m[:, 0::patch_size, 0::patch_size], | |
| sp_m[:, 0::patch_size, 0::patch_size], | |
| t_m, | |
| st_m, | |
| ) | |
| # this also only fails on the runner | |
| # self.assertFalse(torch.isnan(loss).any()) | |
| loss.backward() | |
| optimizer.step() | |
| # check the channel embeddings in the decoder didn't change | |
| self.assertTrue( | |
| torch.equal( | |
| predictor.s_t_channel_embed, torch.zeros_like(predictor.s_t_channel_embed) | |
| ) | |
| ) | |
| self.assertTrue( | |
| torch.equal( | |
| predictor.sp_channel_embed, torch.zeros_like(predictor.sp_channel_embed) | |
| ) | |
| ) | |
| self.assertTrue( | |
| torch.equal(predictor.t_channel_embed, torch.zeros_like(predictor.t_channel_embed)) | |
| ) | |
| self.assertTrue( | |
| torch.equal( | |
| predictor.st_channel_embed, torch.zeros_like(predictor.st_channel_embed) | |
| ) | |
| ) | |