File size: 6,282 Bytes
dfafaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
---
title: ExplainableCNN
app_file: app/gradio_app.py
sdk: gradio
sdk_version: 5.47.0
---
# ExplainableCNN

End‑to‑end image classification with explainability. Train CNNs on common vision datasets, save checkpoints and metrics, and visualize Grad‑CAM/Grad‑CAM++ heatmaps in a Streamlit app.

>**Online App**: You can try the app online at `https://explainable-cnn.streamlit.app`

Contents
- Quick start
- Installation
- Datasets
- Training
- Configuration reference
- Streamlit Grad‑CAM demo
- Checkpoints and outputs
- Project layout
- FAQ / Tips

## Quick start

1) Install dependencies (CPU‑only by default):
```bash
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
```

2) Train with defaults (Fashion‑MNIST, small CNN):
```bash
python -m src.train --config configs/baseline.yaml
```

3) Launch the Grad‑CAM demo and visualize predictions:
```bash
streamlit run app/streamlit_app.py
```

## Installation

This repo ships with CPU‑only PyTorch wheels via the official extra index in `requirements.txt`. If you have CUDA, you can install the matching GPU wheels from PyTorch and keep the rest of the requirements.

```bash
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```

### GPU installation (recommended for training)

1) Install CUDA‑enabled PyTorch that matches your driver and CUDA version (see `https://pytorch.org/get-started/locally/`). Examples:

```bash
# CUDA 12.1
pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio

# CUDA 11.8
# pip install --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio
```

2) Install the rest of the project dependencies (excluding torch*):

```bash
pip install -r requirements-gpu.txt
```

Notes
- If you want GPU builds: follow the selector at `https://pytorch.org/get-started/locally/` and install the torch/torchvision/torchaudio triplet before installing the rest of the requirements.
- This project uses: torch/torchvision, torchmetrics, captum/torchcam, lightning (as the newer package name), albumentations, TensorBoard, Streamlit, PyYAML, etc.

## Datasets

Built‑in dataset options for training:
- `fashion-mnist` (default)
- `mnist`
- `cifar10`

Where data lives
- By default datasets are downloaded under `data/`.

## Training

Run with a YAML config
```bash
python -m src.train --config configs/baseline.yaml
```

Override config values from the CLI
```bash
python -m src.train --config configs/baseline.yaml --epochs 12 --lr 5e-4
```

Switch dataset from the CLI
```bash
python -m src.train --config configs/baseline.yaml --dataset mnist
```

Use a ResNet‑18 backbone for CIFAR‑10 (adapted conv1/no maxpool)
```bash
python -m src.train --config configs/cifar10_resnet18.yaml
```


Training flow (high level)
- Loads YAML config and merges CLI overrides
- Builds dataloaders with dataset‑specific transforms and normalization
- Builds model: `smallcnn`, `resnet18_cifar`, or `resnet18_imagenet`
- Optimizer: Adam (default) or SGD with momentum
- Trains with early stopping and ReduceLROnPlateau on val loss
- Writes TensorBoard logs, metrics JSONs, and image reports (confusion matrix)
- Saves `last.ckpt` and `best.ckpt` with model weights and metadata

Outputs per run (under roots from config)
- `runs/<run_id>/` TensorBoard logs
- `checkpoints/<run_id>/last.ckpt` and `best.ckpt`
- `reports/<run_id>/config_effective.yaml`, `metrics.json`, and `figures/confusion_matrix.png`

## Configuration reference

See examples in `configs/`:
- `baseline.yaml` (Fashion‑MNIST + `smallcnn`)
- `cifar10_resnet18.yaml` (CIFAR‑10 + adapted ResNet‑18)

Common keys
- `dataset`: one of `fashion-mnist`, `mnist`, `cifar10`
- `model_name`: `smallcnn` | `resnet18_cifar` | `resnet18_imagenet`
- `data_dir`: root folder for data (default `./data`)
- `batch_size`, `epochs`, `lr`, `weight_decay`, `num_workers`, `seed`, `device`
- `img_size`, `mean`, `std`: image shape and normalization stats
- `optimizer`: `adam` (default) or `sgd`; `momentum` used for SGD
- `log_root`, `ckpt_root`, `reports_root`: base folders for artifacts
- `early_stop`: `{ monitor: val_loss|val_acc, mode: min|max, patience, min_delta }`

CLI flags can override the YAML. For example `--dataset`, `--epochs`, `--lr`, `--model-name`.

## Streamlit Grad‑CAM demo

Start the app (or try it online at `https://explainable-cnn.streamlit.app`)
```bash
streamlit run app/streamlit_app.py
```

What it does
- Load a trained checkpoint (`.ckpt`)
- Upload an image or sample one from the corresponding dataset
- Run inference and display top‑k predictions
- Visualize Grad‑CAM or Grad‑CAM++ overlays with adjustable alpha

Supplying checkpoints
- Local discovery: put `.ckpt` files under `saved_checkpoints/` or use the file uploader
- Download from a URL: paste a direct link to a `.ckpt` asset and click “Download checkpoint”
- Presets: provide a map of names → URLs via one of:
  - Streamlit secrets: `st.secrets["release_checkpoints"] = { "Name": "https://...best.ckpt" }`
  - `.streamlit/presets.json` or `presets.json` in repo root, either:
    ```json
    { "release_checkpoints": { "FMNIST SmallCNN": "https://.../best.ckpt" } }
    ```
    or a flat mapping `{ "FMNIST SmallCNN": "https://..." }`
  - Environment variable `RELEASE_CKPTS_JSON` with a JSON mapping string

Devices and CAM methods
- Device: `auto` (default), `cuda`, or `cpu`
- CAM: `Grad-CAM` or `Grad-CAM++` via `torchcam`

Checkpoint metadata expected
- `meta`: `{ dataset, model_name, img_size, mean, std, default_target_layer }`
- `classes`: list of class names (used to label predictions)

## Checkpoints and outputs

Each run writes:
- Checkpoints: `<ckpt_root>/<run_id>/{last.ckpt,best.ckpt}`
- Logs: `<log_root>/<run_id>/` (TensorBoard)
- Reports: `<reports_root>/<run_id>/metrics.json`, `figures/confusion_matrix.png`

Best checkpoint selection respects early‑stopping monitor (`val_loss` or `val_acc`).

## License and acknowledgements

- Uses `torchcam` for CAM extraction and `captum` as a general explainability dependency
- TorchVision models and datasets are used for baselines and data handling

___

If you run into issues, please open an issue with your command, config file, and environment details.