Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import pandas as pd | |
| import rioxarray | |
| import torch | |
| from einops import rearrange | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from src.utils import data_dir | |
| from ..preprocess import normalize_bands | |
| flood_folder = data_dir / "sen1floods" | |
| class Sen1Floods11Processor: | |
| input_hw = 512 | |
| output_tile_size = 64 | |
| s1_bands = ("VV", "VH") | |
| s2_bands = ("B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B10", "B11", "B12") | |
| def __init__(self, folder: Path, split_path: Path): | |
| split_labelnames = pd.read_csv(split_path, header=None)[1].tolist() | |
| all_labels = list(folder.glob("LabelHand/*.tif")) | |
| split_labels = [] | |
| for label in all_labels: | |
| if label.name in split_labelnames: | |
| split_labels.append(label) | |
| self.all_labels = split_labels | |
| def __len__(self): | |
| return len(self.all_labels) | |
| def split_and_filter_tensors(cls, s1, s2, labels): | |
| """ | |
| Split image and label tensors into 9 tiles and filter based on label content. | |
| Args: | |
| image_tensor (torch.Tensor): Input tensor of shape (13, 240, 240) | |
| label_tensor (torch.Tensor): Label tensor of shape (240, 240) | |
| Returns: | |
| list of tuples: Each tuple contains (image_tile, label_tile) | |
| """ | |
| assert s1.shape == ( | |
| len(cls.s1_bands), | |
| cls.input_hw, | |
| cls.input_hw, | |
| ), ( | |
| f"s1 tensor must be of shape ({len(cls.s1_bands)}, {cls.input_hw}, {cls.input_hw}), " | |
| f"got {s1.shape}" | |
| ) | |
| assert s2.shape == ( | |
| len(cls.s2_bands), | |
| cls.input_hw, | |
| cls.input_hw, | |
| ), f"s2 tensor must be of shape ({len(cls.s2_bands)}, {cls.input_hw}, {cls.input_hw})" | |
| assert labels.shape == ( | |
| 1, | |
| cls.input_hw, | |
| cls.input_hw, | |
| ), f"labels tensor must be of shape (1, {cls.input_hw}, {cls.input_hw})" | |
| tile_size = cls.output_tile_size | |
| s1_list, s2_list, labels_list = [], [], [] | |
| num_tiles_per_dim = cls.input_hw // cls.output_tile_size | |
| for i in range(num_tiles_per_dim): | |
| for j in range(num_tiles_per_dim): | |
| # Extract image tile | |
| s1_tile = s1[ | |
| :, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size | |
| ] | |
| s2_tile = s2[ | |
| :, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size | |
| ] | |
| # Extract corresponding label tile | |
| label_tile = labels[ | |
| :, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size | |
| ] | |
| # Check if label tile has any non-zero values | |
| if torch.any(label_tile > 0): | |
| s1_list.append(s1_tile) | |
| s2_list.append(s2_tile) | |
| labels_list.append(label_tile) | |
| return s1_list, s2_list, labels_list | |
| def label_to(label: Path, to: str = "s1"): | |
| sen_root = label.parents[1] | |
| location, tile_id, _ = label.stem.split("_") | |
| if to == "s1": | |
| return sen_root / f"s1/{location}_{tile_id}_S1Hand.tif" | |
| elif to == "s2": | |
| return sen_root / f"s2/{location}_{tile_id}_S2Hand.tif" | |
| else: | |
| raise ValueError(f"Expected `to` to be s1 or s2, got {to}") | |
| def __getitem__(self, idx: int): | |
| labels_path = self.all_labels[idx] | |
| with rioxarray.open_rasterio(labels_path) as ds: # type: ignore | |
| labels = torch.from_numpy(ds.values) # type: ignore | |
| with rioxarray.open_rasterio(self.label_to(labels_path, "s1")) as ds: # type: ignore | |
| s1 = torch.from_numpy(ds.values) # type: ignore | |
| with rioxarray.open_rasterio(self.label_to(labels_path, "s2")) as ds: # type: ignore | |
| s2 = torch.from_numpy(ds.values) # type: ignore | |
| return self.split_and_filter_tensors(s1, s2, labels) | |
| def get_sen1floods11(split_name: str = "flood_bolivia_data.csv"): | |
| split_path = flood_folder / split_name | |
| dataset = Sen1Floods11Processor(folder=flood_folder, split_path=split_path) | |
| all_s1, all_s2, all_labels = [], [], [] | |
| for i in tqdm(range(len(dataset))): | |
| b = dataset[i] | |
| all_s1 += b[0] | |
| all_s2 += b[1] | |
| all_labels += b[2] | |
| save_path = flood_folder / f"{split_path.stem}.pt" | |
| torch.save( | |
| obj={ | |
| "s1": torch.stack(all_s1), | |
| "labels": torch.stack(all_labels), | |
| "s2": torch.stack(all_s2), | |
| }, | |
| f=save_path, | |
| ) | |
| def remove_nan(s1, target): | |
| # s1 is shape (N, H, W, C) | |
| # target is shape (N, H, W) | |
| new_s1, new_target = [], [] | |
| for i in range(s1.shape[0]): | |
| if torch.any(torch.isnan(s1[i])) or torch.any(torch.isinf(s1[i])): | |
| continue | |
| new_s1.append(s1[i]) | |
| new_target.append(target[i]) | |
| return torch.stack(new_s1), torch.stack(new_target) | |
| class Sen1Floods11Dataset(Dataset): | |
| def __init__( | |
| self, | |
| path_to_splits: Path, | |
| split: str, | |
| norm_operation, | |
| augmentation, | |
| partition, | |
| mode: str = "s1", # not sure if we would ever want s2? | |
| ): | |
| with (Path(__file__).parents[0] / Path("configs") / Path("sen1floods11.json")).open( | |
| "r" | |
| ) as f: | |
| config = json.load(f) | |
| assert split in ["train", "val", "valid", "test", "bolivia"] | |
| if split == "val": | |
| split = "valid" | |
| self.band_info = config["band_info"]["s1"] | |
| self.split = split | |
| self.augmentation = augmentation | |
| self.norm_operation = norm_operation | |
| torch_obj = torch.load(path_to_splits / f"flood_{split}_data.pt") | |
| self.s1 = torch_obj["s1"] # (N, 2, 64, 64) | |
| self.s1 = rearrange(self.s1, "n c h w -> n h w c") | |
| # print(f"Before removing nans, we have {self.s1.shape[0]} tiles") | |
| self.labels = torch_obj["labels"] | |
| self.s1, self.labels = remove_nan( | |
| self.s1, self.labels | |
| ) # should we remove the tile or impute the pixel? | |
| # print(f"After removing nans, we have {self.s1.shape[0]} tiles") | |
| if (partition != "default") and (split == "train"): | |
| with open(path_to_splits / f"{partition}_partition.json", "r") as json_file: | |
| subset_indices = json.load(json_file) | |
| self.s1 = self.s1[subset_indices] | |
| self.labels = self.labels[subset_indices] | |
| if mode != "s1": | |
| raise ValueError(f"Modes other than s1 not yet supported, got {mode}") | |
| def __len__(self): | |
| return self.s1.shape[0] | |
| def __getitem__(self, idx): | |
| image = self.s1[idx] | |
| label = self.labels[idx][0] | |
| image = torch.tensor(normalize_bands(image.numpy(), self.norm_operation, self.band_info)) | |
| image, label = self.augmentation.apply(image, label, "seg") | |
| return {"s1": image, "target": label.long()} | |