Spaces:
Sleeping
Sleeping
| import logging | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Sequence, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from einops import repeat | |
| from sklearn.base import BaseEstimator, clone | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.utils import shuffle | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from src.eval.cropharvest.columns import NullableColumns, RequiredColumns | |
| from src.eval.cropharvest.cropharvest_eval import Hyperparams | |
| from src.eval.cropharvest.datasets import CropHarvest, Task, TestInstance | |
| from src.eval.cropharvest.datasets import CropHarvestLabels as OrgCropHarvestLabels | |
| from src.eval.cropharvest.utils import NoDataForBoundingBoxError, memoized | |
| from src.utils import DEFAULT_SEED, data_dir, device | |
| from .single_file_presto import NUM_DYNAMIC_WORLD_CLASSES, PRESTO_ADD_BY, PRESTO_DIV_BY, Encoder | |
| logger = logging.getLogger("__main__") | |
| cropharvest_data_dir = data_dir / "cropharvest_data" | |
| class PrestoNormalizer: | |
| # these are the bands we will replace with the 2*std computation | |
| # if std = True | |
| def __init__(self, std_multiplier: float = 1): | |
| self.std_multiplier = std_multiplier | |
| # add by -> subtract by | |
| self.shift_values = np.array(PRESTO_ADD_BY) * -1 | |
| self.div_values = np.array(PRESTO_DIV_BY) * std_multiplier | |
| def _normalize(x: np.ndarray, shift_values: np.ndarray, div_values: np.ndarray) -> np.ndarray: | |
| x = (x - shift_values) / div_values | |
| return x | |
| def __call__(self, x: np.ndarray): | |
| return self._normalize(x, self.shift_values, self.div_values) | |
| class CropHarvestLabels(OrgCropHarvestLabels): | |
| def construct_fao_classification_labels( | |
| self, task: Task, filter_test: bool = True | |
| ) -> List[Tuple[Path, int]]: | |
| gpdf = self.as_geojson() | |
| if filter_test: | |
| gpdf = gpdf[gpdf[RequiredColumns.IS_TEST] == False] # noqa | |
| if task.bounding_box is not None: | |
| gpdf = self.filter_geojson( | |
| gpdf, task.bounding_box, task.include_externally_contributed_labels | |
| ) | |
| # This should probably be a required column since it has no | |
| # None values (and shouldn't have any) | |
| gpdf = gpdf[~gpdf[NullableColumns.CLASSIFICATION_LABEL].isnull()] | |
| if len(gpdf) == 0: | |
| raise NoDataForBoundingBoxError | |
| ys = gpdf[NullableColumns.CLASSIFICATION_LABEL] | |
| paths = self._dataframe_to_paths(gpdf) | |
| return [(path, y) for path, y in zip(paths, ys) if path.exists()] | |
| def get_eval_datasets(): | |
| return CropHarvest.create_benchmark_datasets( | |
| root=cropharvest_data_dir, balance_negative_crops=False, normalize=False | |
| ) | |
| def download_cropharvest_data(root_name: str = ""): | |
| root = Path(root_name) if root_name != "" else cropharvest_data_dir | |
| if not root.exists(): | |
| root.mkdir() | |
| CropHarvest(root, download=True) | |
| class BinaryCropHarvestEval: | |
| start_month = 1 | |
| num_outputs = 1 | |
| country_to_sizes: Dict[str, List] = { | |
| "Kenya": [20, 32, 64, 96, 128, 160, 192, 224, 256, None], | |
| "Togo": [20, 50, 126, 254, 382, 508, 636, 764, 892, 1020, 1148, None], | |
| } | |
| all_classification_sklearn_models = ["LogisticRegression"] | |
| def __init__( | |
| self, | |
| country: str, | |
| normalizer: PrestoNormalizer, | |
| num_timesteps: Optional[int] = None, | |
| sample_size: Optional[int] = None, | |
| seed: int = DEFAULT_SEED, | |
| include_latlons: bool = True, | |
| eval_mode: str = "test", | |
| ): | |
| if eval_mode == "val": | |
| assert country in list(self.country_to_sizes.keys()) | |
| self.eval_mode = eval_mode | |
| suffix = f"_{sample_size}" if sample_size else "" | |
| suffix = f"{suffix}_{num_timesteps}" if num_timesteps is not None else suffix | |
| self.include_latlons = include_latlons | |
| self.name = f"CropHarvest_{country}{suffix}{'_latlons' if include_latlons else ''}" | |
| self.seed = seed | |
| download_cropharvest_data() | |
| evaluation_datasets = get_eval_datasets() | |
| evaluation_datasets = [d for d in evaluation_datasets if country in d.id] | |
| assert len(evaluation_datasets) == 1 | |
| self.dataset: CropHarvest = evaluation_datasets[0] | |
| assert self.dataset.task.normalize is False | |
| self.num_timesteps = num_timesteps | |
| self.sample_size = sample_size | |
| self.normalize = normalizer | |
| def truncate_timesteps(x, num_timesteps: Optional[int] = None): | |
| if (num_timesteps is None) or (x is None): | |
| return x | |
| else: | |
| return x[:, :num_timesteps] | |
| def _mask_to_batch_tensor( | |
| mask: Optional[np.ndarray], batch_size: int | |
| ) -> Optional[torch.Tensor]: | |
| if mask is not None: | |
| return repeat(torch.from_numpy(mask).to(device), "t c -> b t c", b=batch_size).float() | |
| return None | |
| def _evaluate_model( | |
| self, | |
| pretrained_model: Encoder, | |
| sklearn_model: BaseEstimator, | |
| ) -> Dict: | |
| pretrained_model.eval() | |
| with tempfile.TemporaryDirectory() as results_dir: | |
| for test_id, test_instance in self.dataset.test_data(max_size=10000): | |
| savepath = Path(results_dir) / f"{test_id}.nc" | |
| test_x = self.truncate_timesteps( | |
| torch.from_numpy(self.normalize(test_instance.x)).to(device).float() # type: ignore | |
| ) | |
| # mypy fails with these lines uncommented, but this is how we will | |
| # pass the other values to the model | |
| test_latlons_np = np.stack([test_instance.lats, test_instance.lons], axis=-1) | |
| test_latlon = torch.from_numpy(test_latlons_np).to(device).float() | |
| # mask out DW | |
| test_dw = self.truncate_timesteps( | |
| torch.ones_like(test_x[:, :, 0]).to(device).long() * NUM_DYNAMIC_WORLD_CLASSES | |
| ) | |
| batch_mask = self.truncate_timesteps( | |
| self._mask_to_batch_tensor(None, test_x.shape[0]) | |
| ) | |
| encodings = ( | |
| pretrained_model( | |
| test_x, | |
| dynamic_world=test_dw, | |
| mask=batch_mask, | |
| latlons=test_latlon, | |
| month=self.start_month, | |
| ) | |
| .cpu() | |
| .numpy() | |
| ) | |
| preds = sklearn_model.predict_proba(encodings)[:, 1] | |
| ds = test_instance.to_xarray(preds) | |
| ds.to_netcdf(savepath) | |
| all_nc_files = list(Path(results_dir).glob("*.nc")) | |
| combined_instance, combined_preds = TestInstance.load_from_nc(all_nc_files) | |
| combined_results = combined_instance.evaluate_predictions(combined_preds) | |
| prefix = sklearn_model.__class__.__name__ | |
| return {f"{self.name}: {prefix}_{key}": val for key, val in combined_results.items()} | |
| def finetune_sklearn_model( | |
| self, | |
| dl: DataLoader, | |
| pretrained_model: Encoder, | |
| models: List[str] = ["LogisticRegression"], | |
| ) -> Union[Sequence[BaseEstimator], Dict]: | |
| for model_mode in models: | |
| assert model_mode in ["LogisticRegression"] | |
| pretrained_model.eval() | |
| encoding_list, target_list = [], [] | |
| for x, y, dw, latlons, month in dl: | |
| x, dw, latlons, y, month = [t.to(device) for t in (x, dw, latlons, y, month)] | |
| batch_mask = self._mask_to_batch_tensor(None, x.shape[0]) | |
| target_list.append(y.cpu().numpy()) | |
| with torch.no_grad(): | |
| encodings = ( | |
| pretrained_model( | |
| x, dynamic_world=dw, mask=batch_mask, latlons=latlons, month=month | |
| ) | |
| .cpu() | |
| .numpy() | |
| ) | |
| encoding_list.append(encodings) | |
| encodings_np = np.concatenate(encoding_list) | |
| targets = np.concatenate(target_list) | |
| if len(targets.shape) == 2 and targets.shape[1] == 1: | |
| targets = targets.ravel() | |
| fit_models = [] | |
| model_dict = { | |
| "LogisticRegression": LogisticRegression( | |
| class_weight="balanced", max_iter=1000, random_state=self.seed | |
| ) | |
| } | |
| for model in models: | |
| fit_models.append(clone(model_dict[model]).fit(encodings_np, targets)) | |
| return fit_models | |
| def random_subset( | |
| array: np.ndarray, latlons: np.ndarray, labels: np.ndarray, fraction: Optional[float] | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| if fraction is not None: | |
| num_samples = int(array.shape[0] * fraction) | |
| else: | |
| num_samples = array.shape[0] | |
| return shuffle(array, latlons, labels, random_state=DEFAULT_SEED, n_samples=num_samples) | |
| def evaluate_model_on_task( | |
| self, | |
| pretrained_model: Encoder, | |
| model_modes: Optional[List[str]] = None, | |
| fraction: Optional[float] = None, | |
| ) -> Dict: | |
| if model_modes is None: | |
| model_modes = self.all_classification_sklearn_models | |
| for model_mode in model_modes: | |
| assert model_mode in self.all_classification_sklearn_models | |
| results_dict = {} | |
| if len(model_modes) > 0: | |
| array, latlons, y = self.dataset.as_array(num_samples=self.sample_size) | |
| array, latlons, y = self.random_subset(array, latlons, y, fraction=fraction) | |
| dw = np.ones_like(array[:, :, 0]) * NUM_DYNAMIC_WORLD_CLASSES | |
| month = np.array([self.start_month] * array.shape[0]) | |
| dl = DataLoader( | |
| TensorDataset( | |
| torch.from_numpy(self.truncate_timesteps(self.normalize(array))).float(), | |
| torch.from_numpy(y).long(), | |
| torch.from_numpy(self.truncate_timesteps(dw)).long(), | |
| torch.from_numpy(latlons).float(), | |
| torch.from_numpy(month).long(), | |
| ), | |
| batch_size=Hyperparams.batch_size, | |
| shuffle=False, | |
| num_workers=Hyperparams.num_workers, | |
| ) | |
| sklearn_models = self.finetune_sklearn_model( | |
| dl, | |
| pretrained_model, | |
| models=model_modes, | |
| ) | |
| for sklearn_model in sklearn_models: | |
| results_dict.update( | |
| self._evaluate_model( | |
| pretrained_model=pretrained_model, sklearn_model=sklearn_model | |
| ) | |
| ) | |
| return results_dict | |