TerraMind-Methane-Classification / classification /config /methane_classification_datamodule.py
KPLabs's picture
Upload folder using huggingface_hub
97a17c2 verified
import pandas as pd
import albumentations as A
from typing import Optional, List
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchgeo.datamodules import NonGeoDataModule
from methane_classification_dataset import MethaneClassificationDataset
class MethaneClassificationDataModule(NonGeoDataModule):
def __init__(
self,
data_root: str,
excel_file: str,
batch_size: int = 8,
num_workers: int = 0,
val_split: float = 0.2,
seed: int = 42,
**kwargs
):
# We pass "NonGeoDataset" just to satisfy the parent class,
# but we instantiate specific datasets in setup()
super().__init__(MethaneClassificationDataset, batch_size, num_workers, **kwargs)
self.data_root = data_root
self.excel_file = excel_file
self.val_split = val_split
self.seed = seed
self.batch_size = batch_size
self.num_workers = num_workers
# State variables for paths
self.train_paths = []
self.val_paths = []
def _get_training_transforms(self):
"""Internal definition of training transforms"""
return A.Compose([
A.ElasticTransform(p=0.25),
A.RandomRotate90(p=0.5),
A.Flip(p=0.5),
A.ShiftScaleRotate(rotate_limit=90, shift_limit_x=0.05, shift_limit_y=0.05, p=0.5)
])
def setup(self, stage: str = None):
# 1. Read the Excel File
try:
df = pd.read_csv(self.excel_file) if self.excel_file.endswith('.csv') else pd.read_excel(self.excel_file)
except Exception as e:
raise RuntimeError(f"Failed to load summary file: {e}")
# 2. Filter valid paths (checking if Fold column exists or just using all data)
# Assuming we just use all data in the file and split it 80/20 here.
# If you need specific Fold filtering, add that logic here.
all_paths = df['Filename'].tolist()
# 3. Perform the Split
self.train_paths, self.val_paths = train_test_split(
all_paths,
test_size=self.val_split,
random_state=self.seed
)
# 4. Instantiate Datasets
if stage in ("fit", "train"):
self.train_dataset = MethaneClassificationDataset(
root_dir=self.data_root,
excel_file=self.excel_file,
paths=self.train_paths,
transform=self._get_training_transforms(),
)
if stage in ("fit", "validate", "val"):
self.val_dataset = MethaneClassificationDataset(
root_dir=self.data_root,
excel_file=self.excel_file,
paths=self.val_paths,
transform=None, # No transforms for validation
)
if stage in ("test", "predict"):
# For testing, you might want to use a specific hold-out set
# For now, reusing val_paths or you can add logic to load a test fold
self.test_dataset = MethaneClassificationDataset(
root_dir=self.data_root,
excel_file=self.excel_file,
paths=self.val_paths,
transform=None,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True
)