| | import pandas as pd |
| | import albumentations as A |
| | from typing import Optional, List |
| | from sklearn.model_selection import train_test_split |
| | from torch.utils.data import DataLoader |
| | from torchgeo.datamodules import NonGeoDataModule |
| | from methane_classification_dataset import MethaneClassificationDataset |
| |
|
| | class MethaneClassificationDataModule(NonGeoDataModule): |
| | def __init__( |
| | self, |
| | data_root: str, |
| | excel_file: str, |
| | batch_size: int = 8, |
| | num_workers: int = 0, |
| | val_split: float = 0.2, |
| | seed: int = 42, |
| | **kwargs |
| | ): |
| | |
| | |
| | super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs) |
| |
|
| | self.data_root = data_root |
| | self.excel_file = excel_file |
| | self.val_split = val_split |
| | self.seed = seed |
| | self.batch_size = batch_size |
| | self.num_workers = num_workers |
| | |
| | |
| | self.train_paths = [] |
| | self.val_paths = [] |
| |
|
| | def _get_training_transforms(self): |
| | """Internal definition of training transforms""" |
| | return A.Compose([ |
| | A.ElasticTransform(p=0.25), |
| | A.RandomRotate90(p=0.5), |
| | A.Flip(p=0.5), |
| | A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5) |
| | ]) |
| |
|
| | def setup(self, stage: str = None): |
| | |
| | try: |
| | df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file) |
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load summary file: {e}") |
| |
|
| | |
| | |
| | |
| | all_paths = df['Filename'].tolist() |
| |
|
| | |
| | self.train_paths, self.val_paths = train_test_split( |
| | all_paths, |
| | test_size=self.val_split, |
| | random_state=self.seed |
| | ) |
| |
|
| | |
| | if stage in ("fit", "train"): |
| | self.train_dataset = MethaneClassificationDataset( |
| | root_dir=self.data_root, |
| | excel_file=self.excel_file, |
| | paths=self.train_paths, |
| | transform=self._get_training_transforms(), |
| | ) |
| | |
| | if stage in ("fit", "validate", "val"): |
| | self.val_dataset = MethaneClassificationDataset( |
| | root_dir=self.data_root, |
| | excel_file=self.excel_file, |
| | paths=self.val_paths, |
| | transform=None, |
| | ) |
| |
|
| | if stage in ("test", "predict"): |
| | |
| | |
| | self.test_dataset = MethaneClassificationDataset( |
| | root_dir=self.data_root, |
| | excel_file=self.excel_file, |
| | paths=self.val_paths, |
| | transform=None, |
| | ) |
| |
|
| |
|
| | def train_dataloader(self): |
| | return DataLoader( |
| | self.train_dataset, |
| | batch_size=self.batch_size, |
| | shuffle=True, |
| | num_workers=self.num_workers, |
| | drop_last=True |
| | ) |
| |
|
| | def val_dataloader(self): |
| | return DataLoader( |
| | self.val_dataset, |
| | batch_size=self.batch_size, |
| | shuffle=False, |
| | num_workers=self.num_workers, |
| | drop_last=True |
| | ) |
| |
|
| | def test_dataloader(self): |
| | return DataLoader( |
| | self.test_dataset, |
| | batch_size=self.batch_size, |
| | shuffle=False, |
| | num_workers=self.num_workers, |
| | drop_last=True |
| | ) |
| |
|