Spaces:
Sleeping
Sleeping
| import unittest | |
| import numpy as np | |
| from src.data.dataset import ( | |
| SPACE_BAND_GROUPS_IDX, | |
| SPACE_BANDS, | |
| SPACE_TIME_BANDS, | |
| SPACE_TIME_BANDS_GROUPS_IDX, | |
| STATIC_BAND_GROUPS_IDX, | |
| STATIC_BANDS, | |
| TIME_BAND_GROUPS_IDX, | |
| TIME_BANDS, | |
| Normalizer, | |
| ) | |
| from src.eval.cropharvest.cropharvest_eval import BANDS, BinaryCropHarvestEval | |
| class TestCropHarvest(unittest.TestCase): | |
| def test_to_galileo_arrays(self): | |
| for ignore_band_groups in [["S1"], None]: | |
| for include_latlons in [True, False]: | |
| class BinaryCropHarvestEvalNoDownload(BinaryCropHarvestEval): | |
| def __init__(self, ignore_band_groups, include_latlons: bool = True): | |
| self.include_latlons = include_latlons | |
| self.normalizer = Normalizer(std=False) | |
| self.ignore_s_t_band_groups = self.indices_of_ignored( | |
| SPACE_TIME_BANDS_GROUPS_IDX, ignore_band_groups | |
| ) | |
| self.ignore_sp_band_groups = self.indices_of_ignored( | |
| SPACE_BAND_GROUPS_IDX, ignore_band_groups | |
| ) | |
| self.ignore_t_band_groups = self.indices_of_ignored( | |
| TIME_BAND_GROUPS_IDX, ignore_band_groups | |
| ) | |
| self.ignore_st_band_groups = self.indices_of_ignored( | |
| STATIC_BAND_GROUPS_IDX, ignore_band_groups | |
| ) | |
| eval = BinaryCropHarvestEvalNoDownload(ignore_band_groups, include_latlons) | |
| b, t = 8, 12 | |
| array = np.ones((b, t, len(BANDS))) | |
| latlons = np.ones((b, 2)) | |
| ( | |
| s_t_x, | |
| sp_x, | |
| t_x, | |
| st_x, | |
| s_t_m, | |
| sp_m, | |
| t_m, | |
| st_m, | |
| months, | |
| ) = eval.cropharvest_array_to_normalized_galileo(array, latlons, start_month=1) | |
| self.assertEqual(s_t_x.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS))) | |
| self.assertEqual(s_t_m.shape, (b, 1, 1, t, len(SPACE_TIME_BANDS_GROUPS_IDX))) | |
| if ignore_band_groups is not None: | |
| # check s1 got masked | |
| self.assertTrue( | |
| ( | |
| s_t_m[:, :, :, :, list(SPACE_TIME_BANDS_GROUPS_IDX.keys()).index("S1")] | |
| == 1 | |
| ).all() | |
| ) | |
| self.assertEqual(sp_x.shape, (b, 1, 1, len(SPACE_BANDS))) | |
| self.assertEqual(sp_m.shape, (b, 1, 1, len(SPACE_BAND_GROUPS_IDX))) | |
| self.assertTrue( | |
| (sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("SRTM")] == 0).all() | |
| ) | |
| self.assertTrue( | |
| (sp_m[:, :, :, list(SPACE_BAND_GROUPS_IDX.keys()).index("DW")] == 1).all() | |
| ) | |
| self.assertEqual(t_x.shape, (b, t, len(TIME_BANDS))) | |
| self.assertEqual(t_m.shape, (b, t, len(TIME_BAND_GROUPS_IDX))) | |
| self.assertEqual(st_x.shape, (b, len(STATIC_BANDS))) | |
| self.assertEqual(st_m.shape, (b, len(STATIC_BAND_GROUPS_IDX))) | |
| self.assertEqual(months.shape, (b, t)) | |