NASA-Galileo / src /eval /datasets /mados_dataset.py
openfree's picture
Deploy from GitHub repository
b20c769 verified
import json
import os
from pathlib import Path
import numpy as np
import torch
import torch.multiprocessing
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from ..preprocess import normalize_bands
torch.multiprocessing.set_sharing_strategy("file_system")
def split_and_filter_tensors(image_tensor, label_tensor):
"""
Split image and label tensors into 9 tiles and filter based on label content.
Args:
image_tensor (torch.Tensor): Input tensor of shape (13, 240, 240)
label_tensor (torch.Tensor): Label tensor of shape (240, 240)
Returns:
list of tuples: Each tuple contains (image_tile, label_tile)
"""
assert image_tensor.shape == (13, 240, 240), "Image tensor must be of shape (13, 240, 240)"
assert label_tensor.shape == (240, 240), "Label tensor must be of shape (240, 240)"
tile_size = 80
tiles = []
labels = []
for i in range(3):
for j in range(3):
# Extract image tile
image_tile = image_tensor[
:, i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
# Extract corresponding label tile
label_tile = label_tensor[
i * tile_size : (i + 1) * tile_size, j * tile_size : (j + 1) * tile_size
]
# Check if label tile has any non-zero values
if torch.any(label_tile > 0):
tiles.append(image_tile)
labels.append(label_tile)
return tiles, labels
class PrepMADOSDataset(Dataset):
def __init__(self, root_dir, split_file):
self.root_dir = root_dir
with open(os.path.join(root_dir, "splits", split_file), "r") as f:
self.scene_list = [line.strip() for line in f]
def __len__(self):
return len(self.scene_list)
def __getitem__(self, idx):
scene_name = self.scene_list[idx]
scene_num_1 = scene_name.split("_")[1]
scene_num_2 = scene_name.split("_")[2]
# Load all bands
B1 = self._load_band(scene_num_1, scene_num_2, [442, 443], 60)
B2 = self._load_band(scene_num_1, scene_num_2, [492], 10)
B3 = self._load_band(scene_num_1, scene_num_2, [559, 560], 10)
B4 = self._load_band(scene_num_1, scene_num_2, [665], 10)
B5 = self._load_band(scene_num_1, scene_num_2, [704], 20)
B7 = self._load_band(scene_num_1, scene_num_2, [780, 783], 20)
B8 = self._load_band(scene_num_1, scene_num_2, [833], 10)
B8A = self._load_band(scene_num_1, scene_num_2, [864, 865], 20)
B11 = self._load_band(scene_num_1, scene_num_2, [1610, 1614], 20)
B12 = self._load_band(scene_num_1, scene_num_2, [2186, 2202], 20)
B1 = self._resize(B1)
B5 = self._resize(B5)
B7 = self._resize(B7)
B8A = self._resize(B8A)
B11 = self._resize(B11)
B12 = self._resize(B12)
# Interpolate missing bands
B6 = (B5 + B7) / 2
B9 = B8A
B10 = (B8A + B11) / 2
image = torch.cat(
[B1, B2, B3, B4, B5, B6, B7, B8, B8A, B9, B10, B11, B12], axis=1
).squeeze(0) # (13, 240, 240)
mask = self._load_mask(scene_num_1, scene_num_2).squeeze(0).squeeze(0) # (240, 240)
images, masks = split_and_filter_tensors(image, mask)
return images, masks
def _load_band(self, scene_num_1, scene_num_2, bands, resolution):
for band in bands:
band_path = f"{self.root_dir}/Scene_{scene_num_1}/{resolution}/Scene_{scene_num_1}_L2R_rhorc_{band}_{scene_num_2}.tif"
if os.path.exists(band_path):
return (
torch.from_numpy(np.array(Image.open(band_path)))
.float()
.unsqueeze(0)
.unsqueeze(0)
)
print(f"COULDNT FIND {scene_num_1, scene_num_2, bands, resolution}")
def _resize(self, image):
return F.interpolate(image, size=240, mode="bilinear", align_corners=False)
def _load_mask(self, scene_num_1, scene_num_2):
mask_path = (
f"{self.root_dir}/Scene_{scene_num_1}/10/Scene_{scene_num_1}_L2R_cl_{scene_num_2}.tif"
)
return torch.from_numpy(np.array(Image.open(mask_path))).long().unsqueeze(0).unsqueeze(0)
def get_mados(save_path, root_dir="MADOS", split_file="test_X.txt"):
dataset = PrepMADOSDataset(root_dir=root_dir, split_file=split_file)
all_images = []
all_masks = []
for i in dataset:
all_images += i[0]
all_masks += i[1]
split_images = torch.stack(all_images) # shape (N, 13, 80, 80)
split_masks = torch.stack(all_masks) # shape (N, 80, 80)
torch.save(obj={"images": split_images, "labels": split_masks}, f=save_path)
class MADOSDataset(Dataset):
def __init__(self, path_to_splits: Path, split: str, norm_operation, augmentation, partition):
with (Path(__file__).parents[0] / Path("configs") / Path("mados.json")).open("r") as f:
config = json.load(f)
# NOTE: I imputed bands for this dataset before saving the tensors, so no imputation is necessary
assert split in ["train", "val", "valid", "test"]
if split == "valid":
split = "val"
self.band_info = config["band_info"]
self.split = split
self.augmentation = augmentation
self.norm_operation = norm_operation
torch_obj = torch.load(path_to_splits / f"MADOS_{split}.pt")
self.images = torch_obj["images"]
self.labels = torch_obj["labels"]
if (partition != "default") and (split == "train"):
with open(path_to_splits / f"{partition}_partition.json", "r") as json_file:
subset_indices = json.load(json_file)
self.images = self.images[subset_indices]
self.labels = self.labels[subset_indices]
def __len__(self):
return self.images.shape[0]
def __getitem__(self, idx):
image = self.images[idx] # (80, 80, 13)
label = self.labels[idx] # (80, 80)
image = torch.tensor(normalize_bands(image.numpy(), self.norm_operation, self.band_info))
image, label = self.augmentation.apply(image, label, "seg")
return {"s2": image, "target": label}