diff --git "a/task1/train_mcs_models.py" "b/task1/train_mcs_models.py" deleted file mode 100644--- "a/task1/train_mcs_models.py" +++ /dev/null @@ -1,2578 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark multiple backbones on modulation classification across dataset sizes. - -For each desired training size (samples per MCS class) and repetition, the script: - 1. Randomly samples spectrograms from distinct (modulation, rate, SNR, doppler) configs. - 2. Builds train/val/test splits (val/test sizes are configurable). - 3. Fine-tunes several backbones (LWM, ResNet18, EfficientNet-B0, MobileNet-V3, - and a small CNN) using the same splits. - 4. Reports accuracy statistics and stores checkpoints/metrics per experiment. - -Input spectrograms are globally normalized using the dataset mean/std stored with -the specified pretrained checkpoint (defaults to the latest run under `models/`). - -Usage example (defaults cover city_1_losangeles/LTE with all available SNR·mobility combos): - - python task1/train_mcs_models.py --train-sizes 128 --models resnet18 mobilenet_v3_small -""" - -from __future__ import annotations - -import argparse -import copy -import csv -import glob -import json -import os -import pickle -import random -import re -import sys -from collections import Counter, defaultdict -from pathlib import Path - -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple - -from contextlib import nullcontext -from datetime import datetime - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset -from torch.amp import autocast, GradScaler - -try: - from tqdm import tqdm -except ImportError: # pragma: no cover - optional dependency - def tqdm(iterable, *args, **kwargs): - return iterable - -PROJECT_ROOT = Path(__file__).resolve().parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - -from pretraining.pretrained_model import lwm as lwm_model -from utils import count_parameters - - -COMM_CANONICAL = { - "lte": "LTE", - "wifi": "WiFi", - "5g": "5G", -} -COMM_LOWER = {v: k for k, v in COMM_CANONICAL.items()} -try: - from sklearn.metrics import f1_score as sklearn_f1_score - HAVE_SKLEARN = True -except ImportError: - HAVE_SKLEARN = False - -try: - import matplotlib.pyplot as plt - - HAVE_MPL = True -except ImportError: - HAVE_MPL = False - -try: - from task2.mobility_utils import LWMClassifierMinimal # type: ignore -except ImportError: # pragma: no cover - optional dependency - LWMClassifierMinimal = None # type: ignore[misc] - -# HPU support detection -HPU_AVAILABLE = False -try: - import habana_frameworks.torch.core as htcore # type: ignore[import-not-found] - HPU_AVAILABLE = hasattr(torch, "hpu") and torch.hpu.is_available() -except (ImportError, AttributeError): - pass - - -def compute_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: - if HAVE_SKLEARN: - return float(sklearn_f1_score(y_true, y_pred, average="macro")) - classes = np.unique(np.concatenate([y_true, y_pred])) - scores = [] - for cls in classes: - tp = np.sum((y_true == cls) & (y_pred == cls)) - fp = np.sum((y_true != cls) & (y_pred == cls)) - fn = np.sum((y_true == cls) & (y_pred != cls)) - precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 - recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 - denom = precision + recall - f1 = (2 * precision * recall / denom) if denom > 0 else 0.0 - scores.append(f1) - return float(np.mean(scores)) - - -MODULATION_LABELS = { - "BPSK": 0, - "QPSK": 1, - "QAM16": 2, - "QAM64": 3, - "QAM256": 4, -} - -LABEL_NAMES = {idx: name for name, idx in MODULATION_LABELS.items()} - -DEFAULT_LWM_TRAINABLE_LAYERS = 2 # fine-tune the last two transformer blocks - -_SAMPLE_COUNT_CACHE: Dict[str, int] = {} - - -def normalize_per_sample(specs: np.ndarray, eps: float = 1e-6) -> np.ndarray: - if specs.size == 0: - return specs.astype(np.float32, copy=False) - means = specs.mean(axis=(1, 2), keepdims=True) - stds = specs.std(axis=(1, 2), keepdims=True) - stds = np.maximum(stds, eps) - normalized = (specs - means) / stds - return normalized.astype(np.float32, copy=False) - - -def apply_normalization(specs: np.ndarray, stats: Dict[str, object]) -> np.ndarray: - mode = str(stats.get("normalization", "dataset")).lower() - mean = float(stats.get("mean", 0.0)) - std = float(stats.get("std", 1.0)) - if abs(std) < 1e-6: - std = 1e-6 - if mode == "dataset": - return ((specs.astype(np.float32, copy=False) - mean) / std).astype(np.float32, copy=False) - return normalize_per_sample(specs) - - -def _unique_parameters(params: Iterable[nn.Parameter]) -> List[nn.Parameter]: - seen: set[int] = set() - unique: List[nn.Parameter] = [] - for param in params: - pid = id(param) - if pid not in seen: - unique.append(param) - seen.add(pid) - return unique - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--data-root", - default=str(PROJECT_ROOT / "spectrograms"), - help="Root directory containing city folders (default: project_root/spectrograms)", - ) - parser.add_argument( - "--cities", - nargs="*", - default=["city_1_losangeles"], - help="City directories to include (default: %(default)s)", - ) - parser.add_argument("--comm-types", nargs="*", default=["LTE"], help="Communication standards to include (default: %(default)s)") - parser.add_argument("--LTE", dest="select_lte", action="store_true", help="Shortcut for --comm-types LTE") - parser.add_argument("--WiFi", dest="select_wifi", action="store_true", help="Shortcut for --comm-types WiFi") - parser.add_argument("--5G", dest="select_5g", action="store_true", help="Shortcut for --comm-types 5G") - parser.add_argument("--snrs", nargs="*", default=None, help="SNR folders to include for training (default: all available)") - parser.add_argument("--val-snrs", nargs="*", default=None, help="SNR folders for validation/test (default: all available)") - parser.add_argument( - "--mobilities", - nargs="*", - default=None, - help="Mobility folders to include for training (default: all available)", - ) - parser.add_argument( - "--val-mobilities", - nargs="*", - default=None, - help="Mobility folders for validation/test (default: all available)", - ) - parser.add_argument("--fft-folder", default="512FFT", help="FFT folder name (default: %(default)s)") - parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "hpu", "cpu"], - help="Device to use for training (default: auto - detects HPU, then CUDA, then CPU)", - ) - parser.add_argument( - "--gpu-ids", - type=int, - nargs="*", - default=None, - help="Specific GPU device IDs to use (only for CUDA, default: all visible GPUs)", - ) - parser.add_argument( - "--train-sizes", - type=int, - nargs="*", - default=[2, 4, 8, 16, 32, 64, 128, 256], - help="Training samples per class to benchmark", - ) - parser.add_argument("--val-per-class", type=int, default=512, help="Validation samples per class") - parser.add_argument("--test-per-class", type=int, default=512, help="Test samples per class") - parser.add_argument("--repetitions", type=int, default=1, help="Repetitions per train size") - parser.add_argument("--epochs", type=int, default=200, help="Epochs per run") - parser.add_argument("--batch-size", type=int, default=32, help="Mini-batch size") - parser.add_argument("--lr", type=float, default=8e-4, help="Learning rate for fine-tuning") - parser.add_argument("--weight-decay", type=float, default=3e-2, help="Weight decay") - parser.add_argument( - "--no-epoch-history", - action="store_true", - help="Disable aggregated per-epoch history tracking", - ) - parser.add_argument( - "--no-epoch-plot", - action="store_true", - help="Disable per-repetition metric plots", - ) - parser.add_argument( - "--save-epoch-checkpoints", - action="store_true", - help="Persist per-epoch checkpoints (default: disabled)", - ) - parser.add_argument( - "--backbone-lr-factor", - type=float, - default=0.3, - help="Relative LR multiplier applied to unfrozen backbone parameters (default: %(default)s)", - ) - parser.add_argument( - "--early-patience", - type=int, - default=5, - help="Early stopping patience based on validation F1 (default: %(default)s)", - ) - parser.add_argument( - "--early-min-epochs", - type=int, - default=10, - help="Minimum number of epochs to run before early stopping can trigger (default: %(default)s)", - ) - parser.add_argument( - "--finetune-epochs", - type=int, - default=0, - help="Additional fine-tuning epochs to run after the main schedule (default: %(default)s)", - ) - parser.add_argument( - "--finetune-lr-factor", - type=float, - default=0.1, - help="Multiplier applied to the base learning rate during fine-tuning (default: %(default)s)", - ) - parser.add_argument( - "--finetune-patience", - type=int, - default=3, - help="Early stopping patience for the fine-tuning phase (default: %(default)s)", - ) - parser.add_argument( - "--finetune-min-epochs", - type=int, - default=0, - help="Minimum epochs to execute in the fine-tuning phase before early stopping is considered (default: %(default)s)", - ) - parser.add_argument( - "--debug-eval-batches", - type=int, - default=0, - help="Log detailed stats for the first N evaluation batches (0 disables logging)", - ) - parser.add_argument( - "--debug-eval-interval", - type=int, - default=1, - help="Evaluate logging interval in batches when debug logging is enabled (default: %(default)s)", - ) - parser.add_argument( - "--debug-eval-softmax", - action="store_true", - help="When debugging evaluations, also log softmax statistics per batch", - ) - parser.add_argument( - "--models", - nargs="*", - default=["lwm", "resnet18", "efficientnet_b0", "mobilenet_v3_small", "simple_cnn", "ieee_cnn"], - help="Models to benchmark", - ) - parser.add_argument( - "--raw-input-models", - nargs="*", - default=None, - help=( - "Models that should receive raw spectrograms without additional normalization " - "(default: all non-LWM models)" - ), - ) - parser.add_argument( - "--lwm-trainable-layers", - type=int, - default=2, - help="Number of transformer layers (from the end) to fine-tune in LWM (default: %(default)s)", - ) - parser.add_argument( - "--lwm-classifier-dim", - type=int, - default=64, - help="Hidden width for the LWM classifier MLP head (default: %(default)s; ignored for linear head)", - ) - parser.add_argument( - "--lwm-head-dropout", - type=float, - default=0.0, - help="Dropout applied inside the LWM classifier head (default: %(default)s)", - ) - parser.add_argument( - "--lwm-head-type", - choices=("linear", "mlp", "res1dcnn"), - default="res1dcnn", - help="Classifier head architecture for LWM (default: %(default)s)", - ) - parser.add_argument( - "--lwm-backbone-lr-factor", - type=float, - default=0.2, - help="LR multiplier for unfrozen LWM backbone layers (default: %(default)s)", - ) - parser.add_argument( - "--resnet-head-width", - type=int, - default=512, - help="Hidden width for the ResNet18 classifier head (default: %(default)s)", - ) - parser.add_argument( - "--efficientnet-head-width", - type=int, - default=296, - help="Hidden width for the EfficientNet-B0 classifier head (default: %(default)s)", - ) - parser.add_argument( - "--mobilenet-head-width", - type=int, - default=576, - help="Hidden width for the MobileNetV3-Small classifier head (default: %(default)s)", - ) - parser.add_argument( - "--imagenet-head-dropout", - type=float, - default=0.6, - help="Dropout probability used inside ImageNet backbone classifier heads (default: %(default)s)", - ) - parser.add_argument( - "--imagenet-weight-decay-scale", - type=float, - default=2.0, - help="Multiplier applied to weight decay for ImageNet backbone trainable parameters (default: %(default)s)", - ) - parser.add_argument( - "--simple-cnn-hidden-dims", - type=int, - nargs="*", - default=[272, 128], - help="Hidden layer widths for the Simple CNN classifier (default: %(default)s)", - ) - parser.add_argument( - "--ieee-cnn-hidden-dims", - type=int, - nargs="*", - default=[512, 256], - help="Hidden layer widths for the IEEE CNN classifier (default: %(default)s)", - ) - parser.add_argument( - "--ieee-cnn-dropout", - type=float, - default=0.3, - help="Dropout rate for the IEEE CNN model (default: %(default)s)", - ) - parser.add_argument("--checkpoint", type=Path, default=None, help="Path to pretrained LWM checkpoint (.pth)") - parser.add_argument("--stats", type=Path, default=None, help="dataset_stats.json path") - parser.add_argument( - "--models-root", - type=Path, - default=PROJECT_ROOT / "models", - help="Root with pretrained runs (default: project_root/models)", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=PROJECT_ROOT / "task1" / "mcs_benchmarks", - help="Results root directory (per-run subfolder created automatically)", - ) - parser.add_argument( - "--export-full-model", - type=Path, - default=None, - help="Directory where best full-model checkpoints (backbone + head) will be exported per run", - ) - parser.add_argument("--seed", type=int, default=42, help="Base random seed") - args = parser.parse_args() - args.output_root = args.output_dir - - quick_comm: List[str] = [] - if getattr(args, "select_lte", False): - quick_comm.append("LTE") - if getattr(args, "select_wifi", False): - quick_comm.append("WiFi") - if getattr(args, "select_5g", False): - quick_comm.append("5G") - if quick_comm: - args.comm_types = quick_comm - - normalized: List[str] = [] - for comm in args.comm_types: - upper = comm.upper() - if upper == "WIFI": - normalized.append("WiFi") - elif upper == "LTE": - normalized.append("LTE") - elif upper == "5G": - normalized.append("5G") - else: - normalized.append(comm) - args.comm_types = normalized - timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") - comm_tokens: List[str] = [] - for comm in args.comm_types: - canonical = COMM_LOWER.get(comm, comm.lower()) - token = re.sub(r"[^a-z0-9]+", "-", canonical.lower()).strip("-") - comm_tokens.append(token or "unknown") - comm_suffix = "-".join(comm_tokens) if comm_tokens else "unknown" - args.run_timestamp = timestamp - args.output_dir = args.output_root / comm_suffix / timestamp - args.comm_suffix = comm_suffix - - if args.gpu_ids is not None and len(args.gpu_ids) == 0: - args.gpu_ids = None - - if not args.simple_cnn_hidden_dims: - args.simple_cnn_hidden_dims = [512, 256] - if not args.ieee_cnn_hidden_dims: - args.ieee_cnn_hidden_dims = [512, 256] - if args.raw_input_models is None: - args.raw_input_models = [ - model.lower() - for model in args.models - if model.lower() not in {"lwm"} - ] - else: - args.raw_input_models = [model.lower() for model in args.raw_input_models] - - args.models = [model for model in args.models] - args.save_epoch_history = not args.no_epoch_history - args.plot_epoch_history = not args.no_epoch_plot - - args.imagenet_head_dropout = float(max(0.0, min(args.imagenet_head_dropout, 0.95))) - args.imagenet_weight_decay_scale = float(max(0.0, args.imagenet_weight_decay_scale)) - - return args - - -def find_latest_run(models_root: Path) -> Path: - run_dirs = [p for p in models_root.iterdir() if p.is_dir()] - run_dirs = [p for p in run_dirs if not p.name.lower().endswith("_models")] - valid_runs = [p for p in run_dirs if any(p.glob("*.pth"))] - if valid_runs: - return max(valid_runs, key=lambda p: p.stat().st_mtime) - - checkpoints = list(models_root.glob("*.pth")) - if checkpoints: - print(f"[INFO] No checkpoint-bearing subdirectories under {models_root}; using root as run directory.") - return models_root - - raise FileNotFoundError(f"No checkpoints found under {models_root}") - - -def find_best_checkpoint(run_dir: Path) -> Path: - candidates = list(run_dir.glob("*.pth")) - if not candidates: - raise FileNotFoundError(f"No checkpoints in {run_dir}") - - def metric(path: Path) -> float: - match = re.search(r"_val([0-9]+(?:\.[0-9]+)?)", path.name) - if match: - try: - return float(match.group(1)) - except ValueError: - pass - return float("inf") - - best = min(candidates, key=metric) - return best - - -def resolve_models_directory(args: argparse.Namespace) -> Path: - base = args.models_root.expanduser().resolve() - if not base.exists(): - raise FileNotFoundError(f"Models root not found: {base}") - matches: List[Path] = [] - for comm in args.comm_types: - subdir = base / f"{comm}_models" - if subdir.exists(): - matches.append(subdir) - else: - print(f"[WARN] Models directory for {comm} not found at {subdir}") - if len(matches) == 1: - print(f"[INFO] Using models directory for {args.comm_types[0]}: {matches[0]}") - return matches[0] - if len(matches) > 1: - raise ValueError( - "Multiple communication-specific model directories detected; please provide --checkpoint explicitly." - ) - print(f"[INFO] Using shared models directory: {base}") - return base - - -def resolve_checkpoint_and_stats(args: argparse.Namespace, require_checkpoint: bool) -> Tuple[Path | None, Dict[str, object]]: - checkpoint: Path | None = None - models_dir = resolve_models_directory(args) - user_provided_stats = args.stats is not None - - if args.checkpoint is not None: - checkpoint = args.checkpoint.expanduser().resolve() - if not checkpoint.exists(): - raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") - stats_path = args.stats.expanduser().resolve() if user_provided_stats else checkpoint.parent / "dataset_stats.json" - else: - run_dir = find_latest_run(models_dir) - stats_path = run_dir / "dataset_stats.json" - if require_checkpoint: - checkpoint = find_best_checkpoint(run_dir) - else: - checkpoint = None - - if stats_path.exists(): - try: - with open(stats_path, "r", encoding="utf-8") as f: - stats = json.load(f) - except json.JSONDecodeError as exc: - if user_provided_stats: - raise ValueError( - f"Failed to parse dataset_stats.json at {stats_path}: {exc}" - ) from exc - print( - f"[WARN] Corrupt dataset_stats.json at {stats_path}; " - "falling back to mean=0/std=1 per-sample normalization." - ) - stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} - else: - if "mean" not in stats or "std" not in stats: - raise ValueError("dataset_stats.json must contain 'mean' and 'std'") - stats.setdefault("normalization", stats.get("mode", "dataset")) - else: - if user_provided_stats: - raise FileNotFoundError(f"dataset_stats.json not found: {stats_path}") - stats = {"mean": 0.0, "std": 1.0, "normalization": "per_sample"} - print(f"[WARN] dataset_stats.json not found at {stats_path}. Falling back to per-sample normalization.") - - if checkpoint is not None: - print(f"[INFO] Using checkpoint: {checkpoint}") - elif require_checkpoint: - raise FileNotFoundError("LWM requested but no checkpoint available") - else: - print("[INFO] No LWM checkpoint required for selected models") - - norm_mode = str(stats.get("normalization", "dataset")) - if norm_mode.lower() == "dataset": - print(f"[INFO] Dataset stats -> mean={stats['mean']:.4f}, std={stats['std']:.4f}") - else: - print("[INFO] Normalization mode: per_sample") - - return checkpoint, { - "mean": float(stats.get("mean", 0.0)), - "std": float(stats.get("std", 1.0)), - "normalization": norm_mode, - } - - -def identify_modulation(path: str) -> tuple[int | None, str | None]: - for mod_name, label in MODULATION_LABELS.items(): - if mod_name in path: - return label, mod_name - return None, None - - -def _extract_metadata(parts: Sequence[str]) -> Tuple[str, str, str]: - rate = next((part for part in parts if part.startswith("rate")), "rate_unknown") - snr = next((part for part in parts if part.startswith("SNR")), "SNR_unknown") - mobility = next((part for part in parts if part in {"static", "pedestrian", "vehicular"}), "mobility_unknown") - return rate, snr, mobility - - -def discover_snr_mobility( - data_root: Path, - cities: Sequence[str], - comm_types: Sequence[str], - fft_folder: str, -) -> Tuple[List[str], List[str]]: - snrs: set[str] = set() - mobilities: set[str] = set() - for city in cities: - for comm in comm_types: - base = data_root / city / comm - if not base.exists(): - continue - for root, dirs, _ in os.walk(base): - parts = Path(root).parts - for part in parts: - if part.startswith("SNR") and part.endswith("dB"): - snrs.add(part) - elif part in {"static", "pedestrian", "vehicular"}: - mobilities.add(part) - if not snrs: - snrs.add("SNR20dB") - if not mobilities: - mobilities.add("static") - return sorted(snrs), sorted(mobilities) - - -def build_config_map( - data_root: Path, - cities: Sequence[str], - comm_types: Sequence[str], - snrs: Sequence[str], - mobilities: Sequence[str], - fft_folder: str, -) -> Dict[int, Dict[str, List[str]]]: - class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) - for city in cities: - for comm in comm_types: - base = data_root / city / comm - for snr in snrs: - for mobility in mobilities: - pattern = str(base / "**" / snr / mobility / "**" / fft_folder / "**" / "spectrograms" / "*.pkl") - for path_str in glob.glob(pattern, recursive=True): - cls, modulation_name = identify_modulation(path_str) - if cls is None: - continue - rate, _, _ = _extract_metadata(Path(path_str).parts) - config_name = f"{modulation_name}_{rate}_{snr}_{mobility}" - class_configs[cls][config_name].append(path_str) - return class_configs - - -def build_global_config_map( - data_root: Path, - cities: Sequence[str], - comm_types: Sequence[str], - fft_folder: str, -) -> Dict[int, Dict[str, List[str]]]: - class_configs: Dict[int, Dict[str, List[str]]] = defaultdict(lambda: defaultdict(list)) - for city in cities: - for comm in comm_types: - base = data_root / city / comm - pattern = str(base / "**" / fft_folder / "**" / "spectrograms" / "*.pkl") - for path_str in glob.glob(pattern, recursive=True): - cls, modulation_name = identify_modulation(path_str) - if cls is None: - continue - rate, snr_part, mobility_part = _extract_metadata(Path(path_str).parts) - config_name = f"{modulation_name}_{rate}_{snr_part}_{mobility_part}" - class_configs[cls][config_name].append(path_str) - return class_configs - - -def _count_samples_in_path(path: str) -> int: - cached = _SAMPLE_COUNT_CACHE.get(path) - if cached is not None: - return cached - arr = load_all_samples(path) - count = int(arr.shape[0]) - _SAMPLE_COUNT_CACHE[path] = count - return count - - -class LazyConfigArray: - """Lazily views spectrograms spread across multiple pickled files.""" - - __slots__ = ("paths", "_counts", "_offsets", "_total", "shape", "dtype", "ndim") - - def __init__(self, paths: Sequence[str]) -> None: - filtered_paths: List[str] = [] - counts: List[int] = [] - for path in sorted(paths): - count = _count_samples_in_path(path) - if count <= 0: - continue - filtered_paths.append(path) - counts.append(count) - - self.paths: Tuple[str, ...] = tuple(filtered_paths) - if counts: - self._counts = np.array(counts, dtype=np.int64) - self._offsets = np.concatenate(([0], np.cumsum(self._counts))) - self._total = int(self._offsets[-1]) - else: - self._counts = np.empty(0, dtype=np.int64) - self._offsets = np.array([0], dtype=np.int64) - self._total = 0 - - self.shape = (self._total, 128, 128) - self.dtype = np.float32 - self.ndim = 3 - - def __len__(self) -> int: - return self._total - - def _resolve_index(self, index: int) -> Tuple[int, int]: - if self._total == 0: - raise IndexError("attempting to index empty LazyConfigArray") - if index < 0: - index += self._total - if index < 0 or index >= self._total: - raise IndexError("index out of range for LazyConfigArray") - path_idx = int(np.searchsorted(self._offsets[1:], index, side="right")) - start = int(self._offsets[path_idx]) - return path_idx, int(index - start) - - def _load_path(self, path_idx: int) -> np.ndarray: - path = self.paths[path_idx] - return load_all_samples(path) - - def __getitem__(self, item: Any) -> np.ndarray: - if isinstance(item, (int, np.integer)): - path_idx, local_idx = self._resolve_index(int(item)) - data = self._load_path(path_idx) - sample = data[local_idx].copy() - return sample - - indices = np.asarray(item, dtype=np.int64) - if indices.ndim == 0: - indices = indices.reshape(1) - else: - indices = indices.reshape(-1) - if indices.size == 0: - return np.empty((0, 128, 128), dtype=np.float32) - - resolved: Dict[int, List[Tuple[int, int]]] = {} - for pos, raw_idx in enumerate(indices): - path_idx, local_idx = self._resolve_index(int(raw_idx)) - resolved.setdefault(path_idx, []).append((pos, local_idx)) - - result = np.empty((indices.size, 128, 128), dtype=np.float32) - for path_idx, items in resolved.items(): - data = self._load_path(path_idx) - local_positions = [loc for _, loc in items] - chunk = data[local_positions] - for offset, (pos, _) in enumerate(items): - result[pos] = chunk[offset] - return result - - -def load_config_arrays(class_configs: Dict[int, Dict[str, List[str]]]) -> Dict[int, Dict[str, LazyConfigArray]]: - loaded: Dict[int, Dict[str, LazyConfigArray]] = {} - for cls, configs in class_configs.items(): - arrays_for_cls: Dict[str, LazyConfigArray] = {} - for config_name, paths in configs.items(): - lazy_array = LazyConfigArray(paths) - if len(lazy_array) == 0: - continue - arrays_for_cls[config_name] = lazy_array - if arrays_for_cls: - loaded[cls] = arrays_for_cls - return loaded - - -def load_all_samples(path: str) -> np.ndarray: - with open(path, "rb") as f: - data = pickle.load(f) - if isinstance(data, dict) and "spectrograms" in data: - arr = data["spectrograms"] - elif isinstance(data, np.ndarray): - arr = data - else: - return np.empty((0, 128, 128), dtype=np.float32) - - arr = np.asarray(arr, dtype=np.float32) - if arr.ndim == 2: - arr = arr[None, ...] - if arr.shape[1:] != (128, 128): - return np.empty((0, 128, 128), dtype=np.float32) - return arr - - -def sample_from_paths( - paths: Sequence[str], - n_samples: int, - rng: np.random.Generator, - used_map: Dict[str, set[int]], -) -> Tuple[np.ndarray, List[Tuple[str, np.ndarray]]]: - if not paths: - raise RuntimeError("No files available for sampling") - - paths_array = np.array(paths, dtype=object) - order = rng.permutation(len(paths_array)) - remaining = n_samples - collected: List[np.ndarray] = [] - info: List[Tuple[str, np.ndarray]] = [] - - for idx in order: - if remaining <= 0: - break - path = str(paths_array[idx]) - samples = load_all_samples(path) - total = samples.shape[0] - used = used_map[path] - if used: - used_idx = np.fromiter(used, dtype=np.int64, count=len(used)) - available = np.setdiff1d(np.arange(total), used_idx, assume_unique=True) - else: - available = np.arange(total) - if available.size == 0: - continue - take = min(remaining, available.size) - chosen = rng.choice(available, size=take, replace=False) - collected.append(samples[chosen]) - used_map[path].update(int(i) for i in chosen) - info.append((path, chosen)) - remaining -= take - - if remaining > 0: - raise RuntimeError("Insufficient samples remaining to satisfy request") - - result = np.concatenate(collected, axis=0) if len(collected) > 1 else collected[0] - return result, info - - -def _ensure_available(total_needed: int, availability: Dict[str, set]) -> None: - remaining = sum(len(indices) for indices in availability.values()) - if remaining < total_needed: - raise RuntimeError( - f"Insufficient samples: need {total_needed}, only {remaining} available across configs" - ) - - -def _sample_from_availability( - arrays_map: Dict[str, LazyConfigArray], - availability: Dict[str, set[int]], - total_needed: int, - rng: np.random.Generator, -) -> Tuple[np.ndarray, Dict[str, set[int]]]: - if total_needed <= 0: - return np.empty((0, 128, 128), dtype=np.float32), {cfg: set() for cfg in arrays_map} - - _ensure_available(total_needed, availability) - remaining = total_needed - configs = [cfg for cfg, indices in availability.items() if indices] - used: Dict[str, set[int]] = {cfg: set() for cfg in arrays_map} - collected: List[np.ndarray] = [] - - while remaining > 0 and configs: - cfg = rng.choice(configs) - available_indices = np.array(list(availability[cfg]), dtype=np.int64) - if available_indices.size == 0: - configs = [c for c in configs if c != cfg] - continue - take = min(max(1, remaining // max(len(configs), 1)), remaining, available_indices.size) - chosen = rng.choice(available_indices, size=take, replace=False) - collected.append(arrays_map[cfg][chosen]) - chosen_set = {int(idx) for idx in chosen} - used[cfg].update(chosen_set) - availability[cfg].difference_update(chosen_set) - remaining -= take - configs = [c for c in configs if availability[c]] - - if remaining > 0: - raise RuntimeError("Sampling failed to collect the requested number of samples") - - stacked = np.concatenate(collected, axis=0) if collected else np.empty((0, 128, 128), dtype=np.float32) - return stacked.astype(np.float32, copy=False), used - - -def sample_train_arrays( - arrays_map: Dict[str, LazyConfigArray], - availability: Dict[str, set[int]], - train_size: int, - rng: np.random.Generator, -) -> Tuple[np.ndarray, Dict[str, set[int]]]: - return _sample_from_availability(arrays_map, availability, train_size, rng) - - -def sample_global_arrays( - arrays_map: Dict[str, LazyConfigArray], - availability: Dict[str, set[int]], - per_class: int, - rng: np.random.Generator, -) -> Tuple[np.ndarray, Dict[str, set[int]]]: - return _sample_from_availability(arrays_map, availability, per_class, rng) - - -class SpectrogramDataset(Dataset): - def __init__(self, specs: np.ndarray, labels: np.ndarray): - self.specs = specs.astype(np.float32, copy=False) - self.labels = labels.astype(np.int64, copy=False) - - def __len__(self) -> int: - return len(self.labels) - - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: - return torch.from_numpy(self.specs[idx]), int(self.labels[idx]) - - -def normalize_batch(specs: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: - mean = specs.mean() - std = specs.std(unbiased=False) - std = torch.clamp(std, min=eps) - return (specs - mean) / std - - -def apply_spec_augment( - specs: torch.Tensor, - *, - freq_mask_width: int = 12, - time_mask_width: int = 16, - freq_masks: int = 2, - time_masks: int = 2, - mask_prob: float = 0.5, - noise_std: float = 0.0, -) -> torch.Tensor: - """Apply light-weight SpecAugment-style masking to a batch of spectrograms. - - The function accepts tensors shaped ``[B, H, W]`` or ``[B, 1, H, W]`` and - returns an augmented tensor with the same shape. Masks use the sample mean - to avoid introducing large bias and are applied per-sample with the given - probability. Gaussian noise (if requested) is injected before masking. - """ - - if mask_prob <= 0.0 and noise_std <= 0.0: - return specs - - if specs.dim() not in (3, 4): - raise ValueError(f"Spectrograms must be rank-3 or rank-4, got shape {tuple(specs.shape)}") - - needs_squeeze = specs.dim() == 3 - augmented = specs.unsqueeze(1) if needs_squeeze else specs - batch_size, _, freq_dim, time_dim = augmented.shape - - if mask_prob < 1.0: - apply_mask = torch.rand(batch_size, device=augmented.device) < mask_prob - else: - apply_mask = torch.ones(batch_size, dtype=torch.bool, device=augmented.device) - - freq_mask_width = max(0, int(freq_mask_width)) - time_mask_width = max(0, int(time_mask_width)) - freq_masks = max(0, int(freq_masks)) - time_masks = max(0, int(time_masks)) - - for idx in range(batch_size): - if not apply_mask[idx]: - continue - - sample = augmented[idx] - if noise_std > 0.0: - sample = sample + noise_std * torch.randn_like(sample) - - fill_value = sample.mean() - - if freq_mask_width > 0 and freq_masks > 0: - max_width = min(freq_mask_width, freq_dim) - for _ in range(freq_masks): - width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item()) - if width == 0 or width > freq_dim: - continue - start = 0 if freq_dim == width else int(torch.randint(0, freq_dim - width + 1, (1,), device=augmented.device).item()) - sample[:, start:start + width, :] = fill_value - - if time_mask_width > 0 and time_masks > 0: - max_width = min(time_mask_width, time_dim) - for _ in range(time_masks): - width = int(torch.randint(0, max_width + 1, (1,), device=augmented.device).item()) - if width == 0 or width > time_dim: - continue - start = 0 if time_dim == width else int(torch.randint(0, time_dim - width + 1, (1,), device=augmented.device).item()) - sample[:, :, start:start + width] = fill_value - - augmented[idx] = sample - - return augmented.squeeze(1) if needs_squeeze else augmented - - - - -def _write_epoch_history( - rep_root: Path, - records: Sequence[Dict[str, object]], - enable_csv: bool, - enable_plot: bool, -) -> None: - if not records: - return - - rep_root.mkdir(parents=True, exist_ok=True) - - if enable_csv: - base_fields = [ - "model", - "epoch", - "phase", - "train_loss", - "val_loss", - "val_acc", - "val_f1", - "lr", - "train_size_requested", - "train_size_effective", - ] - extra_fields = sorted( - {key for rec in records for key in rec.keys() if key not in base_fields} - ) - fieldnames = base_fields + extra_fields - sorted_records = sorted(records, key=lambda r: (r["epoch"], r["model"], r.get("phase", ""))) - history_path = rep_root / "epoch_history.csv" - with open(history_path, "w", newline="", encoding="utf-8") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames, extrasaction="ignore") - writer.writeheader() - writer.writerows(sorted_records) - - if enable_plot and HAVE_MPL: - models_in_run = sorted({rec["model"] for rec in records}) - fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) - for ax in axes: - ax.grid(True, linestyle='--', alpha=0.3) - for model_name_plot in models_in_run: - model_records = [rec for rec in records if rec["model"] == model_name_plot] - if not model_records: - continue - epochs = [rec["epoch"] for rec in model_records] - val_loss_values = [rec["val_loss"] for rec in model_records] - val_f1_values = [rec["val_f1"] for rec in model_records] - axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot) - axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot) - axes[0].set_ylabel('Val Loss') - axes[1].set_ylabel('Val F1') - axes[1].set_xlabel('Epoch') - axes[0].legend(loc='best') - axes[0].set_title('Per-epoch validation metrics') - fig.tight_layout() - fig.savefig(rep_root / 'epoch_history.png', dpi=150) - plt.close(fig) - - -class ResidualBlock1D(nn.Module): - """Lightweight residual block used by the res1dcnn head.""" - - def __init__(self, in_channels: int, out_channels: int) -> None: - super().__init__() - self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm1d(out_channels) - self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm1d(out_channels) - self.shortcut = nn.Identity() - if in_channels != out_channels: - self.shortcut = nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=1), - nn.BatchNorm1d(out_channels), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - x = F.relu(self.bn1(self.conv1(x))) - x = self.bn2(self.conv2(x)) - x = x + self.shortcut(residual) - x = F.relu(x) - return x - - -class Res1DCNNHead(nn.Module): - """Residual 1D CNN classifier head that operates on 128-d LWM features.""" - - def __init__(self, input_dim: int, num_classes: int, dropout: float = 0.1) -> None: - super().__init__() - self.input_dim = int(input_dim) - hidden_dim = 64 - self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm1d(hidden_dim) - self.res_block = ResidualBlock1D(hidden_dim, hidden_dim) - self.dropout = nn.Dropout(dropout) - self.fc = nn.Linear(hidden_dim, num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.unsqueeze(1) - x = F.relu(self.bn1(self.conv1(x))) - x = self.res_block(x) - x = F.adaptive_avg_pool1d(x, 1).squeeze(-1) - x = self.dropout(x) - return self.fc(x) - - -class LWMClassifier(nn.Module): - def __init__( - self, - backbone: nn.Module, - trainable_layers: int, - num_classes: int, - classifier_dim: int = 128, - head_dropout: float = 0.1, - head_type: str = "mlp", - ): - super().__init__() - self.backbone = backbone - self.patch_size = 4 - self.unfold = nn.Unfold(kernel_size=self.patch_size, stride=self.patch_size) - head_dropout = max(0.0, float(head_dropout)) - head_type = head_type.lower().strip() - - if head_type == "linear": - head_layers: List[nn.Module] = [nn.LayerNorm(128)] - if head_dropout > 0: - head_layers.append(nn.Dropout(head_dropout)) - head_layers.append(nn.Linear(128, num_classes)) - self.classifier = nn.Sequential(*head_layers) - elif head_type == "res1dcnn": - self.classifier = nn.Sequential( - nn.LayerNorm(128), - Res1DCNNHead(128, num_classes, dropout=head_dropout), - ) - else: - head_layers = [ - nn.LayerNorm(128), - nn.Linear(128, classifier_dim), - nn.GELU(), - ] - if head_dropout > 0: - head_layers.append(nn.Dropout(head_dropout)) - head_layers.append(nn.Linear(classifier_dim, num_classes)) - self.classifier = nn.Sequential(*head_layers) - - for param in self.backbone.parameters(): - param.requires_grad = False - if trainable_layers > 0: - for layer in self.backbone.layers[-trainable_layers:]: - for param in layer.parameters(): - param.requires_grad = True - # Enable gradient checkpointing for memory efficiency - if hasattr(layer, 'gradient_checkpointing'): - layer.gradient_checkpointing = True - - def spectrogram_to_tokens(self, x: torch.Tensor) -> torch.Tensor: - x = x.unsqueeze(1) - patches = self.unfold(x).transpose(1, 2) - cls = torch.full( - (patches.size(0), 1, patches.size(-1)), 0.2, dtype=patches.dtype, device=patches.device - ) - return torch.cat([cls, patches], dim=1) - - def forward_features(self, x: torch.Tensor) -> torch.Tensor: - tokens = self.spectrogram_to_tokens(x) - outputs = self.backbone(tokens) - if outputs.size(1) <= 1: - return outputs[:, 0, :] - return outputs[:, 1:, :].mean(dim=1) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - cls = self.forward_features(x) - return self.classifier(cls) - - -def create_simple_cnn( - num_classes: int, - hidden_dims: Tuple[int, ...] = (192,), - dropout: float = 0.3, -) -> nn.Module: - """Create baseline CNN with configurable classifier width.""" - - if not hidden_dims: - raise ValueError("hidden_dims must contain at least one value") - - layers: List[nn.Module] = [ - nn.Conv2d(1, 16, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2), - nn.Conv2d(16, 32, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2), - nn.Conv2d(32, 64, 5, padding=2), nn.ReLU(), nn.AdaptiveAvgPool2d((4, 4)), - nn.Flatten(), - nn.Dropout(dropout), - ] - - in_dim = 4 * 4 * 64 - fc_layers: List[nn.Module] = [] - for idx, hidden_dim in enumerate(hidden_dims): - fc_layers.append(nn.Linear(in_dim, hidden_dim)) - fc_layers.append(nn.ReLU()) - fc_layers.append(nn.Dropout(dropout)) - in_dim = hidden_dim - - fc_layers.append(nn.Linear(in_dim, num_classes)) - - return nn.Sequential(*layers, *fc_layers) - - -def create_ieee_cnn( - num_classes: int, - hidden_dims: Tuple[int, ...] = (512, 256), - dropout: float = 0.3, -) -> nn.Module: - """CNN inspired by IEEE 2021 joint SNR/mobility classifier.""" - - if not hidden_dims: - raise ValueError("hidden_dims must contain at least one value") - - layers: List[nn.Module] = [ - nn.Conv2d(1, 32, kernel_size=3, padding=1), - nn.BatchNorm2d(32), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Dropout2d(p=dropout), - nn.Conv2d(32, 64, kernel_size=3, padding=1), - nn.BatchNorm2d(64), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Dropout2d(p=dropout), - nn.Conv2d(64, 128, kernel_size=3, padding=1), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Dropout2d(p=dropout), - nn.Conv2d(128, 256, kernel_size=3, padding=1), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Dropout2d(p=dropout), - nn.Conv2d(256, 256, kernel_size=3, padding=1), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool2d((4, 4)), - nn.Flatten(), - nn.Dropout(dropout), - ] - - in_dim = 4 * 4 * 256 - fc_layers: List[nn.Module] = [] - for hidden_dim in hidden_dims: - fc_layers.append(nn.Linear(in_dim, hidden_dim)) - fc_layers.append(nn.BatchNorm1d(hidden_dim)) - fc_layers.append(nn.ReLU(inplace=True)) - fc_layers.append(nn.Dropout(dropout)) - in_dim = hidden_dim - - fc_layers.append(nn.Linear(in_dim, num_classes)) - return nn.Sequential(*layers, *fc_layers) - - -def build_model( - name: str, - num_classes: int, - checkpoint: Path, - device: torch.device, - trainable_layers: int, - backbone_lr_factor: float, - overrides: Dict[str, object] | None = None, -) -> Tuple[nn.Module, List[Dict[str, object]]]: - name = name.lower() - param_groups: List[Dict[str, object]] = [] - overrides = overrides or {} - - if name == "lwm": - backbone = lwm_model(element_length=16, d_model=128, n_layers=12, max_len=1025, n_heads=8, dropout=0.1) - if checkpoint is None: - raise FileNotFoundError("Checkpoint is required for LWM-based models") - try: - state = torch.load(checkpoint, map_location="cpu", weights_only=True) - except TypeError: - # Older torch versions do not support weights_only - state = torch.load(checkpoint, map_location="cpu") - if any(k.startswith("module.") for k in state): - state = {k.replace("module.", ""): v for k, v in state.items()} - if any(k.startswith("backbone.") for k in state): - backbone_state = { - k.split("backbone.", 1)[1]: v - for k, v in state.items() - if k.startswith("backbone.") - } - else: - backbone_state = { - k: v - for k, v in state.items() - if not k.startswith("classifier.") and not k.startswith("projection_head.") - } - backbone.load_state_dict(backbone_state, strict=False) - - classifier_dim = int(overrides.get("lwm_classifier_dim", 96)) - head_dropout = float(overrides.get("lwm_head_dropout", 0.1)) - head_type = str(overrides.get("lwm_head_type", "mlp")).lower() - model = LWMClassifier( - backbone, - trainable_layers=trainable_layers, - num_classes=num_classes, - classifier_dim=classifier_dim, - head_dropout=head_dropout, - head_type=head_type, - ) - head_params = list(model.classifier.parameters()) - param_groups.append({"params": head_params, "scale": 1.0}) - - if trainable_layers > 0: - backbone_params: List[nn.Parameter] = [] - for layer in model.backbone.layers[-trainable_layers:]: - backbone_params.extend(layer.parameters()) - backbone_params = _unique_parameters(backbone_params) - if backbone_params: - param_groups.append({"params": backbone_params, "scale": backbone_lr_factor}) - - elif name == "resnet18": - from torchvision import models - - backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) - backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) - nn.init.kaiming_normal_(backbone.conv1.weight, mode='fan_out', nonlinearity='relu') - in_features = backbone.fc.in_features - head_width = int(overrides.get("resnet_head_width", 384)) - imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) - imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) - pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) - backbone.fc = nn.Sequential( - nn.Dropout(p=pre_fc_dropout), - nn.Linear(in_features, head_width), - nn.LayerNorm(head_width), - nn.ReLU(inplace=True), - nn.Dropout(p=imagenet_head_dropout), - nn.Linear(head_width, num_classes), - ) - - for param in backbone.parameters(): - param.requires_grad = False - - head_params = list(backbone.fc.parameters()) - for param in head_params: - param.requires_grad = True - imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) - head_group: Dict[str, object] = {"params": head_params, "scale": 1.0} - if imagenet_weight_decay is not None: - head_group["weight_decay"] = float(imagenet_weight_decay) - param_groups.append(head_group) - - adapt_params: List[nn.Parameter] = [] - if hasattr(backbone.layer4[0], "downsample") and backbone.layer4[0].downsample is not None: - adapt_params.extend(backbone.layer4[0].downsample[0].parameters()) - if len(backbone.layer4[0].downsample) > 1: - adapt_params.extend(backbone.layer4[0].downsample[1].parameters()) - for module in backbone.layer4[-1].modules(): - if isinstance(module, nn.BatchNorm2d): - adapt_params.extend(module.parameters()) - adapt_params = _unique_parameters(adapt_params) - for param in adapt_params: - param.requires_grad = True - if adapt_params: - adapt_group: Dict[str, object] = {"params": adapt_params, "scale": backbone_lr_factor} - if imagenet_weight_decay is not None: - adapt_group["weight_decay"] = float(imagenet_weight_decay) - param_groups.append(adapt_group) - - model = backbone - - elif name == "efficientnet_b0": - from torchvision import models - - backbone = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT) - first_conv = backbone.features[0][0] - backbone.features[0][0] = nn.Conv2d(1, first_conv.out_channels, kernel_size=3, stride=2, padding=1, bias=False) - nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu') - in_features = backbone.classifier[-1].in_features - head_width = int(overrides.get("efficientnet_head_width", 192)) - imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) - imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) - pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) - backbone.classifier = nn.Sequential( - nn.Dropout(p=pre_fc_dropout), - nn.Linear(in_features, head_width), - nn.LayerNorm(head_width), - nn.ReLU(inplace=True), - nn.Dropout(p=imagenet_head_dropout), - nn.Linear(head_width, num_classes), - ) - - for param in backbone.parameters(): - param.requires_grad = False - - head_params = list(backbone.classifier.parameters()) - for param in head_params: - param.requires_grad = True - imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) - head_group = {"params": head_params, "scale": 1.0} - if imagenet_weight_decay is not None: - head_group["weight_decay"] = float(imagenet_weight_decay) - param_groups.append(head_group) - - adapt_params: List[nn.Parameter] = [] - final_block = backbone.features[7][0] - # Depthwise conv + associated norms for the last MBConv block - depthwise = final_block.block[1][0] - adapt_params.extend(depthwise.parameters()) - for idx in (0, 1, 3): - for module in final_block.block[idx].modules(): - if isinstance(module, nn.BatchNorm2d): - adapt_params.extend(module.parameters()) - adapt_params = _unique_parameters(adapt_params) - for param in adapt_params: - param.requires_grad = True - if adapt_params: - adapt_group = {"params": adapt_params, "scale": backbone_lr_factor} - if imagenet_weight_decay is not None: - adapt_group["weight_decay"] = float(imagenet_weight_decay) - param_groups.append(adapt_group) - - model = backbone - - elif name == "mobilenet_v3_small": - from torchvision import models - - backbone = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT) - backbone.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False) - nn.init.kaiming_normal_(backbone.features[0][0].weight, mode='fan_out', nonlinearity='relu') - with torch.no_grad(): - dummy = torch.zeros(1, 1, 128, 128) - features = backbone.features(dummy) - pooled = backbone.avgpool(features) - flattened = torch.flatten(pooled, 1) - in_features = flattened.shape[1] - head_width = int(overrides.get("mobilenet_head_width", 320)) - imagenet_head_dropout = float(overrides.get("imagenet_head_dropout", 0.45)) - imagenet_head_dropout = max(0.0, min(imagenet_head_dropout, 0.9)) - pre_fc_dropout = max(0.0, min(imagenet_head_dropout * 0.5, 0.9)) - backbone.classifier = nn.Sequential( - nn.Dropout(p=pre_fc_dropout), - nn.Linear(in_features, head_width), - nn.LayerNorm(head_width), - nn.Hardswish(), - nn.Dropout(p=imagenet_head_dropout), - nn.Linear(head_width, num_classes), - ) - - for param in backbone.parameters(): - param.requires_grad = False - - head_params = list(backbone.classifier.parameters()) - for param in head_params: - param.requires_grad = True - imagenet_weight_decay = overrides.get("imagenet_weight_decay", None) - head_group = {"params": head_params, "scale": 1.0} - if imagenet_weight_decay is not None: - head_group["weight_decay"] = float(imagenet_weight_decay) - param_groups.append(head_group) - - adapt_params: List[nn.Parameter] = [] - adapt_params.extend(backbone.features[-1][0].parameters()) - for module in backbone.features[-2].modules(): - if isinstance(module, nn.BatchNorm2d): - adapt_params.extend(module.parameters()) - adapt_params = _unique_parameters(adapt_params) - for param in adapt_params: - param.requires_grad = True - if adapt_params: - adapt_group = {"params": adapt_params, "scale": backbone_lr_factor} - if imagenet_weight_decay is not None: - adapt_group["weight_decay"] = float(imagenet_weight_decay) - param_groups.append(adapt_group) - - model = backbone - - elif name in {"simple_cnn", "simplecnn"}: - hidden_dims = overrides.get("simple_cnn_hidden_dims", (192,)) - if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str): - simple_dims = tuple(int(dim) for dim in hidden_dims) - else: - simple_dims = (int(hidden_dims),) - model = create_simple_cnn(num_classes, hidden_dims=simple_dims) - head_params = list(model.parameters()) - param_groups.append({"params": head_params, "scale": 1.0}) - - elif name in {"ieee_cnn", "ieeecnn"}: - hidden_dims = overrides.get("ieee_cnn_hidden_dims", (512, 256)) - if isinstance(hidden_dims, Sequence) and not isinstance(hidden_dims, str): - ieee_dims = tuple(int(dim) for dim in hidden_dims) - else: - ieee_dims = (int(hidden_dims),) - dropout = float(overrides.get("ieee_cnn_dropout", 0.3)) - model = create_ieee_cnn(num_classes, hidden_dims=ieee_dims, dropout=dropout) - head_params = list(model.parameters()) - param_groups.append({"params": head_params, "scale": 1.0}) - - else: - raise ValueError(f"Unknown model: {name}") - - return model.to(device), param_groups - - -def _unwrap_module(model: nn.Module) -> nn.Module: - return model.module if isinstance(model, nn.DataParallel) else model - - -def _strip_module_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - if not state_dict: - return state_dict - needs_strip = any(key.startswith("module.") for key in state_dict) - if not needs_strip: - return state_dict - stripped = state_dict.__class__() if hasattr(state_dict, "__class__") else {} - for key, value in state_dict.items(): - new_key = key.split("module.", 1)[1] if key.startswith("module.") else key - stripped[new_key] = value - return stripped - - -def _model_forward( - model: nn.Module, - specs: torch.Tensor, - input_stats: Optional[torch.Tensor] = None, -) -> torch.Tensor: - base_model = _unwrap_module(model) - is_lwm_like = isinstance(base_model, LWMClassifier) - if not is_lwm_like and LWMClassifierMinimal is not None: - is_lwm_like = isinstance(base_model, LWMClassifierMinimal) - if not is_lwm_like and hasattr(base_model, "spectrogram_to_tokens"): - is_lwm_like = True - if is_lwm_like: - while specs.dim() > 3 and specs.size(1) == 1: - specs = specs.squeeze(1) - if specs.dim() != 3: - specs = specs.view(specs.size(0), specs.size(-2), specs.size(-1)) - if not is_lwm_like and specs.dim() == 3: - specs = specs.unsqueeze(1) - if input_stats is not None: - input_stats = input_stats.to(specs.device, non_blocking=True) - supports_stats = bool( - is_lwm_like and hasattr(base_model, "append_input_stats") and getattr(base_model, "append_input_stats") - ) - if supports_stats and input_stats is not None: - return model(specs, input_stats=input_stats) - return model(specs) - - -def train_one_epoch(model, loader, optimizer, device, scaler=None, batch_normalize: bool = False): - criterion = nn.CrossEntropyLoss(reduction='mean') - model.train() - total_loss = 0.0 - total = 0 - for specs, labels in loader: - specs = specs.to(device, non_blocking=True) - if batch_normalize: - specs = normalize_batch(specs) - labels = labels.to(device, non_blocking=True) - optimizer.zero_grad(set_to_none=True) - - # Use autocast only for CUDA - if scaler is not None and device.type == 'cuda': - with autocast(device_type='cuda'): - logits = _model_forward(model, specs) - loss = criterion(logits, labels) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - # HPU and CPU use standard forward/backward - logits = _model_forward(model, specs) - loss = criterion(logits, labels) - loss.backward() - optimizer.step() - - total_loss += loss.item() * labels.size(0) - total += labels.size(0) - - # Clear cache periodically to prevent memory fragmentation - if device.type == 'cuda': - torch.cuda.empty_cache() - - return total_loss / max(total, 1) - - -@torch.no_grad() -def evaluate( - model, - loader, - device, - debug: Optional[Dict[str, object]] = None, - batch_normalize: bool = False, -) -> Tuple[float, float, float]: - criterion = nn.CrossEntropyLoss(reduction='mean') - model.eval() - total_loss = 0.0 - correct = 0 - total = 0 - all_preds: List[np.ndarray] = [] - all_labels: List[np.ndarray] = [] - - debug_batches = int(debug.get("log_batches", 0)) if debug else 0 - debug_every = max(1, int(debug.get("log_every", 1))) if debug else 1 - log_softmax = bool(debug.get("log_softmax", False)) if debug else False - debug_logged = 0 - - for batch_idx, batch in enumerate(loader, start=1): - stats_batch: Optional[torch.Tensor] - if isinstance(batch, (list, tuple)) and len(batch) == 3: - specs, stats_batch, labels = batch - stats_batch = stats_batch.to(device, non_blocking=True) - else: - specs, labels = batch # type: ignore[misc] - stats_batch = None - specs = specs.to(device, non_blocking=True) - if batch_normalize: - specs = normalize_batch(specs) - labels = labels.to(device, non_blocking=True) - - # Use autocast only for CUDA, not for HPU or CPU - if device.type == 'cuda': - context = autocast(device_type='cuda') - else: - context = nullcontext() - - with context: - logits = _model_forward(model, specs, stats_batch) - loss = criterion(logits, labels) - - preds = logits.argmax(dim=1) - total_loss += loss.item() * labels.size(0) - correct += (preds == labels).sum().item() - total += labels.size(0) - all_preds.append(preds.detach().cpu().numpy()) - all_labels.append(labels.detach().cpu().numpy()) - - should_log = ( - debug_batches > 0 - and debug_logged < debug_batches - and ((batch_idx - 1) % debug_every == 0) - ) - if should_log: - specs_cpu = specs.detach().cpu() - logits_cpu = logits.detach().cpu() - loss_scalar = float(loss.detach().cpu().item()) - finite_specs = torch.isfinite(specs).all().item() - finite_logits = torch.isfinite(logits).all().item() - print( - f" [DEBUG][eval][batch {batch_idx}] loss={loss_scalar:.6f} " - f"reduction={criterion.reduction} labels_shape={tuple(labels.shape)}" - ) - print( - f" specs dtype={specs.dtype} mean={specs_cpu.mean():.4f} std={specs_cpu.std():.4f} " - f"min={specs_cpu.min():.4f} max={specs_cpu.max():.4f} finite={bool(finite_specs)}" - ) - print( - f" logits dtype={logits.dtype} mean={logits_cpu.mean():.4f} std={logits_cpu.std():.4f} " - f"min={logits_cpu.min():.4f} max={logits_cpu.max():.4f} finite={bool(finite_logits)}" - ) - unique_labels, counts = torch.unique(labels.detach().cpu(), return_counts=True) - label_info = ", ".join( - f"{int(lbl)}:{int(cnt)}" for lbl, cnt in zip(unique_labels, counts) - ) - print(f" label distribution -> {label_info}") - if log_softmax: - probs = torch.softmax(logits_cpu, dim=1) - print( - f" softmax mean={probs.mean():.4f} std={probs.std():.4f} " - f"min={probs.min():.4f} max={probs.max():.4f}" - ) - debug_logged += 1 - - # Clear cache periodically - if device.type == 'cuda': - torch.cuda.empty_cache() - - y_true = np.concatenate(all_labels) if all_labels else np.empty(0) - y_pred = np.concatenate(all_preds) if all_preds else np.empty(0) - f1 = compute_f1(y_true, y_pred) if y_true.size > 0 else 0.0 - return total_loss / max(total, 1), correct / max(total, 1), f1 - - -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - # HPU seed setting (if available) - if HPU_AVAILABLE and hasattr(torch.hpu, "manual_seed"): - torch.hpu.manual_seed(seed) - - -def main() -> None: - # Set CUDA memory allocation configuration to reduce fragmentation - if torch.cuda.is_available(): - import os - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' - - args = parse_args() - - if args.early_min_epochs < 10: - print( - f"[INFO] Requested early_min_epochs={args.early_min_epochs} < 10; enforcing minimum of 10" - ) - args.early_min_epochs = 10 - set_seed(args.seed) - - require_checkpoint = any(model.lower() == "lwm" for model in args.models) - checkpoint, stats = resolve_checkpoint_and_stats(args, require_checkpoint=require_checkpoint) - - normalization_mode = str(stats.get("normalization", "dataset")).lower() - print(f"[INFO] Normalization mode from stats: {normalization_mode}") - - data_root = Path(args.data_root) - available_snrs, available_mobilities = discover_snr_mobility( - data_root, args.cities, args.comm_types, args.fft_folder - ) - train_snrs = args.snrs if args.snrs else available_snrs - train_mobilities = args.mobilities if args.mobilities else available_mobilities - val_snrs = args.val_snrs if args.val_snrs else available_snrs - val_mobilities = args.val_mobilities if args.val_mobilities else available_mobilities - - class_configs = build_config_map( - data_root, args.cities, args.comm_types, train_snrs, train_mobilities, args.fft_folder - ) - active_labels = [cls for cls, configs in class_configs.items() if any(configs.values())] - if not active_labels: - raise RuntimeError("No modulation classes found with the provided filters.") - class_configs = {cls: class_configs[cls] for cls in active_labels} - label_to_local = {cls: idx for idx, cls in enumerate(sorted(active_labels))} - num_classes = len(active_labels) - print("[INFO] Active modulation classes:", ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in sorted(active_labels))) - config_arrays = load_config_arrays(class_configs) - global_config_map = build_config_map( - Path(args.data_root), args.cities, args.comm_types, val_snrs, val_mobilities, args.fft_folder - ) - global_config_arrays = load_config_arrays(global_config_map) - print("[INFO] Training SNRs:", ", ".join(train_snrs)) - print("[INFO] Training mobilities:", ", ".join(train_mobilities)) - print("[INFO] Validation/Test SNRs:", ", ".join(val_snrs)) - print("[INFO] Validation/Test mobilities:", ", ".join(val_mobilities)) - - per_class_totals: Dict[int, int] = {} - for cls in sorted(active_labels): - configs = config_arrays[cls] - total_samples = sum(arr.shape[0] for arr in configs.values()) - per_class_totals[cls] = total_samples - print(f"[INFO] Class {LABEL_NAMES.get(cls, str(cls))}: {len(configs)} configs, {total_samples} samples") - if cls not in global_config_arrays or not global_config_arrays[cls]: - raise RuntimeError(f"No global data found for modulation {LABEL_NAMES.get(cls, str(cls))}") - - min_class_total = min(per_class_totals.values()) - max_train_per_class = min_class_total - args.val_per_class - args.test_per_class - if max_train_per_class <= 0: - raise RuntimeError( - "Requested val/test splits leave no data for training. " - f"Minimum class has {min_class_total} samples; " - f"val={args.val_per_class}, test={args.test_per_class}." - ) - if any(size > max_train_per_class for size in args.train_sizes): - adjusted: List[int] = [] - for size in args.train_sizes: - if size > max_train_per_class: - print( - f"[WARN] Requested train size {size} exceeds available " - f"{max_train_per_class} per class after val/test splits; capping." - ) - capped = min(size, max_train_per_class) - if capped not in adjusted: - adjusted.append(capped) - args.train_sizes = adjusted if adjusted else [max_train_per_class] - print(f"[INFO] Effective train sizes per class: {args.train_sizes}") - - # Device selection: auto, cuda, hpu, or cpu - requested_device = args.device.lower() - - if requested_device == "auto": - if HPU_AVAILABLE: - requested_device = "hpu" - elif torch.cuda.is_available(): - requested_device = "cuda" - else: - requested_device = "cpu" - - # Setup device based on selection - if requested_device == "hpu": - if not HPU_AVAILABLE: - raise RuntimeError( - "HPU device requested but not available. " - "Install Habana PyTorch or select --device cuda/cpu." - ) - device = torch.device("hpu") - # Set HPU device (typically device 0 for single-process) - if hasattr(torch.hpu, "set_device"): - torch.hpu.set_device(0) - print(f"[INFO] Using HPU device") - active_gpu_ids = [] # Not applicable for HPU - multi_gpu = False - - elif requested_device == "cuda": - cuda_available = torch.cuda.is_available() - if not cuda_available: - raise RuntimeError("CUDA device requested but not available.") - - available_gpu_ids = list(range(torch.cuda.device_count())) - if args.gpu_ids is not None: - invalid_ids = [gpu_id for gpu_id in args.gpu_ids if gpu_id not in available_gpu_ids] - if invalid_ids: - raise ValueError( - f"Requested GPU IDs not available: {invalid_ids}; available: {available_gpu_ids}" - ) - active_gpu_ids = list(dict.fromkeys(args.gpu_ids)) - else: - active_gpu_ids = available_gpu_ids - - if active_gpu_ids: - primary_gpu = active_gpu_ids[0] - torch.cuda.set_device(primary_gpu) - device = torch.device(f"cuda:{primary_gpu}") - print(f"[INFO] Using CUDA device(s): {', '.join(str(i) for i in active_gpu_ids)}") - else: - device = torch.device("cpu") - print("[INFO] CUDA requested but no GPUs available, using CPU") - - multi_gpu = len(active_gpu_ids) > 1 - if multi_gpu: - print(f"[INFO] Enabling DataParallel across GPUs: {', '.join(str(i) for i in active_gpu_ids)}") - - else: # cpu - device = torch.device("cpu") - if args.gpu_ids is not None: - print("[WARN] GPU IDs specified but using CPU") - print("[INFO] Using CPU") - active_gpu_ids = [] - multi_gpu = False - - print(f"[INFO] Using device: {device}") - args.output_dir.mkdir(parents=True, exist_ok=True) - print(f"[INFO] Saving outputs under: {args.output_dir}") - - eval_debug_config: Optional[Dict[str, object]] = None - if args.debug_eval_batches > 0: - eval_debug_config = { - "log_batches": int(args.debug_eval_batches), - "log_every": max(1, int(args.debug_eval_interval)), - "log_softmax": bool(args.debug_eval_softmax), - } - print( - "[INFO] Evaluation debug logging enabled -> batches:" - f" {eval_debug_config['log_batches']}, interval: {eval_debug_config['log_every']}" - ) - - summary_device = torch.device("cpu") if device.type == "cuda" else device - model_overrides: Dict[str, object] = { - "resnet_head_width": args.resnet_head_width, - "efficientnet_head_width": args.efficientnet_head_width, - "mobilenet_head_width": args.mobilenet_head_width, - "simple_cnn_hidden_dims": tuple(args.simple_cnn_hidden_dims), - "ieee_cnn_hidden_dims": tuple(args.ieee_cnn_hidden_dims), - "ieee_cnn_dropout": args.ieee_cnn_dropout, - "lwm_classifier_dim": args.lwm_classifier_dim, - "lwm_head_dropout": args.lwm_head_dropout, - "lwm_head_type": args.lwm_head_type, - "imagenet_head_dropout": args.imagenet_head_dropout, - "imagenet_weight_decay": args.weight_decay * args.imagenet_weight_decay_scale, - } - print("\n[INFO] Parameter counts per model (total/trainable):") - for model_name in args.models: - lower_name = model_name.lower() - trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0 - model_checkpoint = checkpoint - backbone_lr_factor = args.backbone_lr_factor - if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None: - backbone_lr_factor = args.lwm_backbone_lr_factor - model, _ = build_model( - model_name, - num_classes, - model_checkpoint, - summary_device, - trainable_layers=trainable_layers, - backbone_lr_factor=backbone_lr_factor, - overrides=model_overrides, - ) - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f" {model_name}: {total_params:,} / {trainable_params:,}") - del model - if device.type == 'cuda': - torch.cuda.empty_cache() - - raw_input_models = set(args.raw_input_models) - active_raw_models = [model for model in args.models if model.lower() in raw_input_models] - if active_raw_models: - print( - "[INFO] Raw spectrogram input (per-batch normalization) for models: " - + ", ".join(active_raw_models) - ) - - normalized_models = [model for model in args.models if model.lower() not in raw_input_models] - requires_normalized_inputs = len(normalized_models) > 0 - if requires_normalized_inputs: - print( - "[INFO] Applying normalization for models: " - + ", ".join(normalized_models) - ) - else: - print("[INFO] All selected models consume raw spectrograms; normalization skipped") - - summary: Dict[str, Dict[int, Dict[str, List[float]]]] = { - model: {size: {"acc": [], "f1": [], "val_f1": [], "val_loss": []} for size in args.train_sizes} - for model in args.models - } - - train_sizes_sorted = sorted(args.train_sizes) - - for repetition in range(1, args.repetitions + 1): - selection_for_repetition: Dict[int, Dict[str, set[int]]] = {} - val_rng_seed = args.seed + repetition * 100000 - val_rng = np.random.default_rng(val_rng_seed) - fixed_val_samples: Dict[int, np.ndarray] = {} - fixed_test_samples: Dict[int, np.ndarray] = {} - val_reserved_indices: Dict[int, Dict[str, set[int]]] = {} - test_reserved_indices: Dict[int, Dict[str, set[int]]] = {} - - for cls in sorted(config_arrays.keys()): - global_arrays = global_config_arrays[cls] - global_avail = {cfg: set(range(arr.shape[0])) for cfg, arr in global_arrays.items()} - val_samples, val_used = sample_global_arrays(global_arrays, global_avail, args.val_per_class, val_rng) - test_samples, test_used = sample_global_arrays(global_arrays, global_avail, args.test_per_class, val_rng) - fixed_val_samples[cls] = val_samples - fixed_test_samples[cls] = test_samples - val_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in val_used.items()} - test_reserved_indices[cls] = {cfg: set(indices) for cfg, indices in test_used.items()} - - for train_size in train_sizes_sorted: - rep_seed = args.seed + train_size * 1000 + repetition - rng = np.random.default_rng(rep_seed) - - repetition_records: List[Dict[str, object]] = [] - per_size_val_metrics: List[Tuple[str, float, float, float]] = [] - - train_specs, train_labels = [], [] - val_specs, val_labels = [], [] - test_specs, test_labels = [], [] - - class_contexts: Dict[int, Dict[str, Any]] = {} - class_capacities: Dict[int, int] = {} - for cls in sorted(config_arrays.keys()): - arrays_map = config_arrays[cls] - if not arrays_map: - raise RuntimeError(f"No data for class {LABEL_NAMES[cls]}") - - if cls not in selection_for_repetition: - selection_for_repetition[cls] = defaultdict(set) - prev_config_indices = selection_for_repetition[cls] - - prev_total = sum(len(sel_indices) for sel_indices in prev_config_indices.values()) - if prev_total > train_size: - raise ValueError( - f"Requested train size {train_size} is smaller than previously selected {prev_total} " - f"for class {LABEL_NAMES[cls]}" - ) - - train_avail = {config: set(range(arr.shape[0])) for config, arr in arrays_map.items()} - for config, sel_indices in prev_config_indices.items(): - if sel_indices and config in train_avail: - train_avail[config].difference_update(sel_indices) - val_reserved = val_reserved_indices.get(cls, {}) - test_reserved = test_reserved_indices.get(cls, {}) - for config, reserved in val_reserved.items(): - if reserved and config in train_avail: - train_avail[config].difference_update(reserved) - for config, reserved in test_reserved.items(): - if reserved and config in train_avail: - train_avail[config].difference_update(reserved) - - available_now = sum(len(indices) for indices in train_avail.values()) - capacity = prev_total + available_now - class_contexts[cls] = { - "arrays_map": arrays_map, - "prev_indices": prev_config_indices, - "train_avail": train_avail, - } - class_capacities[cls] = capacity - - if not class_capacities: - raise RuntimeError("No modulation classes available for training") - - min_capacity = min(class_capacities.values()) - limiting_classes = sorted(cls for cls, cap in class_capacities.items() if cap == min_capacity) - effective_train_size = min(train_size, min_capacity) - if effective_train_size < train_size: - limiting_labels = ", ".join(LABEL_NAMES.get(cls, str(cls)) for cls in limiting_classes) - if not limiting_labels: - limiting_labels = "unknown" - print( - f"[WARN] Requested train size {train_size} exceeds available " - f"{min_capacity} after reserving val/test samples; using {effective_train_size} " - f"(limited by {limiting_labels})" - ) - if effective_train_size <= 0: - raise RuntimeError("No training samples available after reserving val/test splits") - - for cls in sorted(config_arrays.keys()): - ctx = class_contexts[cls] - arrays_map = ctx["arrays_map"] - prev_config_indices = ctx["prev_indices"] - train_avail = ctx["train_avail"] - - selected_arrays: List[np.ndarray] = [] - prev_total = 0 - for config, sel_indices in prev_config_indices.items(): - if sel_indices: - idx_sorted = sorted(sel_indices) - selected_arrays.append(arrays_map[config][idx_sorted]) - prev_total += len(sel_indices) - - needed = max(effective_train_size - prev_total, 0) - if needed > 0: - additional_samples, train_used = sample_train_arrays(arrays_map, train_avail, needed, rng) - if additional_samples.size == 0: - raise RuntimeError("Failed to collect additional training samples") - selected_arrays.append(additional_samples) - for config, indices in train_used.items(): - prev_config_indices[config].update(int(idx) for idx in indices) - - if not selected_arrays: - raise RuntimeError("No training samples collected") - - train_samples = np.concatenate(selected_arrays, axis=0) - if train_samples.shape[0] != effective_train_size: - print( - f"[WARN] Collected {train_samples.shape[0]} training samples for " - f"{LABEL_NAMES.get(cls, str(cls))}, expected {effective_train_size}" - ) - - val_samples = fixed_val_samples[cls] - test_samples = fixed_test_samples[cls] - train_specs.append(train_samples) - val_specs.append(val_samples) - test_specs.append(test_samples) - local_label = label_to_local[cls] - train_labels.append(np.full(train_samples.shape[0], local_label, dtype=np.int64)) - val_labels.append(np.full(val_samples.shape[0], local_label, dtype=np.int64)) - test_labels.append(np.full(test_samples.shape[0], local_label, dtype=np.int64)) - - train_specs_raw = np.concatenate(train_specs) - val_specs_raw = np.concatenate(val_specs) - test_specs_raw = np.concatenate(test_specs) - train_labels = np.concatenate(train_labels) - val_labels = np.concatenate(val_labels) - test_labels = np.concatenate(test_labels) - - # Verify no data leakage (all splits are disjoint) - # Note: Since we sample from different configs with availability tracking, - # there should be no overlap, but we verify to be safe - print( - f"[INFO] Verifying data splits for train_size={train_size} " - f"(effective {effective_train_size}), rep={repetition}..." - ) - print( - f" Train: {len(train_labels)} samples " - f"(~{effective_train_size} per class expected)" - ) - print(f" Val: {len(val_labels)} samples ({args.val_per_class} per class)") - print(f" Test: {len(test_labels)} samples ({args.test_per_class} per class)") - - # Check class balance - train_class_counts = Counter(train_labels) - val_class_counts = Counter(val_labels) - test_class_counts = Counter(test_labels) - - print(f"[INFO] Train class distribution: {dict(train_class_counts)}") - print(f"[INFO] Val class distribution: {dict(val_class_counts)}") - print(f"[INFO] Test class distribution: {dict(test_class_counts)}") - - # Verify all classes have expected counts - expected_train_per_class = effective_train_size - for cls_idx in range(num_classes): - if train_class_counts[cls_idx] != expected_train_per_class: - print(f"[WARN] Class {cls_idx} has {train_class_counts[cls_idx]} train samples, expected {expected_train_per_class}") - if val_class_counts[cls_idx] != args.val_per_class: - print(f"[WARN] Class {cls_idx} has {val_class_counts[cls_idx]} val samples, expected {args.val_per_class}") - if test_class_counts[cls_idx] != args.test_per_class: - print(f"[WARN] Class {cls_idx} has {test_class_counts[cls_idx]} test samples, expected {args.test_per_class}") - - print(f"[INFO] ✓ All splits have balanced class distribution") - - train_ds_raw = SpectrogramDataset(train_specs_raw, train_labels) - val_ds_raw = SpectrogramDataset(val_specs_raw, val_labels) - test_ds_raw = SpectrogramDataset(test_specs_raw, test_labels) - - train_loader_raw = DataLoader( - train_ds_raw, - batch_size=args.batch_size, - shuffle=True, - num_workers=2, - pin_memory=False, - ) - val_loader_raw = DataLoader( - val_ds_raw, - batch_size=args.batch_size, - shuffle=False, - num_workers=2, - pin_memory=False, - ) - test_loader_raw = DataLoader( - test_ds_raw, - batch_size=args.batch_size, - shuffle=False, - num_workers=2, - pin_memory=False, - ) - - train_loader_normalized: Optional[DataLoader] = None - val_loader_normalized: Optional[DataLoader] = None - test_loader_normalized: Optional[DataLoader] = None - - if requires_normalized_inputs: - train_specs_normalized = apply_normalization(train_specs_raw, stats) - val_specs_normalized = apply_normalization(val_specs_raw, stats) - test_specs_normalized = apply_normalization(test_specs_raw, stats) - - train_ds_normalized = SpectrogramDataset(train_specs_normalized, train_labels) - val_ds_normalized = SpectrogramDataset(val_specs_normalized, val_labels) - test_ds_normalized = SpectrogramDataset(test_specs_normalized, test_labels) - - train_loader_normalized = DataLoader( - train_ds_normalized, - batch_size=args.batch_size, - shuffle=True, - num_workers=2, - pin_memory=False, - ) - val_loader_normalized = DataLoader( - val_ds_normalized, - batch_size=args.batch_size, - shuffle=False, - num_workers=2, - pin_memory=False, - ) - test_loader_normalized = DataLoader( - test_ds_normalized, - batch_size=args.batch_size, - shuffle=False, - num_workers=2, - pin_memory=False, - ) - - rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" - rep_root.mkdir(parents=True, exist_ok=True) - - for model_name in args.models: - model_root = rep_root / model_name - model_root.mkdir(parents=True, exist_ok=True) - epoch_ckpt_dir: Optional[Path] = None - if args.save_epoch_checkpoints: - epoch_ckpt_dir = model_root / "epoch_checkpoints" - epoch_ckpt_dir.mkdir(parents=True, exist_ok=True) - for old_path in epoch_ckpt_dir.glob("epoch_*.pth"): - if old_path.is_file(): - old_path.unlink() - print( - f"\n[INFO] Size {train_size} (effective {effective_train_size}), " - f"repetition {repetition}, model {model_name}" - ) - set_seed(rep_seed + hash(model_name) % 1000) - lower_name = model_name.lower() - use_raw_input = lower_name in raw_input_models - if use_raw_input: - train_loader = train_loader_raw - val_loader = val_loader_raw - test_loader = test_loader_raw - print(" [INFO] Feeding raw spectrograms with per-batch normalization") - else: - if ( - train_loader_normalized is None - or val_loader_normalized is None - or test_loader_normalized is None - ): - raise RuntimeError( - "Normalized loaders were requested but could not be constructed." - ) - train_loader = train_loader_normalized - val_loader = val_loader_normalized - test_loader = test_loader_normalized - - trainable_layers = args.lwm_trainable_layers if lower_name == "lwm" else 0 - backbone_lr_factor = args.backbone_lr_factor - if lower_name == "lwm" and args.lwm_backbone_lr_factor is not None: - backbone_lr_factor = args.lwm_backbone_lr_factor - model_checkpoint = checkpoint - model, param_groups = build_model( - model_name, - num_classes, - model_checkpoint, - device, - trainable_layers=trainable_layers, - backbone_lr_factor=backbone_lr_factor, - overrides=model_overrides, - ) - if multi_gpu: - model = nn.DataParallel(model, device_ids=active_gpu_ids) - total_params = sum(p.numel() for p in model.parameters()) - trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print( - f"[INFO] Parameters (total/trainable): {total_params:,} / {trainable_params:,}" - ) - def make_optimizer(base_lr: float) -> torch.optim.Optimizer: - optim_groups: List[Dict[str, object]] = [] - if param_groups: - for group in param_groups: - scale = float(group.get("scale", 1.0)) - params = [p for p in group.get("params", []) if p.requires_grad] - if params: - group_cfg: Dict[str, object] = { - "params": list(params), - "lr": base_lr * scale, - } - if "weight_decay" in group: - group_cfg["weight_decay"] = float(group["weight_decay"]) - optim_groups.append(group_cfg) - if not optim_groups: - optim_groups.append({ - "params": [p for p in model.parameters() if p.requires_grad], - "lr": base_lr, - }) - return torch.optim.AdamW(optim_groups, lr=base_lr, weight_decay=args.weight_decay) - - def make_scheduler(optimizer: torch.optim.Optimizer, base_lr: float, patience_limit: int): - plateau_patience = max(2, patience_limit // 2) - return torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - mode="min", - factor=0.5, - patience=plateau_patience, - min_lr=base_lr * 0.01, - ) - - # Initialize mixed precision scaler for CUDA - # GradScaler only for CUDA, not for HPU or CPU - scaler = GradScaler('cuda') if device.type == 'cuda' else None - best_val_loss = float("inf") - best_val_acc = 0.0 - best_state = None - epoch_history: List[Dict[str, object]] = [] - best_val_f1 = 0.0 - best_epoch = 0 - total_epochs_ran = 0 - overall_early_stopped = False - - phase_configs = [ - { - "name": "main", - "max_epochs": args.epochs, - "base_lr": args.lr, - "patience": max(1, args.early_patience), - "min_epochs": max(0, args.early_min_epochs), - } - ] - - ft_epochs = max(0, args.finetune_epochs) - ft_lr_factor = args.finetune_lr_factor - ft_patience = max(1, args.finetune_patience) - ft_min_epochs = max(0, args.finetune_min_epochs) - - if ft_epochs > 0: - phase_configs.append( - { - "name": "finetune", - "max_epochs": ft_epochs, - "base_lr": args.lr * ft_lr_factor, - "patience": ft_patience, - "min_epochs": ft_min_epochs, - } - ) - - for phase_idx, phase in enumerate(phase_configs): - if phase["max_epochs"] <= 0: - continue - - if phase_idx > 0: - print( - f"\n [INFO] Starting {phase['name']} phase: lr={phase['base_lr']:.2e}, " - f"max_epochs={phase['max_epochs']}" - ) - if best_state is not None: - model.load_state_dict(best_state["model"]) - - optimizer = make_optimizer(phase["base_lr"]) - scheduler = make_scheduler(optimizer, phase["base_lr"], phase["patience"]) - patience_counter = 0 - phase_early_stopped = False - phase_epochs_completed = 0 - phase_min_epochs = max(0, phase["min_epochs"]) - - for local_epoch in range(1, phase["max_epochs"] + 1): - overall_epoch = total_epochs_ran + local_epoch - train_loss = train_one_epoch( - model, - train_loader, - optimizer, - device, - scaler, - batch_normalize=use_raw_input, - ) - val_loss, val_acc, val_f1 = evaluate( - model, - val_loader, - device, - eval_debug_config, - batch_normalize=use_raw_input, - ) - scheduler.step(val_loss) - current_lr = optimizer.param_groups[0]["lr"] - print( - f" [{phase['name']}] Epoch {overall_epoch:02d}: " - f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} " - f"val_acc={val_acc:.4%} val_f1={val_f1:.4f}" - ) - epoch_history.append( - { - "epoch": int(overall_epoch), - "train_loss": float(train_loss), - "val_loss": float(val_loss), - "val_acc": float(val_acc), - "val_f1": float(val_f1), - "lr": float(current_lr), - "phase": phase["name"], - } - ) - repetition_records.append( - { - "model": model_name, - "epoch": int(overall_epoch), - "phase": phase["name"], - "train_loss": float(train_loss), - "val_loss": float(val_loss), - "val_acc": float(val_acc), - "val_f1": float(val_f1), - "lr": float(current_lr), - "train_size_requested": int(train_size), - "train_size_effective": int(effective_train_size), - } - ) - _write_epoch_history(rep_root, repetition_records, args.save_epoch_history, args.plot_epoch_history) - raw_epoch_state = _strip_module_prefix(model.state_dict()) - if epoch_ckpt_dir is not None: - epoch_state = raw_epoch_state.__class__() - for key, value in raw_epoch_state.items(): - epoch_state[key] = value.detach().cpu() - epoch_ckpt_path = epoch_ckpt_dir / f"epoch_{overall_epoch:03d}.pth" - torch.save(epoch_state, epoch_ckpt_path) - if val_loss < best_val_loss: - best_val_loss = val_loss - best_val_acc = val_acc - best_val_f1 = val_f1 - best_model_state = { - key: value.detach().cpu() - for key, value in model.state_dict().items() - } - best_state = { - "model": best_model_state, - "val_loss": val_loss, - "val_acc": val_acc, - "val_f1": val_f1, - "epoch": int(overall_epoch), - "lr": current_lr, - "phase": phase["name"], - } - best_epoch = int(overall_epoch) - patience_counter = 0 - else: - if local_epoch >= phase_min_epochs: - patience_counter += 1 - if patience_counter >= phase["patience"]: - print( - f" [INFO] Early stopping ({phase['name']}) at epoch {overall_epoch:02d} " - f"after {patience_counter} epochs without val loss improvement" - ) - overall_early_stopped = True - phase_early_stopped = True - phase_epochs_completed = local_epoch - break - phase_epochs_completed = local_epoch - - total_epochs_ran += phase_epochs_completed - - if phase_early_stopped is False and phase_epochs_completed < phase["max_epochs"]: - # Loop exited early via break without setting the flag (should not happen) - phase_early_stopped = True - - if best_state is None: - raise RuntimeError("Training finished without recording a validation improvement") - - model.load_state_dict(best_state["model"]) - test_loss, test_acc, test_f1 = evaluate( - model, - test_loader, - device, - eval_debug_config, - batch_normalize=use_raw_input, - ) - print( - f" -> Test loss={test_loss:.4f} Test acc={test_acc:.4%} Test f1={test_f1:.4f}" - ) - - export_dir = getattr(args, "export_full_model", None) - if export_dir is not None: - export_dir = export_dir.expanduser().resolve() - export_dir.mkdir(parents=True, exist_ok=True) - comm_token = getattr(args, "comm_suffix", "multi") - filename = f"{comm_token}_{model_name}_size{train_size}_rep{repetition}.pth" - export_path = export_dir / filename - full_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} - torch.save(full_state, export_path) - print(f" [INFO] Saved full model (backbone + head) to {export_path}") - - summary[model_name][train_size]["acc"].append(test_acc) - summary[model_name][train_size]["f1"].append(test_f1) - summary[model_name][train_size]["val_f1"].append(best_state["val_f1"]) - summary[model_name][train_size]["val_loss"].append(best_state["val_loss"]) - per_size_val_metrics.append( - (model_name, best_state["val_f1"], best_state["val_loss"], test_f1) - ) - - result_dir = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" / model_name - result_dir.mkdir(parents=True, exist_ok=True) - state_to_save = copy.deepcopy(best_state) - state_to_save["model"] = _strip_module_prefix(state_to_save["model"]) - torch.save(state_to_save, result_dir / "checkpoint.pt") - with open(result_dir / "metrics.json", "w", encoding="utf-8") as f: - json.dump( - { - "train_size_per_class": effective_train_size, - "train_size_per_class_requested": train_size, - "repetition": repetition, - "model": model_name, - "best_val_loss": best_state.get("val_loss", None), - "best_val_acc": best_val_acc, - "best_val_f1": best_state["val_f1"], - "test_loss": test_loss, - "test_acc": test_acc, - "test_f1": test_f1, - "best_epoch": best_epoch, - "epochs_ran": total_epochs_ran, - "early_stopped": overall_early_stopped, - "history": epoch_history, - }, - f, - indent=2, - ) - - rep_root = args.output_dir / f"size_{train_size}" / f"rep_{repetition}" - rep_root.mkdir(parents=True, exist_ok=True) - - if args.plot_epoch_history and HAVE_MPL and repetition_records: - models_in_run = sorted({rec["model"] for rec in repetition_records}) - fig, axes = plt.subplots(2, 1, figsize=(8, 6), sharex=True) - for ax in axes: - ax.grid(True, linestyle='--', alpha=0.3) - for model_name_plot in models_in_run: - model_records = [rec for rec in repetition_records if rec["model"] == model_name_plot] - if not model_records: - continue - epochs = [rec["epoch"] for rec in model_records] - val_loss_values = [rec["val_loss"] for rec in model_records] - val_f1_values = [rec["val_f1"] for rec in model_records] - axes[0].plot(epochs, val_loss_values, marker='o', label=model_name_plot) - axes[1].plot(epochs, val_f1_values, marker='o', label=model_name_plot) - axes[0].set_ylabel('Val Loss') - axes[1].set_ylabel('Val F1') - axes[1].set_xlabel('Epoch') - axes[0].legend(loc='best') - axes[0].set_title(f'Size {train_size} / Rep {repetition} per-epoch metrics') - fig.tight_layout() - fig.savefig(rep_root / 'epoch_history.png', dpi=150) - plt.close(fig) - - # Clean up memory after each model - del model, optimizer, scheduler, best_state - if scaler is not None: - del scaler - if device.type == 'cuda': - torch.cuda.empty_cache() - torch.cuda.synchronize() - - if per_size_val_metrics: - print(f"\n[INFO] Validation summary for train_size={train_size}, rep={repetition}:") - for model_name, val_f1, val_loss, test_f1 in sorted( - per_size_val_metrics, key=lambda item: item[1], reverse=True - ): - print( - f" {model_name:<16} val_f1={val_f1:.4f} " - f"val_loss={val_loss:.4f} test_f1={test_f1:.4f}" - ) - - summary_path = args.output_dir / "summary.json" - summary_path.parent.mkdir(parents=True, exist_ok=True) - serializable_summary = { - model_name: { - size: { - "acc": metrics["acc"], - "f1": metrics["f1"], - "val_f1": metrics["val_f1"], - "val_loss": metrics["val_loss"], - } - for size, metrics in size_dict.items() - } - for model_name, size_dict in summary.items() - } - with open(summary_path, "w", encoding="utf-8") as f: - json.dump(serializable_summary, f, indent=2) - - print("\n[INFO] Final accuracy summary:") - for model_name, results in summary.items(): - for size, metrics in results.items(): - if metrics["acc"]: - acc_mean = float(np.mean(metrics["acc"])) - acc_std = float(np.std(metrics["acc"])) - f1_mean = float(np.mean(metrics["f1"])) - f1_std = float(np.std(metrics["f1"])) - n = len(metrics["acc"]) - print( - f" {model_name} @ {size:4d}/class -> " - f"acc={acc_mean:.4%} ± {acc_std:.4%}, f1={f1_mean:.4f} ± {f1_std:.4f} (n={n})" - ) - - print("\n[INFO] Final validation F1 summary:") - for model_name, results in summary.items(): - for size, metrics in results.items(): - if metrics["val_f1"]: - val_mean = float(np.mean(metrics["val_f1"])) - val_std = float(np.std(metrics["val_f1"])) - n = len(metrics["val_f1"]) - print( - f" {model_name} @ {size:4d}/class -> " - f"val_f1={val_mean:.4f} ± {val_std:.4f} (n={n})" - ) - - print("\n[INFO] Final validation loss summary:") - for model_name, results in summary.items(): - for size, metrics in results.items(): - if metrics["val_loss"]: - loss_mean = float(np.mean(metrics["val_loss"])) - loss_std = float(np.std(metrics["val_loss"])) - n = len(metrics["val_loss"]) - print( - f" {model_name} @ {size:4d}/class -> " - f"val_loss={loss_mean:.4f} ± {loss_std:.4f} (n={n})" - ) - - if HAVE_MPL: - train_sizes_sorted = sorted(args.train_sizes) - plt.figure(figsize=(8, 5)) - plotted = False - for model_name in args.models: - model_results = summary.get(model_name, {}) - means: List[float] = [] - for size in train_sizes_sorted: - val_list = model_results.get(size, {}).get("val_f1", []) - means.append(float(np.mean(val_list)) if val_list else float("nan")) - if not any(np.isfinite(means)): - continue - plt.plot(train_sizes_sorted, means, marker="o", linewidth=2, label=model_name) - plotted = True - if plotted: - plt.title("Validation F1 vs. Training Size") - plt.xlabel("Training samples per class") - plt.ylabel("Validation F1 (macro)") - plt.xticks(train_sizes_sorted) - plt.ylim(0.0, 1.0) - plt.grid(True, which="both", linestyle="--", alpha=0.4) - plt.legend(title="Model", frameon=False) - plt.tight_layout() - plot_path = args.output_dir / "val_f1_summary.png" - plt.savefig(plot_path, dpi=200) - plt.close() - print(f"[INFO] Saved validation F1 plot to {plot_path}") - else: - plt.close() - print("[WARN] No validation F1 data available to plot.") - else: - print("[WARN] Matplotlib not available; skipping validation F1 plot.") - - -if __name__ == "__main__": - main()