Spaces:
Sleeping
Sleeping
| import unittest | |
| import torch | |
| from src.data.dataset import ( | |
| SPACE_BAND_GROUPS_IDX, | |
| SPACE_TIME_BANDS_GROUPS_IDX, | |
| STATIC_BAND_GROUPS_IDX, | |
| TIME_BAND_GROUPS_IDX, | |
| ) | |
| from src.loss import mae_loss | |
| class TestLoss(unittest.TestCase): | |
| def test_mae_loss(self): | |
| b, t_h, t_w, t, patch_size = 16, 4, 4, 3, 2 | |
| pixel_h, pixel_w = t_h * patch_size, t_w * patch_size | |
| max_patch_size = 8 | |
| max_group_length = max( | |
| [ | |
| max([len(v) for _, v in SPACE_TIME_BANDS_GROUPS_IDX.items()]), | |
| max([len(v) for _, v in TIME_BAND_GROUPS_IDX.items()]), | |
| max([len(v) for _, v in SPACE_BAND_GROUPS_IDX.items()]), | |
| max([len(v) for _, v in STATIC_BAND_GROUPS_IDX.items()]), | |
| ] | |
| ) | |
| p_s_t = torch.randn( | |
| ( | |
| b, | |
| t_h, | |
| t_w, | |
| t, | |
| len(SPACE_TIME_BANDS_GROUPS_IDX), | |
| max_group_length * (max_patch_size**2), | |
| ) | |
| ) | |
| p_sp = torch.randn( | |
| (b, t_h, t_w, len(SPACE_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2)) | |
| ) | |
| p_t = torch.randn( | |
| (b, t, len(TIME_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2)) | |
| ) | |
| p_st = torch.randn( | |
| (b, len(STATIC_BAND_GROUPS_IDX), max_group_length * (max_patch_size**2)) | |
| ) | |
| s_t_x = torch.randn( | |
| b, pixel_h, pixel_w, t, sum([len(x) for _, x in SPACE_TIME_BANDS_GROUPS_IDX.items()]) | |
| ) | |
| sp_x = torch.randn( | |
| b, pixel_h, pixel_w, sum([len(x) for _, x in SPACE_BAND_GROUPS_IDX.items()]) | |
| ) | |
| t_x = torch.randn(b, t, sum([len(x) for _, x in TIME_BAND_GROUPS_IDX.items()])) | |
| st_x = torch.randn(b, sum([len(x) for _, x in STATIC_BAND_GROUPS_IDX.items()])) | |
| s_t_m = torch.ones((b, pixel_h, pixel_w, t, len(SPACE_TIME_BANDS_GROUPS_IDX))) * 2 | |
| sp_m = torch.ones((b, pixel_h, pixel_w, len(SPACE_BAND_GROUPS_IDX))) * 2 | |
| t_m = torch.ones((b, t, len(TIME_BAND_GROUPS_IDX))) * 2 | |
| st_m = torch.ones((b, len(STATIC_BAND_GROUPS_IDX))) * 2 | |
| max_patch_size = 8 | |
| loss = mae_loss( | |
| p_s_t, | |
| p_sp, | |
| p_t, | |
| p_st, | |
| s_t_x, | |
| sp_x, | |
| t_x, | |
| st_x, | |
| s_t_m, | |
| sp_m, | |
| t_m, | |
| st_m, | |
| patch_size, | |
| max_patch_size, | |
| ) | |
| self.assertFalse(torch.isnan(loss)) | |