Stefano01 commited on
Commit
dfafaa4
·
verified ·
1 Parent(s): 38cdd14

Upload folder using huggingface_hub

Browse files
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Python 3",
3
+ // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
4
+ "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye",
5
+ "customizations": {
6
+ "codespaces": {
7
+ "openFiles": [
8
+ "README.md",
9
+ "app/streamlit_app.py"
10
+ ]
11
+ },
12
+ "vscode": {
13
+ "settings": {},
14
+ "extensions": [
15
+ "ms-python.python",
16
+ "ms-python.vscode-pylance"
17
+ ]
18
+ }
19
+ },
20
+ "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y <packages.txt; [ -f requirements.txt ] && pip3 install --user -r requirements.txt; pip3 install --user streamlit; echo '✅ Packages installed and Requirements met'",
21
+ "postAttachCommand": {
22
+ "server": "streamlit run app/streamlit_app.py --server.enableCORS false --server.enableXsrfProtection false"
23
+ },
24
+ "portsAttributes": {
25
+ "8501": {
26
+ "label": "Application",
27
+ "onAutoForward": "openPreview"
28
+ }
29
+ },
30
+ "forwardPorts": [
31
+ 8501
32
+ ]
33
+ }
.flake8 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 88
3
+ extend-ignore = E203, W503
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - https://github.com/Stefanoo01/ExplainableCNN
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bytecode & caches
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyo
5
+
6
+ # envs
7
+ .env
8
+ .venv
9
+ venv/
10
+ ENV/
11
+ .conda/
12
+ .ipynb_checkpoints/
13
+
14
+ # OS/editor
15
+ .DS_Store
16
+ Thumbs.db
17
+ .vscode/
18
+ .idea/
19
+
20
+ # build/dist
21
+ build/
22
+ dist/
23
+ *.egg-info/
24
+
25
+ # data & artifacts
26
+ data/
27
+ checkpoints/
28
+ reports/figures/
29
+ reports/cams/
30
+ logs/
31
+ runs/
32
+ wandb/
33
+ input/
34
+ notebooks/reports
35
+ reports
36
+
37
+ # notebooks temp
38
+ *.checkpoint.ipynb
39
+
40
+ # configs with secrets (if any)
41
+ *.secret.*
42
+
43
+ # Model checkpoints
44
+ saved_checkpoints/*.ckpt
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.pre-commit-config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 24.8.0
4
+ hooks:
5
+ - id: black
6
+ args: [--line-length=88]
7
+
8
+ - repo: https://github.com/pycqa/isort
9
+ rev: 5.13.2
10
+ hooks:
11
+ - id: isort
12
+ args: [--profile=black]
13
+
14
+ - repo: https://github.com/pycqa/flake8
15
+ rev: 7.1.1
16
+ hooks:
17
+ - id: flake8
18
+ args: [--max-line-length=88, --extend-ignore=E203,W503]
19
+
20
+ - repo: https://github.com/pre-commit/pre-commit-hooks
21
+ rev: v4.6.0
22
+ hooks:
23
+ - id: end-of-file-fixer
24
+ - id: trailing-whitespace
25
+ - id: check-yaml
26
+ - id: check-json
27
+ - id: check-merge-conflict
.streamlit/presets.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "release_checkpoints": {
3
+ "resnet18_cifar10": "https://github.com/Stefanoo01/ExplainableCNN/releases/download/v1.0.0/resnet18_cifar10.ckpt",
4
+ "smallcnn_fmnist": "https://github.com/Stefanoo01/ExplainableCNN/releases/download/v1.0.0/smallcnn_fmnist.ckpt",
5
+ "smallcnn_aug_fmnist": "https://github.com/Stefanoo01/ExplainableCNN/releases/download/v1.0.0/smallcnn_aug_fmnist.ckpt"
6
+ }
7
+ }
README.MD ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ExplainableCNN
3
+ app_file: app/gradio_app.py
4
+ sdk: gradio
5
+ sdk_version: 5.47.0
6
+ ---
7
+ # ExplainableCNN
8
+
9
+ End‑to‑end image classification with explainability. Train CNNs on common vision datasets, save checkpoints and metrics, and visualize Grad‑CAM/Grad‑CAM++ heatmaps in a Streamlit app.
10
+
11
+ >**Online App**: You can try the app online at `https://explainable-cnn.streamlit.app`
12
+
13
+ Contents
14
+ - Quick start
15
+ - Installation
16
+ - Datasets
17
+ - Training
18
+ - Configuration reference
19
+ - Streamlit Grad‑CAM demo
20
+ - Checkpoints and outputs
21
+ - Project layout
22
+ - FAQ / Tips
23
+
24
+ ## Quick start
25
+
26
+ 1) Install dependencies (CPU‑only by default):
27
+ ```bash
28
+ python -m venv .venv && source .venv/bin/activate
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ 2) Train with defaults (Fashion‑MNIST, small CNN):
33
+ ```bash
34
+ python -m src.train --config configs/baseline.yaml
35
+ ```
36
+
37
+ 3) Launch the Grad‑CAM demo and visualize predictions:
38
+ ```bash
39
+ streamlit run app/streamlit_app.py
40
+ ```
41
+
42
+ ## Installation
43
+
44
+ This repo ships with CPU‑only PyTorch wheels via the official extra index in `requirements.txt`. If you have CUDA, you can install the matching GPU wheels from PyTorch and keep the rest of the requirements.
45
+
46
+ ```bash
47
+ python -m venv .venv
48
+ source .venv/bin/activate
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ ### GPU installation (recommended for training)
53
+
54
+ 1) Install CUDA‑enabled PyTorch that matches your driver and CUDA version (see `https://pytorch.org/get-started/locally/`). Examples:
55
+
56
+ ```bash
57
+ # CUDA 12.1
58
+ pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio
59
+
60
+ # CUDA 11.8
61
+ # pip install --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio
62
+ ```
63
+
64
+ 2) Install the rest of the project dependencies (excluding torch*):
65
+
66
+ ```bash
67
+ pip install -r requirements-gpu.txt
68
+ ```
69
+
70
+ Notes
71
+ - If you want GPU builds: follow the selector at `https://pytorch.org/get-started/locally/` and install the torch/torchvision/torchaudio triplet before installing the rest of the requirements.
72
+ - This project uses: torch/torchvision, torchmetrics, captum/torchcam, lightning (as the newer package name), albumentations, TensorBoard, Streamlit, PyYAML, etc.
73
+
74
+ ## Datasets
75
+
76
+ Built‑in dataset options for training:
77
+ - `fashion-mnist` (default)
78
+ - `mnist`
79
+ - `cifar10`
80
+
81
+ Where data lives
82
+ - By default datasets are downloaded under `data/`.
83
+
84
+ ## Training
85
+
86
+ Run with a YAML config
87
+ ```bash
88
+ python -m src.train --config configs/baseline.yaml
89
+ ```
90
+
91
+ Override config values from the CLI
92
+ ```bash
93
+ python -m src.train --config configs/baseline.yaml --epochs 12 --lr 5e-4
94
+ ```
95
+
96
+ Switch dataset from the CLI
97
+ ```bash
98
+ python -m src.train --config configs/baseline.yaml --dataset mnist
99
+ ```
100
+
101
+ Use a ResNet‑18 backbone for CIFAR‑10 (adapted conv1/no maxpool)
102
+ ```bash
103
+ python -m src.train --config configs/cifar10_resnet18.yaml
104
+ ```
105
+
106
+
107
+ Training flow (high level)
108
+ - Loads YAML config and merges CLI overrides
109
+ - Builds dataloaders with dataset‑specific transforms and normalization
110
+ - Builds model: `smallcnn`, `resnet18_cifar`, or `resnet18_imagenet`
111
+ - Optimizer: Adam (default) or SGD with momentum
112
+ - Trains with early stopping and ReduceLROnPlateau on val loss
113
+ - Writes TensorBoard logs, metrics JSONs, and image reports (confusion matrix)
114
+ - Saves `last.ckpt` and `best.ckpt` with model weights and metadata
115
+
116
+ Outputs per run (under roots from config)
117
+ - `runs/<run_id>/` TensorBoard logs
118
+ - `checkpoints/<run_id>/last.ckpt` and `best.ckpt`
119
+ - `reports/<run_id>/config_effective.yaml`, `metrics.json`, and `figures/confusion_matrix.png`
120
+
121
+ ## Configuration reference
122
+
123
+ See examples in `configs/`:
124
+ - `baseline.yaml` (Fashion‑MNIST + `smallcnn`)
125
+ - `cifar10_resnet18.yaml` (CIFAR‑10 + adapted ResNet‑18)
126
+
127
+ Common keys
128
+ - `dataset`: one of `fashion-mnist`, `mnist`, `cifar10`
129
+ - `model_name`: `smallcnn` | `resnet18_cifar` | `resnet18_imagenet`
130
+ - `data_dir`: root folder for data (default `./data`)
131
+ - `batch_size`, `epochs`, `lr`, `weight_decay`, `num_workers`, `seed`, `device`
132
+ - `img_size`, `mean`, `std`: image shape and normalization stats
133
+ - `optimizer`: `adam` (default) or `sgd`; `momentum` used for SGD
134
+ - `log_root`, `ckpt_root`, `reports_root`: base folders for artifacts
135
+ - `early_stop`: `{ monitor: val_loss|val_acc, mode: min|max, patience, min_delta }`
136
+
137
+ CLI flags can override the YAML. For example `--dataset`, `--epochs`, `--lr`, `--model-name`.
138
+
139
+ ## Streamlit Grad‑CAM demo
140
+
141
+ Start the app (or try it online at `https://explainable-cnn.streamlit.app`)
142
+ ```bash
143
+ streamlit run app/streamlit_app.py
144
+ ```
145
+
146
+ What it does
147
+ - Load a trained checkpoint (`.ckpt`)
148
+ - Upload an image or sample one from the corresponding dataset
149
+ - Run inference and display top‑k predictions
150
+ - Visualize Grad‑CAM or Grad‑CAM++ overlays with adjustable alpha
151
+
152
+ Supplying checkpoints
153
+ - Local discovery: put `.ckpt` files under `saved_checkpoints/` or use the file uploader
154
+ - Download from a URL: paste a direct link to a `.ckpt` asset and click “Download checkpoint”
155
+ - Presets: provide a map of names → URLs via one of:
156
+ - Streamlit secrets: `st.secrets["release_checkpoints"] = { "Name": "https://...best.ckpt" }`
157
+ - `.streamlit/presets.json` or `presets.json` in repo root, either:
158
+ ```json
159
+ { "release_checkpoints": { "FMNIST SmallCNN": "https://.../best.ckpt" } }
160
+ ```
161
+ or a flat mapping `{ "FMNIST SmallCNN": "https://..." }`
162
+ - Environment variable `RELEASE_CKPTS_JSON` with a JSON mapping string
163
+
164
+ Devices and CAM methods
165
+ - Device: `auto` (default), `cuda`, or `cpu`
166
+ - CAM: `Grad-CAM` or `Grad-CAM++` via `torchcam`
167
+
168
+ Checkpoint metadata expected
169
+ - `meta`: `{ dataset, model_name, img_size, mean, std, default_target_layer }`
170
+ - `classes`: list of class names (used to label predictions)
171
+
172
+ ## Checkpoints and outputs
173
+
174
+ Each run writes:
175
+ - Checkpoints: `<ckpt_root>/<run_id>/{last.ckpt,best.ckpt}`
176
+ - Logs: `<log_root>/<run_id>/` (TensorBoard)
177
+ - Reports: `<reports_root>/<run_id>/metrics.json`, `figures/confusion_matrix.png`
178
+
179
+ Best checkpoint selection respects early‑stopping monitor (`val_loss` or `val_acc`).
180
+
181
+ ## License and acknowledgements
182
+
183
+ - Uses `torchcam` for CAM extraction and `captum` as a general explainability dependency
184
+ - TorchVision models and datasets are used for baselines and data handling
185
+
186
+ ___
187
+
188
+ If you run into issues, please open an issue with your command, config file, and environment details.
app/gradio_app.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import random
3
+ from pathlib import Path
4
+ import os
5
+ import hashlib
6
+ import requests
7
+ import json
8
+ import tempfile
9
+
10
+ import numpy as np
11
+ import gradio as gr
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torchvision.models as tvm
16
+ import torchvision.transforms as T
17
+ from PIL import Image
18
+ from torchcam.methods import GradCAM, GradCAMpp
19
+ from torchcam.utils import overlay_mask
20
+ from torchvision.datasets import CIFAR10, MNIST, FashionMNIST
21
+
22
+ # Global state for model and configuration
23
+ app_state = {
24
+ "model": None,
25
+ "classes": None,
26
+ "meta": None,
27
+ "transform": None,
28
+ "target_layer": None,
29
+ "dataset": None,
30
+ "dataset_classes": None
31
+ }
32
+
33
+ custom_theme = gr.themes.Soft(
34
+ primary_hue="green", # main brand color
35
+ secondary_hue="purple", # accent color
36
+ neutral_hue="slate" # backgrounds/borders/text neutrals
37
+ )
38
+
39
+ def download_release_asset(url: str, dest_dir: str = "saved_checkpoints") -> str:
40
+ """Download a remote checkpoint to dest_dir and return its local path."""
41
+ Path(dest_dir).mkdir(parents=True, exist_ok=True)
42
+ url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
43
+ fname = Path(url).name or f"asset_{url_hash}.ckpt"
44
+ if not fname.endswith(".ckpt"):
45
+ fname = f"{fname}.ckpt"
46
+ local_path = Path(dest_dir) / f"{url_hash}_{fname}"
47
+
48
+ if local_path.exists() and local_path.stat().st_size > 0:
49
+ return str(local_path)
50
+
51
+ with requests.get(url, stream=True, timeout=120) as r:
52
+ r.raise_for_status()
53
+ with open(local_path, "wb") as f:
54
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
55
+ if chunk:
56
+ f.write(chunk)
57
+ return str(local_path)
58
+
59
+
60
+ def load_release_presets() -> dict:
61
+ """Load release preset URLs from multiple sources."""
62
+ # Try environment variable containing JSON mapping
63
+ env_json = os.environ.get("RELEASE_CKPTS_JSON", "").strip()
64
+ if env_json:
65
+ try:
66
+ data = json.loads(env_json)
67
+ if isinstance(data, dict):
68
+ return dict(data)
69
+ except Exception:
70
+ pass
71
+
72
+ # Try local JSON files for dev
73
+ for rel in (".streamlit/presets.json", "presets.json"):
74
+ p = Path(rel)
75
+ if p.exists():
76
+ try:
77
+ with open(p, "r", encoding="utf-8") as f:
78
+ data = json.load(f)
79
+ if isinstance(data, dict) and data:
80
+ if "release_checkpoints" in data and isinstance(data["release_checkpoints"], dict):
81
+ return dict(data["release_checkpoints"])
82
+ return dict(data)
83
+ except Exception:
84
+ pass
85
+
86
+ return {}
87
+
88
+
89
+ def get_device(choice="auto"):
90
+ if choice == "cpu":
91
+ return "cpu"
92
+ if choice == "cuda":
93
+ return "cuda"
94
+ return "cuda" if torch.cuda.is_available() else "cpu"
95
+
96
+
97
+ def denorm_to_pil(x, mean, std):
98
+ """Convert normalized tensor to PIL Image."""
99
+ x = x.detach().cpu().clone()
100
+ if len(mean) == 1:
101
+ # grayscale
102
+ m, s = float(mean[0]), float(std[0])
103
+ x = x * s + m
104
+ x = x.clamp(0, 1)
105
+ pil = T.ToPILImage()(x)
106
+ pil = pil.convert("RGB")
107
+ return pil
108
+ else:
109
+ mean = torch.tensor(mean)[:, None, None]
110
+ std = torch.tensor(std)[:, None, None]
111
+ x = x * std + mean
112
+ x = x.clamp(0, 1)
113
+ return T.ToPILImage()(x)
114
+
115
+
116
+ DATASET_CLASSES = {
117
+ "fashion-mnist": [
118
+ "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
119
+ "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
120
+ ],
121
+ "mnist": [str(i) for i in range(10)],
122
+ "cifar10": [
123
+ "airplane", "automobile", "bird", "cat", "deer",
124
+ "dog", "frog", "horse", "ship", "truck",
125
+ ],
126
+ }
127
+
128
+
129
+ def load_raw_dataset(name: str, root="data"):
130
+ """Load the test split with ToTensor() only (for preview)."""
131
+ tt = T.ToTensor()
132
+ if name == "fashion-mnist":
133
+ ds = FashionMNIST(root=root, train=False, download=True, transform=tt)
134
+ elif name == "mnist":
135
+ ds = MNIST(root=root, train=False, download=True, transform=tt)
136
+ elif name == "cifar10":
137
+ ds = CIFAR10(root=root, train=False, download=True, transform=tt)
138
+ else:
139
+ raise ValueError(f"Unknown dataset: {name}")
140
+ classes = getattr(ds, "classes", None) or [str(i) for i in range(10)]
141
+ return ds, classes
142
+
143
+
144
+ def pil_from_tensor(img_tensor, grayscale_to_rgb=True):
145
+ pil = T.ToPILImage()(img_tensor)
146
+ if grayscale_to_rgb and img_tensor.ndim == 3 and img_tensor.shape[0] == 1:
147
+ pil = pil.convert("RGB")
148
+ return pil
149
+
150
+
151
+ class SmallCNN(nn.Module):
152
+ def __init__(self, num_classes=10):
153
+ super().__init__()
154
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
155
+ self.pool1 = nn.MaxPool2d(2, 2)
156
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
157
+ self.pool2 = nn.MaxPool2d(2, 2)
158
+ self.fc = nn.Linear(64 * 7 * 7, num_classes)
159
+
160
+ def forward(self, x):
161
+ x = F.relu(self.conv1(x))
162
+ x = self.pool1(x)
163
+ x = F.relu(self.conv2(x))
164
+ x = self.pool2(x)
165
+ x = torch.flatten(x, 1)
166
+ return self.fc(x)
167
+
168
+
169
+ def load_model_from_ckpt(ckpt_path: Path, device: str):
170
+ ckpt = torch.load(str(ckpt_path), map_location=device)
171
+ classes = ckpt.get("classes", None)
172
+ meta = ckpt.get("meta", {})
173
+ num_classes = len(classes) if classes else 10
174
+ model_name = meta.get("model_name", "smallcnn")
175
+
176
+ if model_name == "smallcnn":
177
+ model = SmallCNN(num_classes=num_classes).to(device)
178
+ default_target_layer = "conv2"
179
+ elif model_name == "resnet18_cifar":
180
+ m = tvm.resnet18(weights=None)
181
+ m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
182
+ m.maxpool = nn.Identity()
183
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
184
+ model = m.to(device)
185
+ default_target_layer = "layer4"
186
+ elif model_name == "resnet18_imagenet":
187
+ try:
188
+ w = tvm.ResNet18_Weights.IMAGENET1K_V1
189
+ except Exception:
190
+ w = None
191
+ m = tvm.resnet18(weights=w)
192
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
193
+ model = m.to(device)
194
+ default_target_layer = "layer4"
195
+ else:
196
+ raise ValueError(f"Unknown model_name in ckpt: {model_name}")
197
+
198
+ model.load_state_dict(ckpt["model_state"])
199
+ model.eval()
200
+ meta.setdefault("default_target_layer", default_target_layer)
201
+ return model, classes, meta
202
+
203
+
204
+ def build_transform_from_meta(meta):
205
+ img_size = int(meta.get("img_size", 28))
206
+ mean = meta.get("mean", [0.2860])
207
+ std = meta.get("std", [0.3530])
208
+ if len(mean) == 1:
209
+ return T.Compose([
210
+ T.Grayscale(num_output_channels=1),
211
+ T.Resize((img_size, img_size)),
212
+ T.ToTensor(),
213
+ T.Normalize(mean, std),
214
+ ])
215
+ else:
216
+ return T.Compose([
217
+ T.Resize((img_size, img_size)),
218
+ T.ToTensor(),
219
+ T.Normalize(mean, std),
220
+ ])
221
+
222
+
223
+ def predict_and_cam(model, x, device, target_layer, topk=3, method="Grad-CAM"):
224
+ """Predict and generate CAM for top-k classes."""
225
+ cam_cls = GradCAM if method == "Grad-CAM" else GradCAMpp
226
+ cam_extractor = cam_cls(model, target_layer=target_layer)
227
+
228
+ logits = model(x.to(device))
229
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu()
230
+ top_vals, top_idxs = probs.topk(topk)
231
+
232
+ results = []
233
+ for rank, (p, idx) in enumerate(zip(top_vals.tolist(), top_idxs.tolist())):
234
+ retain = rank < topk - 1
235
+ cams = cam_extractor(idx, logits, retain_graph=retain)
236
+ cam = cams[0].detach().cpu()
237
+ results.append({
238
+ "rank": rank + 1,
239
+ "class_index": int(idx),
240
+ "prob": float(p),
241
+ "cam": cam
242
+ })
243
+ return results, probs
244
+
245
+
246
+ def overlay_pil(base_pil_rgb: Image.Image, cam_tensor, alpha=0.5):
247
+ """Create overlay of CAM on base image."""
248
+ cam = cam_tensor.clone()
249
+ cam -= cam.min()
250
+ cam = cam / (cam.max() + 1e-8)
251
+ heat = T.ToPILImage()(cam)
252
+ return overlay_mask(base_pil_rgb, heat, alpha=alpha)
253
+
254
+
255
+ # Gradio interface functions
256
+ def load_checkpoint_from_url(url, preset_name):
257
+ """Load checkpoint from URL or preset."""
258
+ presets = load_release_presets()
259
+
260
+ if preset_name and preset_name != "None":
261
+ url = presets.get(preset_name, "")
262
+
263
+ if not url:
264
+ return "❌ No URL provided", "", ""
265
+
266
+ try:
267
+ ckpt_path = download_release_asset(url)
268
+ device = get_device("cpu")
269
+ model, classes, meta = load_model_from_ckpt(Path(ckpt_path), device)
270
+
271
+ # Update global state
272
+ app_state["model"] = model
273
+ app_state["classes"] = classes
274
+ app_state["meta"] = meta
275
+ app_state["transform"] = build_transform_from_meta(meta)
276
+ app_state["target_layer"] = meta.get("default_target_layer", "conv2")
277
+
278
+ # Load dataset for samples
279
+ ds_name = meta.get("dataset", "fashion-mnist")
280
+ try:
281
+ dataset, dataset_classes = load_raw_dataset(ds_name)
282
+ app_state["dataset"] = dataset
283
+ app_state["dataset_classes"] = dataset_classes
284
+ except:
285
+ app_state["dataset"] = None
286
+ app_state["dataset_classes"] = None
287
+
288
+ meta_info = {
289
+ "dataset": meta.get("dataset"),
290
+ "model_name": meta.get("model_name"),
291
+ "img_size": meta.get("img_size"),
292
+ "target_layer": app_state["target_layer"],
293
+ "mean": meta.get("mean"),
294
+ "std": meta.get("std"),
295
+ "classes": len(classes) if classes else "N/A"
296
+ }
297
+
298
+ # Create class choices for filter
299
+ class_choices = ["(any)"] + (dataset_classes if app_state["dataset"] else [])
300
+ max_samples = len(dataset) - 1 if app_state["dataset"] else 0
301
+
302
+ return (f"✅ Loaded: {ckpt_path}", json.dumps(meta_info, indent=2),
303
+ gr.update(visible=True), gr.update(choices=class_choices, value="(any)"),
304
+ gr.update(visible=True, maximum=max_samples, value=0))
305
+
306
+ except Exception as e:
307
+ return f"❌ Failed: {str(e)}", "", gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False), gr.update(choices=["(any)"], value="(any)"), gr.update(visible=False)
308
+
309
+
310
+ def load_checkpoint_from_file(file):
311
+ """Load checkpoint from uploaded file."""
312
+ if file is None:
313
+ return "❌ No file uploaded", "", ""
314
+
315
+ try:
316
+ # Save uploaded file temporarily
317
+ Path("saved_checkpoints").mkdir(parents=True, exist_ok=True)
318
+ with open(file.name, "rb") as f:
319
+ content = f.read()
320
+
321
+ content_hash = hashlib.sha256(content).hexdigest()[:16]
322
+ base_name = Path(file.name).name
323
+ if not base_name.endswith(".ckpt"):
324
+ base_name = f"{base_name}.ckpt"
325
+ local_path = Path("saved_checkpoints") / f"{content_hash}_{base_name}"
326
+
327
+ with open(local_path, "wb") as f:
328
+ f.write(content)
329
+
330
+ device = get_device("cpu")
331
+ model, classes, meta = load_model_from_ckpt(local_path, device)
332
+
333
+ # Update global state
334
+ app_state["model"] = model
335
+ app_state["classes"] = classes
336
+ app_state["meta"] = meta
337
+ app_state["transform"] = build_transform_from_meta(meta)
338
+ app_state["target_layer"] = meta.get("default_target_layer", "conv2")
339
+
340
+ # Load dataset for samples
341
+ ds_name = meta.get("dataset", "fashion-mnist")
342
+ try:
343
+ dataset, dataset_classes = load_raw_dataset(ds_name)
344
+ app_state["dataset"] = dataset
345
+ app_state["dataset_classes"] = dataset_classes
346
+ except:
347
+ app_state["dataset"] = None
348
+ app_state["dataset_classes"] = None
349
+
350
+ meta_info = {
351
+ "dataset": meta.get("dataset"),
352
+ "model_name": meta.get("model_name"),
353
+ "img_size": meta.get("img_size"),
354
+ "target_layer": app_state["target_layer"],
355
+ "mean": meta.get("mean"),
356
+ "std": meta.get("std"),
357
+ "classes": len(classes) if classes else "N/A"
358
+ }
359
+
360
+ # Create class choices for filter
361
+ class_choices = ["(any)"] + (dataset_classes if app_state["dataset"] else [])
362
+ max_samples = len(dataset) - 1 if app_state["dataset"] else 0
363
+
364
+ return (f"✅ Loaded: {local_path}", json.dumps(meta_info, indent=2),
365
+ gr.update(visible=True), gr.update(choices=class_choices, value="(any)"),
366
+ gr.update(visible=True, maximum=max_samples, value=0))
367
+
368
+ except Exception as e:
369
+ return f"❌ Failed: {str(e)}", "", gr.update(visible=False)
370
+
371
+
372
+ def get_random_sample():
373
+ """Get a random sample from the loaded dataset."""
374
+ if app_state["dataset"] is None:
375
+ return None, "No dataset loaded", gr.update(visible=False)
376
+
377
+ dataset = app_state["dataset"]
378
+ idx = random.randint(0, len(dataset) - 1)
379
+ img_tensor, label = dataset[idx]
380
+ sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
381
+
382
+ class_name = app_state["dataset_classes"][label] if app_state["dataset_classes"] else str(label)
383
+ caption = f"Sample from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name} • idx: {idx}"
384
+
385
+ # Update slider maximum and current value
386
+ max_idx = len(dataset) - 1
387
+ return sample_img, caption, gr.update(visible=True, maximum=max_idx, value=idx)
388
+
389
+
390
+ def get_sample_by_index(idx, class_filter):
391
+ """Get a specific sample by index with optional class filtering."""
392
+ if app_state["dataset"] is None:
393
+ return None, "No dataset loaded"
394
+
395
+ dataset = app_state["dataset"]
396
+ dataset_classes = app_state["dataset_classes"]
397
+
398
+ # Apply class filter
399
+ if class_filter != "(any)":
400
+ targets = np.array([dataset[i][1] for i in range(len(dataset))])
401
+ class_id = dataset_classes.index(class_filter)
402
+ filtered_indices = np.where(targets == class_id)[0]
403
+
404
+ if len(filtered_indices) == 0:
405
+ return None, f"No samples found for class: {class_filter}"
406
+
407
+ # Clamp index to filtered range
408
+ idx = max(0, min(idx, len(filtered_indices) - 1))
409
+ actual_idx = filtered_indices[idx]
410
+ else:
411
+ # Clamp index to dataset range
412
+ idx = max(0, min(idx, len(dataset) - 1))
413
+ actual_idx = idx
414
+
415
+ img_tensor, label = dataset[actual_idx]
416
+ sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
417
+
418
+ class_name = dataset_classes[label] if dataset_classes else str(label)
419
+ caption = f"Sample {actual_idx} from {app_state['meta'].get('dataset', 'dataset')} • class: {class_name}"
420
+
421
+ return sample_img, caption
422
+
423
+
424
+ def update_class_filter(class_filter):
425
+ """Update the slider range when class filter changes."""
426
+ if app_state["dataset"] is None:
427
+ return gr.update(visible=False, maximum=0, value=0)
428
+
429
+ dataset = app_state["dataset"]
430
+ dataset_classes = app_state["dataset_classes"]
431
+
432
+ if class_filter == "(any)":
433
+ max_idx = len(dataset) - 1
434
+ else:
435
+ targets = np.array([dataset[i][1] for i in range(len(dataset))])
436
+ class_id = dataset_classes.index(class_filter)
437
+ filtered_indices = np.where(targets == class_id)[0]
438
+ max_idx = len(filtered_indices) - 1 if len(filtered_indices) > 0 else 0
439
+
440
+ return gr.update(visible=True, maximum=max_idx, value=0)
441
+
442
+
443
+ def process_image(image, method, topk, alpha):
444
+ """Process image and generate Grad-CAM visualizations."""
445
+ if app_state["model"] is None:
446
+ return "❌ No model loaded", [], []
447
+
448
+ if image is None:
449
+ return "❌ No image provided", [], []
450
+
451
+ try:
452
+ # Convert to PIL if needed
453
+ if isinstance(image, np.ndarray):
454
+ image = Image.fromarray(image)
455
+
456
+ # Prepare image
457
+ pil = image.convert("RGB")
458
+ x = app_state["transform"](pil)
459
+ x_batched = x.unsqueeze(0)
460
+
461
+ # Generate base image for overlay
462
+ base_pil = denorm_to_pil(
463
+ x,
464
+ app_state["meta"].get("mean", [0.2860]),
465
+ app_state["meta"].get("std", [0.3530])
466
+ )
467
+
468
+ # Run prediction and CAM
469
+ device = get_device("cpu")
470
+ cam_results, probs = predict_and_cam(
471
+ app_state["model"], x_batched, device,
472
+ app_state["target_layer"], topk=topk, method=method
473
+ )
474
+
475
+ # Create predictions table
476
+ predictions = []
477
+ for r in cam_results:
478
+ class_name = app_state["classes"][r["class_index"]] if app_state["classes"] else str(r["class_index"])
479
+ predictions.append([
480
+ r["rank"],
481
+ class_name,
482
+ r["class_index"],
483
+ f"{r['prob']:.4f}"
484
+ ])
485
+
486
+ # Create overlay images
487
+ overlays = []
488
+ for r in cam_results:
489
+ class_name = app_state["classes"][r["class_index"]] if app_state["classes"] else str(r["class_index"])
490
+ overlay_img = overlay_pil(base_pil, r["cam"], alpha=alpha)
491
+ overlays.append((overlay_img, f"Top{r['rank']}: {class_name} ({r['prob']:.3f})"))
492
+
493
+ return "✅ Processing complete", predictions, overlays
494
+
495
+ except Exception as e:
496
+ return f"❌ Processing failed: {str(e)}", [], []
497
+
498
+
499
+ # Create Gradio interface
500
+ def create_interface():
501
+ presets = load_release_presets()
502
+ preset_choices = ["None"] + list(presets.keys()) if presets else ["None"]
503
+
504
+ with gr.Blocks(title="🔍 Grad-CAM Demo", theme=custom_theme) as demo:
505
+ gr.Markdown("# 🔍 Grad-CAM Demo — Upload an image, get top-k predictions + heatmaps")
506
+
507
+ with gr.Row():
508
+ with gr.Column(scale=1):
509
+ gr.Markdown("## Settings")
510
+
511
+ # Checkpoint loading
512
+ gr.Markdown("### Load Checkpoint")
513
+ with gr.Group():
514
+ preset_dropdown = gr.Dropdown(
515
+ choices=preset_choices,
516
+ value="None",
517
+ label="Preset (GitHub Releases)"
518
+ )
519
+ url_input = gr.Textbox(
520
+ label="Or paste asset URL",
521
+ placeholder="https://github.com/user/repo/releases/download/..."
522
+ )
523
+ url_button = gr.Button("Download from URL", variant="primary")
524
+
525
+ with gr.Group():
526
+ file_input = gr.File(
527
+ label="Upload checkpoint (.ckpt)",
528
+ file_types=[".ckpt"]
529
+ )
530
+ file_button = gr.Button("Load uploaded file", variant="primary")
531
+
532
+ status_text = gr.Textbox(
533
+ label="Status",
534
+ interactive=False,
535
+ value="No checkpoint loaded"
536
+ )
537
+
538
+ meta_display = gr.Code(
539
+ label="Model Metadata",
540
+ language="json",
541
+ interactive=False
542
+ )
543
+
544
+ # Processing options
545
+ gr.Markdown("### Processing Options")
546
+ method_radio = gr.Radio(
547
+ choices=["Grad-CAM", "Grad-CAM++"],
548
+ value="Grad-CAM",
549
+ label="CAM Method"
550
+ )
551
+ topk_slider = gr.Slider(
552
+ minimum=1, maximum=10, value=3, step=1,
553
+ label="Top-k classes"
554
+ )
555
+ alpha_slider = gr.Slider(
556
+ minimum=0.1, maximum=0.9, value=0.5, step=0.05,
557
+ label="Overlay alpha"
558
+ )
559
+
560
+ with gr.Column(scale=2):
561
+ gr.Markdown("## Image Input")
562
+
563
+ with gr.Group():
564
+ image_input = gr.Image(
565
+ label="Upload Image",
566
+ type="pil"
567
+ )
568
+
569
+ with gr.Row():
570
+ sample_button = gr.Button("Random Sample", visible=False)
571
+
572
+ with gr.Group():
573
+ gr.Markdown("**Dataset Sample Browser**")
574
+ class_filter = gr.Dropdown(
575
+ label="Filter by class",
576
+ choices=["(any)"],
577
+ value="(any)",
578
+ visible=False
579
+ )
580
+ sample_slider = gr.Slider(
581
+ label="Sample index",
582
+ minimum=0,
583
+ maximum=0,
584
+ value=0,
585
+ step=1,
586
+ visible=False,
587
+ interactive=True
588
+ )
589
+ sample_info = gr.Textbox(
590
+ label="Sample Info",
591
+ interactive=False,
592
+ visible=False
593
+ )
594
+
595
+ process_button = gr.Button("🔍 Process Image", variant="primary", size="lg")
596
+ process_status = gr.Textbox(
597
+ label="Processing Status",
598
+ interactive=False
599
+ )
600
+
601
+ gr.Markdown("## Results")
602
+
603
+ with gr.Group():
604
+ gr.Markdown("### Top-k Predictions")
605
+ predictions_table = gr.Dataframe(
606
+ headers=["Rank", "Class", "Index", "Probability"],
607
+ datatype=["number", "str", "number", "str"],
608
+ interactive=False
609
+ )
610
+
611
+ with gr.Group():
612
+ gr.Markdown("### Grad-CAM Overlays")
613
+ overlay_gallery = gr.Gallery(
614
+ label="CAM Overlays",
615
+ show_label=False,
616
+ elem_id="gallery",
617
+ columns=3,
618
+ rows=2,
619
+ object_fit="contain",
620
+ height="auto"
621
+ )
622
+
623
+ # Event handlers
624
+ url_button.click(
625
+ fn=load_checkpoint_from_url,
626
+ inputs=[url_input, preset_dropdown],
627
+ outputs=[status_text, meta_display, sample_button, class_filter, sample_slider]
628
+ )
629
+
630
+ file_button.click(
631
+ fn=load_checkpoint_from_file,
632
+ inputs=[file_input],
633
+ outputs=[status_text, meta_display, sample_button, class_filter, sample_slider]
634
+ )
635
+
636
+ sample_button.click(
637
+ fn=get_random_sample,
638
+ outputs=[image_input, sample_info, sample_slider]
639
+ )
640
+
641
+ class_filter.change(
642
+ fn=update_class_filter,
643
+ inputs=[class_filter],
644
+ outputs=[sample_slider]
645
+ )
646
+
647
+ sample_slider.change(
648
+ fn=get_sample_by_index,
649
+ inputs=[sample_slider, class_filter],
650
+ outputs=[image_input, sample_info]
651
+ )
652
+
653
+ process_button.click(
654
+ fn=process_image,
655
+ inputs=[image_input, method_radio, topk_slider, alpha_slider],
656
+ outputs=[process_status, predictions_table, overlay_gallery]
657
+ )
658
+
659
+ return demo
660
+
661
+
662
+ if __name__ == "__main__":
663
+ demo = create_interface()
664
+ demo.launch(
665
+ share=True,
666
+ server_name="0.0.0.0",
667
+ server_port=7860
668
+ )
app/streamlit_app.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import random
3
+ from pathlib import Path
4
+ import os
5
+ import hashlib
6
+ import requests
7
+ import json
8
+
9
+ import numpy as np
10
+ import streamlit as st
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torchvision.models as tvm
15
+ import torchvision.transforms as T
16
+ from PIL import Image
17
+ from torchcam.methods import GradCAM, GradCAMpp
18
+ from torchcam.utils import overlay_mask
19
+ from torchvision.datasets import CIFAR10, MNIST, FashionMNIST
20
+
21
+ # Persist selected checkpoint across reruns
22
+ if "ckpt_path" not in st.session_state:
23
+ st.session_state["ckpt_path"] = None
24
+
25
+
26
+ @st.cache_data(show_spinner=True)
27
+ def download_release_asset(url: str, dest_dir: str = "saved_checkpoints") -> str:
28
+ """Download a remote checkpoint to dest_dir and return its local path.
29
+ Cached so subsequent reruns won't redownload.
30
+ """
31
+ Path(dest_dir).mkdir(parents=True, exist_ok=True)
32
+ url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
33
+ fname = Path(url).name or f"asset_{url_hash}.ckpt"
34
+ if not fname.endswith(".ckpt"):
35
+ fname = f"{fname}.ckpt"
36
+ local_path = Path(dest_dir) / f"{url_hash}_{fname}"
37
+ if local_path.exists() and local_path.stat().st_size > 0:
38
+ return str(local_path)
39
+ with requests.get(url, stream=True, timeout=120) as r:
40
+ r.raise_for_status()
41
+ with open(local_path, "wb") as f:
42
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
43
+ if chunk:
44
+ f.write(chunk)
45
+ return str(local_path)
46
+
47
+
48
+ def load_release_presets() -> dict:
49
+ """Load release preset URLs from multiple sources.
50
+ Order: Streamlit secrets → .streamlit/presets.json → presets.json → env var RELEASE_CKPTS_JSON.
51
+ Returns a dict name -> url. Safe if nothing is configured.
52
+ """
53
+ # 1) Streamlit secrets
54
+ try:
55
+ if hasattr(st, "secrets") and "release_checkpoints" in st.secrets:
56
+ # Convert to plain dict in case it's a Secrets object
57
+ return dict(st.secrets["release_checkpoints"]) # type: ignore[index]
58
+ except Exception:
59
+ pass
60
+
61
+ # 2) Local JSON files for dev
62
+ for rel in (".streamlit/presets.json", "presets.json"):
63
+ p = Path(rel)
64
+ if p.exists():
65
+ try:
66
+ with open(p, "r", encoding="utf-8") as f:
67
+ data = json.load(f)
68
+ # Either the file is a mapping directly, or has a top-level key
69
+ if isinstance(data, dict) and data:
70
+ if "release_checkpoints" in data and isinstance(data["release_checkpoints"], dict):
71
+ return dict(data["release_checkpoints"]) # nested
72
+ return dict(data) # flat mapping
73
+ except Exception:
74
+ pass
75
+
76
+ # 3) Environment variable containing JSON mapping
77
+ env_json = os.environ.get("RELEASE_CKPTS_JSON", "").strip()
78
+ if env_json:
79
+ try:
80
+ data = json.loads(env_json)
81
+ if isinstance(data, dict):
82
+ return dict(data)
83
+ except Exception:
84
+ pass
85
+
86
+ return {}
87
+
88
+
89
+ # ---------- Small utilities ----------
90
+ def get_device(choice="auto"):
91
+ if choice == "cpu":
92
+ return "cpu"
93
+ if choice == "cuda":
94
+ return "cuda"
95
+ return "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+
98
+ def find_latest_best_ckpt():
99
+ ckpts = sorted(
100
+ Path("checkpoints").rglob("best.ckpt"), key=lambda p: p.stat().st_mtime
101
+ )
102
+ return ckpts[-1] if ckpts else None
103
+
104
+
105
+ def denorm_to_pil(x, mean, std):
106
+ """
107
+ x: torch.Tensor CxHxW (normalized), mean/std lists
108
+ returns PIL.Image (RGB)
109
+ """
110
+ x = x.detach().cpu().clone()
111
+ if len(mean) == 1:
112
+ # grayscale
113
+ m, s = float(mean[0]), float(std[0])
114
+ x = x * s + m # de-normalize
115
+ x = x.clamp(0, 1)
116
+ # convert to RGB for overlay convenience
117
+ pil = T.ToPILImage()(x)
118
+ pil = pil.convert("RGB")
119
+ return pil
120
+ else:
121
+ mean = torch.tensor(mean)[:, None, None]
122
+ std = torch.tensor(std)[:, None, None]
123
+ x = x * std + mean
124
+ x = x.clamp(0, 1)
125
+ return T.ToPILImage()(x)
126
+
127
+
128
+ DATASET_CLASSES = {
129
+ "fashion-mnist": [
130
+ "T-shirt/top",
131
+ "Trouser",
132
+ "Pullover",
133
+ "Dress",
134
+ "Coat",
135
+ "Sandal",
136
+ "Shirt",
137
+ "Sneaker",
138
+ "Bag",
139
+ "Ankle boot",
140
+ ],
141
+ "mnist": [str(i) for i in range(10)],
142
+ "cifar10": [
143
+ "airplane",
144
+ "automobile",
145
+ "bird",
146
+ "cat",
147
+ "deer",
148
+ "dog",
149
+ "frog",
150
+ "horse",
151
+ "ship",
152
+ "truck",
153
+ ],
154
+ }
155
+
156
+
157
+ @st.cache_resource
158
+ def load_raw_dataset(name: str, root="data"):
159
+ """Load the test split with ToTensor() only (for preview)."""
160
+ tt = T.ToTensor()
161
+ if name == "fashion-mnist":
162
+ ds = FashionMNIST(root=root, train=False, download=True, transform=tt)
163
+ elif name == "mnist":
164
+ ds = MNIST(root=root, train=False, download=True, transform=tt)
165
+ elif name == "cifar10":
166
+ ds = CIFAR10(root=root, train=False, download=True, transform=tt)
167
+ else:
168
+ raise ValueError(f"Unknown dataset: {name}")
169
+ classes = getattr(ds, "classes", None) or [str(i) for i in range(10)]
170
+ return ds, classes
171
+
172
+
173
+ def pil_from_tensor(img_tensor, grayscale_to_rgb=True):
174
+ pil = T.ToPILImage()(img_tensor)
175
+ if grayscale_to_rgb and img_tensor.ndim == 3 and img_tensor.shape[0] == 1:
176
+ pil = pil.convert("RGB")
177
+ return pil
178
+
179
+
180
+ @st.cache_data(ttl=5, show_spinner=False)
181
+ def list_ckpts(root_dir: str, recursive: bool = True, filter: str = ""):
182
+ """Return (labels, paths) sorted by mtime desc."""
183
+ root = Path(root_dir)
184
+ if not root.exists():
185
+ return [], []
186
+ files = sorted(
187
+ (root.rglob("*.ckpt") if recursive else root.glob("*.ckpt")),
188
+ key=lambda p: p.stat().st_mtime,
189
+ reverse=True,
190
+ )
191
+ files = [p for p in files if filter in str(p)]
192
+ labels = []
193
+ for p in files:
194
+ rel = p.relative_to(root)
195
+ mtime = dt.datetime.fromtimestamp(p.stat().st_mtime).strftime("%Y-%m-%d %H:%M")
196
+ labels.append(f"{rel} • {mtime}")
197
+ return labels, [str(p) for p in files]
198
+
199
+
200
+ # ---------- Your SmallCNN (for FMNIST) ----------
201
+ class SmallCNN(nn.Module):
202
+ def __init__(self, num_classes=10):
203
+ super().__init__()
204
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
205
+ self.pool1 = nn.MaxPool2d(2, 2)
206
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
207
+ self.pool2 = nn.MaxPool2d(2, 2)
208
+ self.fc = nn.Linear(64 * 7 * 7, num_classes)
209
+
210
+ def forward(self, x):
211
+ x = F.relu(self.conv1(x))
212
+ x = self.pool1(x)
213
+ x = F.relu(self.conv2(x))
214
+ x = self.pool2(x)
215
+ x = torch.flatten(x, 1)
216
+ return self.fc(x)
217
+
218
+
219
+ # ---------- Load model + meta from checkpoint ----------
220
+ def load_model_from_ckpt(ckpt_path: Path, device: str):
221
+ ckpt = torch.load(str(ckpt_path), map_location=device)
222
+ classes = ckpt.get("classes", None)
223
+ meta = ckpt.get("meta", {})
224
+ num_classes = len(classes) if classes else 10
225
+ model_name = meta.get("model_name", "smallcnn")
226
+
227
+ if model_name == "smallcnn":
228
+ model = SmallCNN(num_classes=num_classes).to(device)
229
+ default_target_layer = "conv2"
230
+ elif model_name == "resnet18_cifar":
231
+ m = tvm.resnet18(weights=None)
232
+ m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
233
+ m.maxpool = nn.Identity()
234
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
235
+ model = m.to(device)
236
+ default_target_layer = "layer4"
237
+ elif model_name == "resnet18_imagenet":
238
+ try:
239
+ w = tvm.ResNet18_Weights.IMAGENET1K_V1
240
+ except Exception:
241
+ w = None
242
+ m = tvm.resnet18(weights=w)
243
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
244
+ model = m.to(device)
245
+ default_target_layer = "layer4"
246
+ else:
247
+ raise ValueError(f"Unknown model_name in ckpt: {model_name}")
248
+
249
+ model.load_state_dict(ckpt["model_state"])
250
+ model.eval()
251
+ # ensure meta has defaults
252
+ meta.setdefault("default_target_layer", default_target_layer)
253
+ return model, classes, meta
254
+
255
+
256
+ def build_transform_from_meta(meta):
257
+ img_size = int(meta.get("img_size", 28))
258
+ mean = meta.get("mean", [0.2860]) # FMNIST fallback
259
+ std = meta.get("std", [0.3530])
260
+ if len(mean) == 1:
261
+ return T.Compose(
262
+ [
263
+ T.Grayscale(num_output_channels=1),
264
+ T.Resize((img_size, img_size)),
265
+ T.ToTensor(),
266
+ T.Normalize(mean, std),
267
+ ]
268
+ )
269
+ else:
270
+ return T.Compose(
271
+ [
272
+ T.Resize((img_size, img_size)),
273
+ T.ToTensor(),
274
+ T.Normalize(mean, std),
275
+ ]
276
+ )
277
+
278
+
279
+ def predict_and_cam(model, x, device, target_layer, topk=3, method="Grad-CAM"):
280
+ """
281
+ x: Tensor [1,C,H,W] normalized
282
+ returns: list of dicts: {rank, class_index, prob, cam_tensor(H,W)}
283
+ """
284
+ cam_cls = GradCAM if method == "Grad-CAM" else GradCAMpp
285
+ cam_extractor = cam_cls(model, target_layer=target_layer)
286
+
287
+ logits = model(x.to(device))
288
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu()
289
+ top_vals, top_idxs = probs.topk(topk)
290
+
291
+ results = []
292
+ for rank, (p, idx) in enumerate(zip(top_vals.tolist(), top_idxs.tolist())):
293
+ retain = rank < topk - 1
294
+ cams = cam_extractor(idx, logits, retain_graph=retain) # list
295
+ cam = cams[0].detach().cpu() # [H,W] at feature-map resolution
296
+ results.append(
297
+ {"rank": rank + 1, "class_index": int(idx), "prob": float(p), "cam": cam}
298
+ )
299
+ return results, probs
300
+
301
+
302
+ def overlay_pil(base_pil_rgb: Image.Image, cam_tensor, alpha=0.5):
303
+ # cam_tensor: torch.Tensor HxW in [0,1] (we'll min-max it)
304
+ cam = cam_tensor.clone()
305
+ cam -= cam.min()
306
+ cam = cam / (cam.max() + 1e-8)
307
+ heat = T.ToPILImage()(cam) # single-channel PIL
308
+ return overlay_mask(base_pil_rgb, heat, alpha=alpha)
309
+
310
+
311
+ # ---------- UI ----------
312
+ st.set_page_config(page_title="Grad-CAM Demo", page_icon="🔍", layout="wide")
313
+ st.title("🔍 Grad-CAM Demo — upload an image, get top-k + heatmaps")
314
+
315
+ # Sidebar: checkpoint + options
316
+ with st.sidebar:
317
+ st.header("Settings")
318
+
319
+ ckpt_path = st.session_state.get("ckpt_path")
320
+
321
+ st.subheader("Checkpoints")
322
+ # Remote download (presets or URL), saved automatically to saved_checkpoints/
323
+ presets = load_release_presets()
324
+ preset_names = list(presets.keys())
325
+ preset_sel = st.selectbox("Preset (GitHub Releases)", options=["(none)"] + preset_names, index=0) if preset_names else "(none)"
326
+ url_input = st.text_input("Or paste asset URL", value="")
327
+ if st.button("Download checkpoint", use_container_width=True):
328
+ url = presets.get(preset_sel, "") if preset_sel != "(none)" else url_input.strip()
329
+ if not url:
330
+ st.warning("Provide a preset or paste a URL")
331
+ else:
332
+ try:
333
+ path_dl = download_release_asset(url, dest_dir="saved_checkpoints")
334
+ st.success(f"Downloaded to: {path_dl}")
335
+ ckpt_path = path_dl
336
+ st.session_state["ckpt_path"] = ckpt_path
337
+ st.cache_data.clear()
338
+ except Exception as e:
339
+ st.error(f"Download failed: {e}")
340
+
341
+ # Upload a user-provided .ckpt directly in the online app
342
+ uploaded_ckpt = st.file_uploader("Upload checkpoint (.ckpt)", type=["ckpt"], accept_multiple_files=False)
343
+ if uploaded_ckpt is not None and st.button("Use uploaded checkpoint", use_container_width=True):
344
+ try:
345
+ Path("saved_checkpoints").mkdir(parents=True, exist_ok=True)
346
+ raw = uploaded_ckpt.read()
347
+ content_hash = hashlib.sha256(raw).hexdigest()[:16]
348
+ base_name = Path(uploaded_ckpt.name).name
349
+ if not base_name.endswith(".ckpt"):
350
+ base_name = f"{base_name}.ckpt"
351
+ local_path = Path("saved_checkpoints") / f"{content_hash}_{base_name}"
352
+ with open(local_path, "wb") as f:
353
+ f.write(raw)
354
+ ckpt_path = str(local_path)
355
+ st.session_state["ckpt_path"] = ckpt_path
356
+ st.success(f"Uploaded to: {ckpt_path}")
357
+ st.cache_data.clear()
358
+ except Exception as e:
359
+ st.error(f"Upload failed: {e}")
360
+
361
+ st.caption(f"Selected: {ckpt_path}")
362
+
363
+ with st.expander("Checkpoint meta preview", expanded=False):
364
+ try:
365
+ if ckpt_path:
366
+ m, c, meta_preview = load_model_from_ckpt(Path(ckpt_path), device="cpu")
367
+ st.json(
368
+ {
369
+ "dataset": meta_preview.get("dataset"),
370
+ "model_name": meta_preview.get("model_name"),
371
+ "img_size": meta_preview.get("img_size"),
372
+ "target_layer": meta_preview.get("default_target_layer"),
373
+ }
374
+ )
375
+ else:
376
+ st.info("No checkpoint selected yet.")
377
+ except Exception as e:
378
+ st.info(f"Could not read meta: {e}")
379
+
380
+ method = st.selectbox("CAM method", ["Grad-CAM", "Grad-CAM++"], index=0)
381
+ topk = st.slider("Top-k classes", min_value=1, max_value=10, value=3, step=1)
382
+ alpha = st.slider(
383
+ "Overlay alpha", min_value=0.1, max_value=0.9, value=0.5, step=0.05
384
+ )
385
+
386
+ # Load model/meta
387
+ if not ckpt_path or not Path(ckpt_path).exists():
388
+ st.info(
389
+ "First choose a checkpoint:\n"
390
+ "- Preset: pick from the list and click 'Download checkpoint'\n"
391
+ "- URL: paste a direct .ckpt URL and click 'Download checkpoint'\n"
392
+ "- Upload: select a .ckpt and click 'Use uploaded checkpoint'\n\n"
393
+ "After a checkpoint is selected, upload an image or use the sample picker to see predictions and Grad-CAM overlays."
394
+ )
395
+ st.stop()
396
+
397
+ device = "cpu"
398
+ model, classes, meta = load_model_from_ckpt(Path(ckpt_path), device)
399
+ tf = build_transform_from_meta(meta)
400
+ target_layer = meta.get("default_target_layer", "conv2")
401
+
402
+ # Main: uploader
403
+ # Main: uploader OR dataset sample
404
+ st.subheader("1) Provide an image")
405
+ uploaded = st.file_uploader(
406
+ "Upload PNG/JPG (or pick a sample below)", type=["png", "jpg", "jpeg"]
407
+ )
408
+
409
+ with st.expander("…or pick a sample from this model's dataset", expanded=False):
410
+ ds_default = meta.get("dataset", "fashion-mnist")
411
+ ds, ds_classes = load_raw_dataset(ds_default, root="data")
412
+ targets = np.array(getattr(ds, "targets", [ds[i][1] for i in range(len(ds))]))
413
+
414
+ # --- class filter (persisted) ---
415
+ class_opts = ["(any)"] + list(ds_classes)
416
+ class_sel = st.selectbox("Class filter", options=class_opts, index=0, key="class_sel")
417
+
418
+ if class_sel == "(any)":
419
+ filtered_idx = np.arange(len(ds))
420
+ else:
421
+ class_id = ds_classes.index(class_sel)
422
+ filtered_idx = np.nonzero(targets == class_id)[0]
423
+
424
+ # --- ensure we have a session index and keep it valid ---
425
+ if "sample_idx" not in st.session_state:
426
+ st.session_state["sample_idx"] = 0
427
+
428
+ # clamp when filter changes or dataset length is small
429
+ if len(filtered_idx) > 0:
430
+ st.session_state["sample_idx"] = int(
431
+ np.clip(st.session_state["sample_idx"], 0, len(filtered_idx) - 1)
432
+ )
433
+
434
+ if len(filtered_idx) == 0:
435
+ st.info("No samples found for this class.")
436
+ sample_img = None
437
+ else:
438
+ col_l, col_r = st.columns([2, 1])
439
+
440
+ with col_r:
441
+ picked = st.button("Pick random", use_container_width=True, key="btn_pick_random")
442
+ if picked:
443
+ # IMPORTANT: update session_state BEFORE creating the slider
444
+ cur = st.session_state["sample_idx"]
445
+ if len(filtered_idx) > 1:
446
+ new_idx = random.randrange(len(filtered_idx) - 1)
447
+ if new_idx >= cur:
448
+ new_idx += 1
449
+ else:
450
+ new_idx = 0
451
+ st.session_state["sample_idx"] = new_idx
452
+ # no st.rerun() needed; the app will rerun after the button
453
+
454
+ with col_l:
455
+ # Now instantiate the slider (AFTER any state changes above)
456
+ st.slider(
457
+ "Pick index (within filtered samples)",
458
+ 0, max(0, len(filtered_idx) - 1),
459
+ key="sample_idx", # same key as the state we set above
460
+ )
461
+
462
+ raw_idx = int(filtered_idx[st.session_state["sample_idx"]])
463
+ img_tensor, label = ds[raw_idx]
464
+ sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)
465
+
466
+ st.image(
467
+ sample_img,
468
+ caption=f"Sample • {ds_default} • class={ds_classes[label]} • idx={raw_idx}",
469
+ width=160,
470
+ use_container_width=False,
471
+ )
472
+
473
+ # Decide the input image used downstream
474
+ if uploaded is not None:
475
+ pil = Image.open(uploaded).convert("RGB")
476
+ elif "sample_img" in locals() and sample_img is not None:
477
+ pil = sample_img
478
+ else:
479
+ st.info("Upload an image or open the sample picker above.")
480
+ st.stop()
481
+
482
+ col_in, col_cfg = st.columns([2, 1])
483
+
484
+ with col_in:
485
+ if uploaded:
486
+ pil = Image.open(uploaded).convert("RGB")
487
+ elif sample_img is not None:
488
+ pil = sample_img
489
+ else:
490
+ st.info("Upload an image or check 'Use a sample image'.")
491
+ st.stop()
492
+
493
+ st.image(pil, caption="Input", use_container_width=True)
494
+
495
+ with col_cfg:
496
+ st.markdown("**Model meta**")
497
+ st.json(
498
+ {
499
+ "dataset": meta.get("dataset"),
500
+ "model_name": meta.get("model_name"),
501
+ "img_size": meta.get("img_size"),
502
+ "target_layer": target_layer,
503
+ "mean": meta.get("mean"),
504
+ "std": meta.get("std"),
505
+ "classes": (
506
+ classes
507
+ if classes and len(classes) <= 10
508
+ else f"{len(classes) if classes else 'N/A'} classes"
509
+ ),
510
+ }
511
+ )
512
+
513
+ # Prepare tensor + denormalized PIL base for overlay
514
+ x = tf(pil) # CxHxW normalized
515
+ x_batched = x.unsqueeze(0) # 1xCxHxW
516
+ base_pil = denorm_to_pil(x, meta.get("mean", [0.2860]), meta.get("std", [0.3530]))
517
+
518
+ # Predict + CAM
519
+ with st.spinner("Running inference + Grad-CAM..."):
520
+ try:
521
+ cam_results, probs = predict_and_cam(
522
+ model, x_batched, device, target_layer, topk=topk, method=method
523
+ )
524
+ except Exception as e:
525
+ st.error(
526
+ f"Grad-CAM failed. Target layer likely incorrect."
527
+ f"\nLayer: {target_layer}\nError: {e}"
528
+ )
529
+ st.stop()
530
+
531
+ # Top-k table
532
+ st.subheader("2) Top-k predictions")
533
+ rows = []
534
+ for r in cam_results:
535
+ name = classes[r["class_index"]] if classes else str(r["class_index"])
536
+ rows.append(
537
+ {
538
+ "rank": r["rank"],
539
+ "class": name,
540
+ "index": r["class_index"],
541
+ "prob": round(r["prob"], 4),
542
+ }
543
+ )
544
+ st.dataframe(rows, use_container_width=True)
545
+
546
+ # Overlays
547
+ st.subheader("3) Grad-CAM overlays")
548
+ cols = st.columns(len(cam_results))
549
+ for c, r in zip(cols, cam_results):
550
+ name = classes[r["class_index"]] if classes else str(r["class_index"])
551
+ ov = overlay_pil(base_pil, r["cam"], alpha=alpha)
552
+ with c:
553
+ st.image(
554
+ ov,
555
+ caption=f"Top{r['rank']}: {name} ({r['prob']:.3f})",
556
+ use_container_width=True,
557
+ )
configs/baseline.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: fashion-mnist
2
+ data_dir: ./data
3
+
4
+ batch_size: 128
5
+ epochs: 8
6
+ lr: 0.001
7
+ weight_decay: 0.0001
8
+ num_workers: 2
9
+ seed: 41
10
+ device: auto
11
+
12
+ log_root: runs
13
+ ckpt_root: checkpoints
14
+ reports_root: reports
15
+
16
+ early_stop:
17
+ monitor: val_loss # val_loss or val_acc
18
+ mode: min # min for loss, max for acc
19
+ patience: 3
20
+ min_delta: 0.0
configs/cifar10_resnet18.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: cifar10
2
+ model_name: resnet18_cifar # CIFAR variant (3×3 conv1, no maxpool)
3
+ data_dir: ./data
4
+
5
+ # training
6
+ batch_size: 128
7
+ epochs: 40
8
+ lr: 0.001
9
+ weight_decay: 0.0005
10
+ num_workers: 2
11
+ seed: 41
12
+ device: auto
13
+
14
+ # image + normalization (CIFAR-10 stats)
15
+ img_size: 32
16
+ mean: [0.4914, 0.4822, 0.4465]
17
+ std: [0.2470, 0.2435, 0.2616]
18
+
19
+ # logging+artifacts
20
+ log_root: runs
21
+ ckpt_root: checkpoints
22
+ reports_root: reports
23
+
24
+ early_stop:
25
+ monitor: val_loss
26
+ mode: min
27
+ patience: 5
28
+ min_delta: 0.0
configs/cifar10_resnet18_adam.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: cifar10
2
+ model_name: resnet18_cifar
3
+ data_dir: ./data
4
+
5
+ batch_size: 128
6
+ epochs: 40
7
+ lr: 0.001 # Adam baseline
8
+ weight_decay: 0.0005
9
+ optimizer: adam # or sgd
10
+ momentum: 0.9
11
+ num_workers: 2
12
+ seed: 41
13
+ device: auto
14
+
15
+ img_size: 32
16
+ mean: [0.4914, 0.4822, 0.4465]
17
+ std: [0.2470, 0.2435, 0.2616]
18
+
19
+ log_root: runs
20
+ ckpt_root: checkpoints
21
+ reports_root: reports
22
+
23
+ early_stop:
24
+ monitor: val_loss
25
+ mode: min
26
+ patience: 5
27
+ min_delta: 0.0
configs/cifar10_resnet18_imagenet.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: cifar10
2
+ model_name: resnet18_imagenet # ImageNet-pretrained; resize to 224
3
+ data_dir: ./data
4
+
5
+ batch_size: 128
6
+ epochs: 20
7
+ lr: 0.0005
8
+ weight_decay: 0.0005
9
+ num_workers: 2
10
+ seed: 41
11
+ device: auto
12
+
13
+ # image + normalization (ImageNet stats)
14
+ img_size: 224
15
+ mean: [0.485, 0.456, 0.406]
16
+ std: [0.229, 0.224, 0.225]
17
+
18
+ log_root: runs
19
+ ckpt_root: checkpoints
20
+ reports_root: reports
21
+
22
+ early_stop:
23
+ monitor: val_loss
24
+ mode: min
25
+ patience: 5
26
+ min_delta: 0.0
configs/cifar10_resnet18_sgd.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: cifar10
2
+ model_name: resnet18_cifar
3
+ data_dir: ./data
4
+
5
+ batch_size: 128
6
+ epochs: 60
7
+ lr: 0.1 # classic SGD start
8
+ weight_decay: 0.0005
9
+ optimizer: adam # or sgd
10
+ momentum: 0.9
11
+ num_workers: 2
12
+ seed: 41
13
+ device: auto
14
+
15
+ img_size: 32
16
+ mean: [0.4914, 0.4822, 0.4465]
17
+ std: [0.2470, 0.2435, 0.2616]
18
+
19
+ log_root: runs
20
+ ckpt_root: checkpoints
21
+ reports_root: reports
22
+
23
+ early_stop:
24
+ monitor: val_loss
25
+ mode: min
26
+ patience: 8
27
+ min_delta: 0.0
configs/fmnist_smallcnn.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: fashion-mnist
2
+ model_name: smallcnn
3
+ data_dir: ./data
4
+
5
+ batch_size: 128
6
+ epochs: 8
7
+ lr: 0.001
8
+ weight_decay: 0.0001
9
+ num_workers: 2
10
+ seed: 41
11
+ device: auto
12
+
13
+ img_size: 28
14
+ mean: [0.2860]
15
+ std: [0.3530]
16
+
17
+ log_root: runs
18
+ ckpt_root: checkpoints
19
+ reports_root: reports
20
+
21
+ early_stop:
22
+ monitor: val_loss
23
+ mode: min
24
+ patience: 3
25
+ min_delta: 0.0
configs/fmnist_smallcnn_aug.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset: fashion-mnist
2
+ model_name: smallcnn
3
+ data_dir: ./data
4
+
5
+ batch_size: 128
6
+ epochs: 12
7
+ lr: 0.001
8
+ weight_decay: 0.0001
9
+ num_workers: 2
10
+ seed: 41
11
+ device: auto
12
+
13
+ img_size: 28
14
+ mean: [0.2860]
15
+ std: [0.3530]
16
+
17
+ log_root: runs
18
+ ckpt_root: checkpoints
19
+ reports_root: reports
20
+
21
+ early_stop:
22
+ monitor: val_loss
23
+ mode: min
24
+ patience: 3
25
+ min_delta: 0.0
model_card.md ADDED
File without changes
notebooks/01_baseline_fmnist.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements-gpu.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPU build (install torch/torchvision/torchaudio separately per https://pytorch.org/get-started/locally/)
2
+ # Example (CUDA 12.1):
3
+ # pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio
4
+
5
+ # Core libs (do NOT pin torch* here; users install matching CUDA builds first)
6
+ torchmetrics
7
+ torchcam
8
+ captum
9
+ lightning
10
+ albumentations
11
+ opencv-python-headless
12
+ matplotlib
13
+ seaborn
14
+ tensorboard
15
+ rich
16
+ PyYAML
17
+ streamlit
18
+ requests
19
+ gradio
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use CPU-only PyTorch wheels from the official index
2
+ --extra-index-url https://download.pytorch.org/whl/cpu
3
+ torch==2.2.2+cpu
4
+ torchvision==0.17.2+cpu
5
+ torchaudio==2.2.2+cpu
6
+ torchmetrics
7
+ torchcam
8
+ captum
9
+ # Lightning (pytorch-lightning older alias still works but this is preferred)
10
+ lightning
11
+ albumentations
12
+ # Headless OpenCV for servers without GUI
13
+ opencv-python-headless
14
+ matplotlib
15
+ seaborn
16
+ tensorboard
17
+ rich
18
+ PyYAML
19
+ streamlit
20
+ requests
21
+ gradio
runtime.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python-3.11
2
+
src/explain.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms.functional as TF
8
+ from PIL import Image
9
+ from torchcam.methods import GradCAM
10
+ from torchcam.utils import overlay_mask
11
+ from torchvision import models as tvm
12
+ from torchvision import transforms
13
+
14
+ from src.train import SmallCNN, get_device
15
+
16
+
17
+ def build_argparser():
18
+ p = argparse.ArgumentParser(description="Grad-CAM explanations")
19
+ p.add_argument("--ckpt", type=str, required=True, help="Path to best.ckpt")
20
+ p.add_argument("--image", type=str, required=True, help="Path to an input image")
21
+ p.add_argument(
22
+ "--dataset",
23
+ choices=["fashion-mnist", "mnist", "cifar10"],
24
+ default="fashion-mnist",
25
+ help="Used to apply the right normalization and class names",
26
+ )
27
+ p.add_argument(
28
+ "--target-layer",
29
+ type=str,
30
+ default="conv2",
31
+ help="Layer to attach CAMs (e.g., 'conv2' for SmallCNN, 'layer4' for ResNet)",
32
+ )
33
+ p.add_argument(
34
+ "--outdir",
35
+ type=str,
36
+ default=None,
37
+ help="Where to store results; defaults near the checkpoint",
38
+ )
39
+ p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
40
+ p.add_argument("--topk", type=int, default=3, help="How many top classes to render")
41
+ return p
42
+
43
+
44
+ def get_transforms_from_meta(meta):
45
+ img_size = int(meta.get("img_size", 28))
46
+ mean = meta.get("mean", [0.2860]) # fallback FMNIST
47
+ std = meta.get("std", [0.3530])
48
+
49
+ # channels: grayscale if mean/std length==1, else RGB
50
+ if len(mean) == 1:
51
+ tf = transforms.Compose(
52
+ [
53
+ transforms.Grayscale(num_output_channels=1),
54
+ transforms.Resize((img_size, img_size)),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize(mean, std),
57
+ ]
58
+ )
59
+ else:
60
+ tf = transforms.Compose(
61
+ [
62
+ transforms.Resize((img_size, img_size)),
63
+ transforms.ToTensor(),
64
+ transforms.Normalize(mean, std),
65
+ ]
66
+ )
67
+ return tf
68
+
69
+
70
+ def denorm_to_pil(x: torch.Tensor, mean, std) -> Image.Image:
71
+ """
72
+ x: normalized tensor CxHxW
73
+ mean/std: list(s) from meta
74
+ returns: PIL RGB image for overlay
75
+ """
76
+ x = x.detach().cpu().clone()
77
+ if len(mean) == 1: # grayscale
78
+ m, s = float(mean[0]), float(std[0])
79
+ x = x * s + m
80
+ x = x.clamp(0, 1)
81
+ pil = transforms.ToPILImage()(x) # grayscale PIL
82
+ return pil.convert("RGB")
83
+ else: # RGB
84
+ mean_t = torch.tensor(mean)[:, None, None]
85
+ std_t = torch.tensor(std)[:, None, None]
86
+ x = x * std_t + mean_t
87
+ x = x.clamp(0, 1)
88
+ return transforms.ToPILImage()(x)
89
+
90
+
91
+ def load_model(ckpt_path, device):
92
+ ckpt = torch.load(ckpt_path, map_location=device)
93
+ classes = ckpt.get("classes", None)
94
+ meta = ckpt.get("meta", {})
95
+ num_classes = len(classes) if classes else 10
96
+ model_name = meta.get("model_name", "smallcnn")
97
+
98
+ if model_name == "smallcnn":
99
+ model = SmallCNN(num_classes=num_classes).to(device)
100
+ elif model_name == "resnet18_cifar":
101
+ m = tvm.resnet18(weights=None)
102
+ m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
103
+ m.maxpool = nn.Identity()
104
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
105
+ model = m.to(device)
106
+ elif model_name == "resnet18_imagenet":
107
+ try:
108
+ w = tvm.ResNet18_Weights.IMAGENET1K_V1
109
+ except Exception:
110
+ w = None
111
+ m = tvm.resnet18(weights=w)
112
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
113
+ model = m.to(device)
114
+ else:
115
+ raise ValueError(f"Unknown model in ckpt: {model_name}")
116
+
117
+ model.load_state_dict(ckpt["model_state"])
118
+ model.eval()
119
+ return model, classes, meta
120
+
121
+
122
+ def run_gradcam(
123
+ model,
124
+ target_layer,
125
+ img_tensor,
126
+ device,
127
+ classes,
128
+ outdir: Path,
129
+ topk=3,
130
+ base_pil_rgb: Image.Image = None,
131
+ ):
132
+ """
133
+ img_tensor: CxHxW normalized (not batched)
134
+ base_pil_rgb: PIL image already denormalized & RGB for overlay (optional).
135
+ If None, will min-max scale from img_tensor (last-resort).
136
+ """
137
+ model.eval()
138
+ x = img_tensor.to(device).unsqueeze(0) # [1,C,H,W]
139
+ H, W = img_tensor.shape[-2:]
140
+ cam_extractor = GradCAM(model, target_layer=target_layer)
141
+
142
+ # forward once to get top-k
143
+ logits = model(x)
144
+ probs = torch.softmax(logits, dim=1)[0].detach().cpu()
145
+ top_vals, top_idxs = probs.topk(topk)
146
+
147
+ if base_pil_rgb is None:
148
+ # Fallback: simple min-max scaling (works but less faithful than denorm)
149
+ xx = img_tensor.detach().cpu()
150
+ xx = (xx - xx.min()) / (xx.max() - xx.min() + 1e-8)
151
+ base_pil_rgb = transforms.ToPILImage()(xx)
152
+ if xx.shape[0] == 1:
153
+ base_pil_rgb = base_pil_rgb.convert("RGB")
154
+
155
+ results = []
156
+ for rank, (score, cls_idx) in enumerate(zip(top_vals.tolist(), top_idxs.tolist())):
157
+ retain = rank < topk - 1
158
+ cams = cam_extractor(int(cls_idx), logits, retain_graph=retain)
159
+ cam = cams[0].detach().cpu() # [h,w]
160
+ cam_up = TF.resize(cam.unsqueeze(0), size=[H, W])[0] # upsample to input size
161
+
162
+ heat = transforms.ToPILImage()(cam_up)
163
+ overlay = overlay_mask(base_pil_rgb, heat, alpha=0.6)
164
+
165
+ out_png = (
166
+ outdir / f"gradcam_top{rank+1}_class{cls_idx}_"
167
+ + f"{classes[cls_idx] if classes else cls_idx}.png"
168
+ )
169
+ overlay.save(out_png)
170
+
171
+ results.append(
172
+ {
173
+ "rank": rank + 1,
174
+ "class_index": int(cls_idx),
175
+ "class_name": classes[cls_idx] if classes else str(cls_idx),
176
+ "prob": float(score),
177
+ "file": str(out_png),
178
+ }
179
+ )
180
+
181
+ with open(outdir / "summary.json", "w") as f:
182
+ json.dump({"topk": results}, f, indent=2)
183
+
184
+ print("Saved:", outdir)
185
+ return results
186
+
187
+
188
+ def main():
189
+ args = build_argparser().parse_args()
190
+ device = get_device(args.device)
191
+
192
+ ckpt_path = Path(args.ckpt)
193
+
194
+ # outdir default
195
+ if args.outdir is None:
196
+ run_id = ckpt_path.parent.name
197
+ outdir = ckpt_path.parent.parent.parent / "reports" / run_id / "explain"
198
+ else:
199
+ outdir = Path(args.outdir)
200
+ outdir.mkdir(parents=True, exist_ok=True)
201
+
202
+ # 1) load model+meta first
203
+ model, classes, meta = load_model(str(ckpt_path), device)
204
+
205
+ # 2) build tf from meta
206
+ tf = get_transforms_from_meta(meta)
207
+
208
+ # 3) load and transform image
209
+ pil = Image.open(args.image).convert("RGB")
210
+ x = tf(pil) # CxHxW normalized
211
+
212
+ # 4) make a denormalized RGB base image for overlay
213
+ base_pil = denorm_to_pil(x, meta.get("mean", [0.2860]), meta.get("std", [0.3530]))
214
+
215
+ # 5) target layer (CLI overrides meta default)
216
+ target_layer = args.target_layer or meta.get("default_target_layer", "conv2")
217
+
218
+ # 6) run Grad-CAM
219
+ results = run_gradcam(
220
+ model,
221
+ target_layer,
222
+ x,
223
+ device,
224
+ classes,
225
+ outdir,
226
+ topk=args.topk,
227
+ base_pil_rgb=base_pil,
228
+ )
229
+
230
+ # 7) print summary
231
+ for r in results:
232
+ print(f"Top{r['rank']}: {r['class_name']} ({r['prob']:.3f}) -> {r['file']}")
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()
src/infer.py ADDED
File without changes
src/simCLR.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import DataLoader
7
+ from torchvision import models as tvm
8
+ from torchvision.datasets import ImageFolder
9
+
10
+ from lightly.loss import NTXentLoss
11
+ from lightly.models.modules import SimCLRProjectionHead
12
+ from lightly.transforms.simclr_transform import SimCLRTransform
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+
15
+ # ----------------------------
16
+ # Config
17
+ # ----------------------------
18
+ DATA_ROOT = "data/eurosat_custom/train" # prepared split (train only, unlabeled)
19
+ BATCH_SIZE = 256
20
+ EPOCHS = 150
21
+ LR = 0.06
22
+ NUM_WORKERS = 8
23
+ IMG_SIZE = 224 # resize inside transform
24
+ OUT_DIR = Path("checkpoints_ssl")
25
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
26
+ warmup_epochs = 10
27
+ total_epochs = EPOCHS
28
+
29
+ def lr_lambda(epoch):
30
+ if epoch < warmup_epochs:
31
+ return float(epoch + 1) / warmup_epochs
32
+ progress = (epoch - warmup_epochs) / float(total_epochs - warmup_epochs)
33
+ return 0.5 * (1.0 + torch.cos(torch.pi * progress))
34
+
35
+ scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ use_amp = torch.cuda.is_available() # mixed precision if GPU
39
+
40
+
41
+ # ----------------------------
42
+ # Model: ResNet18 encoder + SimCLR projection head
43
+ # ----------------------------
44
+ class SimCLR(nn.Module):
45
+ def __init__(self, backbone, in_dim=512, proj_hidden=512, proj_out=128):
46
+ super().__init__()
47
+ self.backbone = backbone
48
+ self.projection_head = SimCLRProjectionHead(in_dim, proj_hidden, proj_out)
49
+
50
+ def forward(self, x):
51
+ # backbone assumed to output [N, C, 1, 1] after global pooling
52
+ x = self.backbone(x).flatten(start_dim=1)
53
+ z = self.projection_head(x)
54
+ return z
55
+
56
+ # Build a torchvision resnet18 backbone without the FC layer
57
+ resnet = tvm.resnet18(weights=None)
58
+ # replace avgpool+fc stack with Identity + keep global avgpool:
59
+ # torchvision resnet18 returns features after avgpool as 512-d before fc.
60
+ backbone = nn.Sequential(*list(resnet.children())[:-1]) # until avgpool, outputs [N,512,1,1]
61
+ model = SimCLR(backbone, in_dim=512, proj_hidden=512, proj_out=128).to(device)
62
+
63
+
64
+ # ----------------------------
65
+ # Data: EuroSAT train images as unlabeled pairs of views
66
+ # ----------------------------
67
+ # SimCLR default normalization in Lightly is ImageNet stats; perfect for ResNet18 at 224.
68
+ transform = SimCLRTransform(
69
+ input_size=IMG_SIZE,
70
+ gaussian_blur=0.1, # EuroSAT is small; mild blur helps but keep modest
71
+ cj_strength=0.5, # color jitter strength
72
+ )
73
+
74
+ dataset = ImageFolder(DATA_ROOT, transform=transform)
75
+ # ImageFolder returns ( (v1, v2), label ) because transform yields two views.
76
+ # We'll ignore labels during pretraining.
77
+ loader = DataLoader(
78
+ dataset,
79
+ batch_size=BATCH_SIZE,
80
+ shuffle=True,
81
+ drop_last=True,
82
+ num_workers=NUM_WORKERS,
83
+ pin_memory=torch.cuda.is_available(),
84
+ )
85
+
86
+
87
+ # ----------------------------
88
+ # Objective & Optimizer
89
+ # ----------------------------
90
+ criterion = NTXentLoss(temperature=0.5) # standard SimCLR temperature
91
+ optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=1e-4)
92
+
93
+ scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
94
+
95
+
96
+ # ----------------------------
97
+ # Training loop
98
+ # ----------------------------
99
+ print(f"Starting SimCLR pretraining on {device} for {EPOCHS} epochs…")
100
+ model.train()
101
+ for epoch in range(1, EPOCHS + 1):
102
+ total_loss = 0.0
103
+ for (v1, v2), _ in loader: # labels are unused
104
+ v1 = v1.to(device, non_blocking=True)
105
+ v2 = v2.to(device, non_blocking=True)
106
+
107
+ optimizer.zero_grad(set_to_none=True)
108
+
109
+ with torch.cuda.amp.autocast(enabled=use_amp):
110
+ z1 = model(v1)
111
+ z2 = model(v2)
112
+ loss = criterion(z1, z2)
113
+
114
+ scaler.scale(loss).backward()
115
+ scaler.step(optimizer)
116
+ scaler.update()
117
+
118
+ total_loss += loss.detach().item()
119
+
120
+ avg_loss = total_loss / len(loader)
121
+ current_lr = scheduler.get_last_lr()[0]
122
+ print(f"epoch {epoch:03d} | loss {avg_loss:.5f} | lr {current_lr:.5f}")
123
+
124
+ scheduler.step()
125
+
126
+ # (optional) save checkpoints every N epochs
127
+ if epoch % 25 == 0 or epoch == EPOCHS:
128
+ # save only the encoder (backbone) weights for fine-tuning
129
+ enc_state = {k.replace("backbone.", "", 1): v
130
+ for k, v in model.state_dict().items()
131
+ if k.startswith("backbone.")}
132
+ torch.save(enc_state, OUT_DIR / f"simclr_resnet18_eurosat_epoch{epoch}.pt")
133
+
134
+ print("Done.")
src/train.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import time
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ import torchvision as tv
14
+ import torchvision.models as models
15
+ import yaml
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ from torchmetrics.classification import MulticlassConfusionMatrix
19
+ from torchvision import transforms
20
+ from torchvision.datasets import ImageFolder
21
+
22
+
23
+ # ----------------- argparse -----------------
24
+ def build_argparser():
25
+ p = argparse.ArgumentParser(description="Train a small CNN on MNIST/Fashion-MNIST")
26
+ p.add_argument(
27
+ "--dataset", choices=["fashion-mnist", "mnist", "cifar10"], default="fashion-mnist"
28
+ )
29
+ p.add_argument("--data-dir", type=str, default="./data")
30
+ p.add_argument("--batch-size", type=int, default=128)
31
+ p.add_argument("--epochs", type=int, default=8)
32
+ p.add_argument("--lr", type=float, default=1e-3)
33
+ p.add_argument("--weight-decay", type=float, default=1e-4)
34
+ p.add_argument("--num-workers", type=int, default=2)
35
+ p.add_argument("--seed", type=int, default=41)
36
+ p.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto")
37
+ # legacy path args (we’ll map them into roots if provided)
38
+ p.add_argument("--logdir", type=str, default=None)
39
+ p.add_argument("--ckpt", type=str, default=None)
40
+ p.add_argument("--metrics", type=str, default=None)
41
+ p.add_argument("--reports-dir", type=str, default=None)
42
+ # config
43
+ p.add_argument(
44
+ "--config",
45
+ type=str,
46
+ default="configs/baseline.yaml",
47
+ help="Path to YAML config with defaults",
48
+ )
49
+ p.add_argument(
50
+ "--model-name",
51
+ type=str,
52
+ default=None,
53
+ choices=["smallcnn", "resnet18_cifar", "resnet18_imagenet"],
54
+ help="Choose model architecture",
55
+ )
56
+ return p
57
+
58
+
59
+ # ----------------- small utils -----------------
60
+ def get_device(choice: str) -> str:
61
+ if choice == "cpu":
62
+ return "cpu"
63
+ if choice == "cuda":
64
+ return "cuda"
65
+ return "cuda" if torch.cuda.is_available() else "cpu"
66
+
67
+
68
+ def seed_all(seed: int):
69
+ import random
70
+
71
+ import numpy as np
72
+
73
+ random.seed(seed)
74
+ np.random.seed(seed)
75
+ torch.manual_seed(seed)
76
+ torch.cuda.manual_seed_all(seed)
77
+ torch.backends.cudnn.deterministic = True
78
+ torch.backends.cudnn.benchmark = False
79
+
80
+
81
+ def accuracy(logits, targets):
82
+ preds = logits.argmax(dim=1)
83
+ return (preds == targets).float().mean().item()
84
+
85
+
86
+ def load_yaml(path: str) -> dict:
87
+ with open(path, "r") as f:
88
+ return yaml.safe_load(f)
89
+
90
+
91
+ def merge_cli_over_config_with_defaults(cfg, args, parser):
92
+ cfg = deepcopy(cfg)
93
+ defaults = parser.parse_args([]) # argparse defaults only
94
+ for arg_name, cfg_key in [
95
+ ("dataset", "dataset"),
96
+ ("data_dir", "data_dir"),
97
+ ("batch_size", "batch_size"),
98
+ ("epochs", "epochs"),
99
+ ("lr", "lr"),
100
+ ("weight_decay", "weight_decay"),
101
+ ("num_workers", "num_workers"),
102
+ ("seed", "seed"),
103
+ ("device", "device"),
104
+ ("logdir", "log_root"),
105
+ ("ckpt", "ckpt_root"),
106
+ ("metrics", "reports_root"),
107
+ ("reports_dir", "reports_root"),
108
+ ("model_name", "model_name"),
109
+ ]:
110
+ val = getattr(args, arg_name)
111
+ defval = getattr(defaults, arg_name)
112
+ if val is not None and val != defval:
113
+ if arg_name == "ckpt":
114
+ cfg[cfg_key] = str(Path(val).parent)
115
+ elif arg_name in ("metrics", "reports_dir"):
116
+ cfg[cfg_key] = str(Path(val).parent)
117
+ else:
118
+ cfg[cfg_key] = val
119
+ cfg["_config_path"] = args.config
120
+ return cfg
121
+
122
+
123
+ def is_improved(best_value, current, mode: str, min_delta: float) -> bool:
124
+ if mode == "min":
125
+ return current < (best_value - min_delta)
126
+ return current > (best_value + min_delta)
127
+
128
+
129
+ def save_checkpoint(payload: dict, path: Path):
130
+ torch.save(payload, str(path))
131
+
132
+
133
+ # ----------------- model -----------------
134
+ class SmallCNN(nn.Module):
135
+ def __init__(self, num_classes: int = 10):
136
+ super().__init__()
137
+ self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
138
+ self.pool1 = nn.MaxPool2d(2, 2)
139
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
140
+ self.pool2 = nn.MaxPool2d(2, 2)
141
+ self.fc = nn.Linear(64 * 7 * 7, num_classes)
142
+
143
+ def forward(self, x):
144
+ x = F.relu(self.conv1(x))
145
+ x = self.pool1(x)
146
+ x = F.relu(self.conv2(x))
147
+ x = self.pool2(x)
148
+ x = torch.flatten(x, 1)
149
+ return self.fc(x) # logits
150
+
151
+
152
+ def build_model(model_name: str, num_classes: int, img_size: int):
153
+ """
154
+ Returns (model, default_target_layer)
155
+ """
156
+ if model_name == "smallcnn":
157
+ m = SmallCNN(num_classes=num_classes)
158
+ return m, "conv2"
159
+
160
+ if model_name == "resnet18_cifar":
161
+ # Start from vanilla resnet18 but adapt for CIFAR (32x32)
162
+ m = models.resnet18(weights=None)
163
+ # 3x3 conv, stride=1, padding=1 instead of 7x7/stride=2, and remove maxpool
164
+ m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
165
+ m.maxpool = nn.Identity()
166
+ # replace classifier
167
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
168
+ return m, "layer4"
169
+
170
+ if model_name == "resnet18_imagenet":
171
+ # Use ImageNet weights and resize input to 224
172
+ try:
173
+ w = models.ResNet18_Weights.IMAGENET1K_V1
174
+ except Exception:
175
+ w = None
176
+ m = models.resnet18(weights=w)
177
+ m.fc = nn.Linear(m.fc.in_features, num_classes)
178
+ return m, "layer4"
179
+
180
+ raise ValueError(f"Unknown model_name: {model_name}")
181
+
182
+
183
+ # ----------------- data -----------------
184
+ def get_transforms_for(dataset_name: str, img_size: int, mean, std, train: bool):
185
+ tfms = []
186
+ if dataset_name in {"cifar10"}:
187
+ if train:
188
+ # light augments for CIFAR
189
+ if img_size == 32:
190
+ tfms += [
191
+ transforms.RandomCrop(32, padding=4),
192
+ transforms.RandomHorizontalFlip(),
193
+ ]
194
+ else:
195
+ tfms += [
196
+ transforms.Resize((img_size, img_size)),
197
+ transforms.RandomHorizontalFlip(),
198
+ ]
199
+ else:
200
+ tfms += [transforms.Resize((img_size, img_size))]
201
+ tfms += [transforms.ToTensor(), transforms.Normalize(mean, std)]
202
+ return transforms.Compose(tfms)
203
+
204
+
205
+
206
+ # fashion-mnist / mnist (grayscale)
207
+ # fashion-mnist / mnist (grayscale)
208
+ m, s = float(mean[0]), float(std[0])
209
+ tfms = [transforms.ToTensor(), transforms.Normalize((m,), (s,))]
210
+ return transforms.Compose(tfms)
211
+
212
+
213
+ def get_dataloaders(
214
+ dataset_name: str,
215
+ data_dir: str,
216
+ batch_size: int,
217
+ num_workers: int,
218
+ seed: int,
219
+ img_size: int,
220
+ mean,
221
+ std,
222
+ ):
223
+ root = Path(data_dir)
224
+ g = torch.Generator().manual_seed(seed)
225
+
226
+ if dataset_name == "fashion-mnist":
227
+ train_tf = get_transforms_for("fashion-mnist", img_size, mean, std, train=True)
228
+ eval_tf = get_transforms_for("fashion-mnist", img_size, mean, std, train=False)
229
+ train_ds = tv.datasets.FashionMNIST(root=root, train=True, download=True, transform=train_tf)
230
+ test_ds = tv.datasets.FashionMNIST(root=root, train=False, download=True, transform=eval_tf)
231
+
232
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=g)
233
+ val_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
234
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
235
+ classes = train_ds.classes
236
+ return train_loader, val_loader, test_loader, classes
237
+
238
+ elif dataset_name == "mnist":
239
+ train_tf = get_transforms_for("mnist", img_size, mean, std, train=True)
240
+ eval_tf = get_transforms_for("mnist", img_size, mean, std, train=False)
241
+ train_ds = tv.datasets.MNIST(root=root, train=True, download=True, transform=train_tf)
242
+ test_ds = tv.datasets.MNIST(root=root, train=False, download=True, transform=eval_tf)
243
+
244
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=g)
245
+ val_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
246
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
247
+ classes = train_ds.classes
248
+ return train_loader, val_loader, test_loader, classes
249
+
250
+ elif dataset_name == "cifar10":
251
+ train_tf = get_transforms_for("cifar10", img_size, mean, std, train=True)
252
+ eval_tf = get_transforms_for("cifar10", img_size, mean, std, train=False)
253
+ train_ds = tv.datasets.CIFAR10(root=root, train=True, download=True, transform=train_tf)
254
+ test_ds = tv.datasets.CIFAR10(root=root, train=False, download=True, transform=eval_tf)
255
+
256
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, generator=g)
257
+ val_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
258
+ test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
259
+ classes = train_ds.classes
260
+ return train_loader, val_loader, test_loader, classes
261
+
262
+
263
+
264
+ else:
265
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
266
+
267
+
268
+ # ----------------- train/eval -----------------
269
+ def train_one_epoch(model, loader, device, optimizer, loss_fn):
270
+ model.train()
271
+ loss_sum = 0.0
272
+ acc_sum = 0.0
273
+ n = 0
274
+ for xb, yb in loader:
275
+ xb, yb = xb.to(device), yb.to(device)
276
+ optimizer.zero_grad()
277
+ logits = model(xb)
278
+ loss = loss_fn(logits, yb)
279
+ loss.backward()
280
+ optimizer.step()
281
+ b = yb.size(0)
282
+ loss_sum += loss.item() * b
283
+ acc_sum += accuracy(logits, yb) * b
284
+ n += b
285
+ return loss_sum / n, acc_sum / n
286
+
287
+
288
+ @torch.no_grad()
289
+ def eval_one_epoch(model, loader, device, loss_fn):
290
+ model.eval()
291
+ loss_sum = 0.0
292
+ acc_sum = 0.0
293
+ n = 0
294
+ for xb, yb in loader:
295
+ xb, yb = xb.to(device), yb.to(device)
296
+ logits = model(xb)
297
+ loss = loss_fn(logits, yb)
298
+ b = yb.size(0)
299
+ loss_sum += loss.item() * b
300
+ acc_sum += accuracy(logits, yb) * b
301
+ n += b
302
+ return loss_sum / n, acc_sum / n
303
+
304
+
305
+ @torch.no_grad()
306
+ def confusion_matrix_report(
307
+ model,
308
+ test_loader,
309
+ device,
310
+ classes,
311
+ reports_dir: Path,
312
+ metrics_path: Path,
313
+ title_prefix: str,
314
+ ):
315
+ model.eval()
316
+ all_preds, all_targets = [], []
317
+ for xb, yb in test_loader:
318
+ xb = xb.to(device)
319
+ logits = model(xb)
320
+ preds = logits.argmax(dim=1).cpu()
321
+ all_preds.append(preds)
322
+ all_targets.append(yb)
323
+ all_preds = torch.cat(all_preds)
324
+ all_targets = torch.cat(all_targets)
325
+
326
+ num_classes = len(classes)
327
+ cm_metric = MulticlassConfusionMatrix(num_classes=num_classes)
328
+ cm = cm_metric(all_preds, all_targets).numpy()
329
+ cm_norm = cm / cm.sum(axis=1, keepdims=True)
330
+
331
+ reports_dir.mkdir(parents=True, exist_ok=True)
332
+ fig, ax = plt.subplots(figsize=(7, 6))
333
+ im = ax.imshow(cm_norm, interpolation="nearest")
334
+ ax.figure.colorbar(im, ax=ax)
335
+ ax.set(
336
+ xticks=np.arange(num_classes),
337
+ yticks=np.arange(num_classes),
338
+ xticklabels=classes,
339
+ yticklabels=classes,
340
+ ylabel="True label",
341
+ xlabel="Predicted label",
342
+ title=f"{title_prefix} Confusion Matrix (row-normalized)",
343
+ )
344
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
345
+ for i in range(num_classes):
346
+ for j in range(num_classes):
347
+ ax.text(
348
+ j, i, f"{cm_norm[i, j]*100:.1f}%", ha="center", va="center", fontsize=8
349
+ )
350
+ fig.tight_layout()
351
+ fig_path = reports_dir / "confusion_matrix.png"
352
+ plt.savefig(fig_path, dpi=200)
353
+ plt.close(fig)
354
+ print("Saved figure to:", fig_path)
355
+
356
+ np.save(reports_dir / "confusion_matrix_counts.npy", cm)
357
+ np.save(reports_dir / "confusion_matrix_norm.npy", cm_norm)
358
+
359
+ try:
360
+ with open(metrics_path) as f:
361
+ metrics = json.load(f)
362
+ except FileNotFoundError:
363
+ metrics = {}
364
+ metrics.update(
365
+ {
366
+ "confusion_matrix_counts_path": str(
367
+ reports_dir / "confusion_matrix_counts.npy"
368
+ ),
369
+ "confusion_matrix_norm_path": str(
370
+ reports_dir / "confusion_matrix_norm.npy"
371
+ ),
372
+ "confusion_matrix_figure": str(fig_path),
373
+ }
374
+ )
375
+ with open(metrics_path, "w") as f:
376
+ json.dump(metrics, f, indent=2)
377
+
378
+
379
+ # ----------------- main -----------------
380
+ def main():
381
+ parser = build_argparser()
382
+ args = parser.parse_args()
383
+ seed_all(args.seed)
384
+
385
+ base_cfg = load_yaml(args.config)
386
+ cfg = merge_cli_over_config_with_defaults(base_cfg, args, parser)
387
+
388
+ dataset = cfg["dataset"]
389
+ model_name = cfg.get("model_name", "smallcnn")
390
+
391
+ img_size = int(
392
+ cfg.get("img_size", 28 if dataset in ["fashion-mnist", "mnist"] else 32)
393
+ )
394
+ mean = cfg.get("mean", None)
395
+ std = cfg.get("std", None)
396
+
397
+ # defaults for grayscale datasets
398
+ if dataset in ["fashion-mnist", "mnist"]:
399
+ if mean is None or std is None:
400
+ if dataset == "fashion-mnist":
401
+ mean, std = [0.2860], [0.3530]
402
+ else:
403
+ mean, std = [0.1307], [0.3081]
404
+ # defaults for cifar10
405
+ if dataset == "cifar10" and (mean is None or std is None):
406
+ mean, std = [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]
407
+
408
+
409
+
410
+ device = get_device(cfg["device"])
411
+ print("device:", device)
412
+
413
+ run_id = f'{cfg["dataset"]}_{int(time.time())}'
414
+ LOG_DIR = Path(cfg["log_root"]) / run_id
415
+ CKPTS_DIR = Path(cfg["ckpt_root"]) / run_id
416
+ REPORTS_DIR = Path(cfg["reports_root"]) / run_id
417
+ for d in (LOG_DIR, CKPTS_DIR, REPORTS_DIR):
418
+ d.mkdir(parents=True, exist_ok=True)
419
+
420
+ effective_cfg = deepcopy(cfg)
421
+ effective_cfg["run_id"] = run_id
422
+ with open(REPORTS_DIR / "config_effective.yaml", "w") as f:
423
+ yaml.safe_dump(effective_cfg, f)
424
+
425
+ train_loader, val_loader, test_loader, classes = get_dataloaders(
426
+ dataset,
427
+ cfg["data_dir"],
428
+ cfg["batch_size"],
429
+ cfg["num_workers"],
430
+ cfg["seed"],
431
+ img_size,
432
+ mean,
433
+ std,
434
+ )
435
+
436
+ loss_fn = nn.CrossEntropyLoss()
437
+
438
+ model, default_target_layer = build_model(
439
+ model_name, num_classes=len(classes), img_size=img_size
440
+ )
441
+ model = model.to(device)
442
+
443
+ opt_name = str(cfg.get("optimizer", "adam")).lower()
444
+ if opt_name == "sgd":
445
+ optimizer = optim.SGD(
446
+ model.parameters(),
447
+ lr=cfg["lr"],
448
+ momentum=float(cfg.get("momentum", 0.9)),
449
+ weight_decay=cfg["weight_decay"],
450
+ nesterov=True,
451
+ )
452
+ else:
453
+ optimizer = optim.Adam(
454
+ model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]
455
+ )
456
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
457
+ optimizer, mode="min", factor=0.5, patience=2
458
+ )
459
+
460
+ writer = SummaryWriter(log_dir=str(LOG_DIR))
461
+
462
+ monitor = cfg["early_stop"]["monitor"]
463
+ mode = cfg["early_stop"]["mode"]
464
+ patience = int(cfg["early_stop"]["patience"])
465
+ min_delta = float(cfg["early_stop"]["min_delta"])
466
+
467
+ best_val = float("inf") if mode == "min" else -float("inf")
468
+ epochs_no_improve = 0
469
+
470
+ ckpt_last = CKPTS_DIR / "last.ckpt"
471
+ ckpt_best = CKPTS_DIR / "best.ckpt"
472
+
473
+ for epoch in range(1, cfg["epochs"] + 1):
474
+ tr_loss, tr_acc = train_one_epoch(
475
+ model, train_loader, device, optimizer, loss_fn
476
+ )
477
+ va_loss, va_acc = eval_one_epoch(model, val_loader, device, loss_fn)
478
+ scheduler.step(va_loss)
479
+
480
+ writer.add_scalar("Loss/train", tr_loss, epoch)
481
+ writer.add_scalar("Loss/val", va_loss, epoch)
482
+ writer.add_scalar("Acc/train", tr_acc, epoch)
483
+ writer.add_scalar("Acc/val", va_acc, epoch)
484
+ writer.add_scalar("LR", optimizer.param_groups[0]["lr"], epoch)
485
+
486
+ print(
487
+ f"Epoch {epoch:02d} | train_loss={tr_loss:.4f} acc={tr_acc:.4f}"
488
+ + f" | val_loss={va_loss:.4f} acc={va_acc:.4f}"
489
+ )
490
+
491
+ mon_value = va_loss if monitor == "val_loss" else va_acc
492
+
493
+ payload = {
494
+ "epoch": epoch,
495
+ "model_state": model.state_dict(),
496
+ "optimizer_state": optimizer.state_dict(),
497
+ "val_acc": va_acc,
498
+ "val_loss": va_loss,
499
+ "dataset": cfg["dataset"],
500
+ "classes": classes,
501
+ "config_path": cfg.get("_config_path"),
502
+ "meta": {
503
+ "dataset": dataset,
504
+ "model_name": model_name,
505
+ "img_size": img_size,
506
+ "mean": mean,
507
+ "std": std,
508
+ "default_target_layer": default_target_layer,
509
+ },
510
+ }
511
+ save_checkpoint(payload, ckpt_last)
512
+
513
+ if is_improved(best_val, mon_value, mode, min_delta):
514
+ best_val = mon_value
515
+ epochs_no_improve = 0
516
+ save_checkpoint(payload, ckpt_best)
517
+ best_json = {
518
+ "epoch": epoch,
519
+ "monitor": monitor,
520
+ "mode": mode,
521
+ "best_value": float(best_val),
522
+ "val_acc": float(va_acc),
523
+ "val_loss": float(va_loss),
524
+ "ckpt_path": str(ckpt_best),
525
+ "meta": {
526
+ "dataset": dataset,
527
+ "model_name": model_name,
528
+ "img_size": img_size,
529
+ "mean": mean,
530
+ "std": std,
531
+ "default_target_layer": default_target_layer,
532
+ },
533
+ }
534
+ with open(REPORTS_DIR / "best.json", "w") as f:
535
+ json.dump(best_json, f, indent=2)
536
+ else:
537
+ epochs_no_improve += 1
538
+
539
+ if epochs_no_improve >= patience:
540
+ print(f"Early stopping: no improvement in {patience} epochs.")
541
+ break
542
+
543
+ writer.close()
544
+ print(f"Best {monitor}: {best_val:.4f}")
545
+
546
+ # Use best checkpoint for reports
547
+ best_ckpt = torch.load(str(ckpt_best), map_location=device)
548
+ model.load_state_dict(best_ckpt["model_state"])
549
+ model.eval()
550
+
551
+ metrics_path = REPORTS_DIR / "metrics.json"
552
+ confusion_matrix_report(
553
+ model,
554
+ test_loader,
555
+ device,
556
+ classes,
557
+ reports_dir=REPORTS_DIR / "figures",
558
+ metrics_path=metrics_path,
559
+ title_prefix=cfg["dataset"].replace("-", " ").title(),
560
+ )
561
+
562
+ metrics = {
563
+ "dataset": cfg["dataset"],
564
+ "epochs_ran": epoch,
565
+ "batch_size": cfg["batch_size"],
566
+ "lr": cfg["lr"],
567
+ "weight_decay": cfg["weight_decay"],
568
+ "best_monitor": monitor,
569
+ "best_mode": mode,
570
+ "best_value": float(best_val),
571
+ "logs_dir": str(LOG_DIR),
572
+ "ckpts_dir": str(CKPTS_DIR),
573
+ "reports_dir": str(REPORTS_DIR),
574
+ }
575
+ with open(metrics_path, "w") as f:
576
+ json.dump(metrics, f, indent=2)
577
+
578
+ print("Saved metrics to:", metrics_path)
579
+ print("Best checkpoint:", ckpt_best)
580
+
581
+
582
+ if __name__ == "__main__":
583
+ main()