Spaces:
Sleeping
Sleeping
File size: 6,282 Bytes
dfafaa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
---
title: ExplainableCNN
app_file: app/gradio_app.py
sdk: gradio
sdk_version: 5.47.0
---
# ExplainableCNN
End‑to‑end image classification with explainability. Train CNNs on common vision datasets, save checkpoints and metrics, and visualize Grad‑CAM/Grad‑CAM++ heatmaps in a Streamlit app.
>**Online App**: You can try the app online at `https://explainable-cnn.streamlit.app`
Contents
- Quick start
- Installation
- Datasets
- Training
- Configuration reference
- Streamlit Grad‑CAM demo
- Checkpoints and outputs
- Project layout
- FAQ / Tips
## Quick start
1) Install dependencies (CPU‑only by default):
```bash
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
```
2) Train with defaults (Fashion‑MNIST, small CNN):
```bash
python -m src.train --config configs/baseline.yaml
```
3) Launch the Grad‑CAM demo and visualize predictions:
```bash
streamlit run app/streamlit_app.py
```
## Installation
This repo ships with CPU‑only PyTorch wheels via the official extra index in `requirements.txt`. If you have CUDA, you can install the matching GPU wheels from PyTorch and keep the rest of the requirements.
```bash
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```
### GPU installation (recommended for training)
1) Install CUDA‑enabled PyTorch that matches your driver and CUDA version (see `https://pytorch.org/get-started/locally/`). Examples:
```bash
# CUDA 12.1
pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio
# CUDA 11.8
# pip install --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio
```
2) Install the rest of the project dependencies (excluding torch*):
```bash
pip install -r requirements-gpu.txt
```
Notes
- If you want GPU builds: follow the selector at `https://pytorch.org/get-started/locally/` and install the torch/torchvision/torchaudio triplet before installing the rest of the requirements.
- This project uses: torch/torchvision, torchmetrics, captum/torchcam, lightning (as the newer package name), albumentations, TensorBoard, Streamlit, PyYAML, etc.
## Datasets
Built‑in dataset options for training:
- `fashion-mnist` (default)
- `mnist`
- `cifar10`
Where data lives
- By default datasets are downloaded under `data/`.
## Training
Run with a YAML config
```bash
python -m src.train --config configs/baseline.yaml
```
Override config values from the CLI
```bash
python -m src.train --config configs/baseline.yaml --epochs 12 --lr 5e-4
```
Switch dataset from the CLI
```bash
python -m src.train --config configs/baseline.yaml --dataset mnist
```
Use a ResNet‑18 backbone for CIFAR‑10 (adapted conv1/no maxpool)
```bash
python -m src.train --config configs/cifar10_resnet18.yaml
```
Training flow (high level)
- Loads YAML config and merges CLI overrides
- Builds dataloaders with dataset‑specific transforms and normalization
- Builds model: `smallcnn`, `resnet18_cifar`, or `resnet18_imagenet`
- Optimizer: Adam (default) or SGD with momentum
- Trains with early stopping and ReduceLROnPlateau on val loss
- Writes TensorBoard logs, metrics JSONs, and image reports (confusion matrix)
- Saves `last.ckpt` and `best.ckpt` with model weights and metadata
Outputs per run (under roots from config)
- `runs/<run_id>/` TensorBoard logs
- `checkpoints/<run_id>/last.ckpt` and `best.ckpt`
- `reports/<run_id>/config_effective.yaml`, `metrics.json`, and `figures/confusion_matrix.png`
## Configuration reference
See examples in `configs/`:
- `baseline.yaml` (Fashion‑MNIST + `smallcnn`)
- `cifar10_resnet18.yaml` (CIFAR‑10 + adapted ResNet‑18)
Common keys
- `dataset`: one of `fashion-mnist`, `mnist`, `cifar10`
- `model_name`: `smallcnn` | `resnet18_cifar` | `resnet18_imagenet`
- `data_dir`: root folder for data (default `./data`)
- `batch_size`, `epochs`, `lr`, `weight_decay`, `num_workers`, `seed`, `device`
- `img_size`, `mean`, `std`: image shape and normalization stats
- `optimizer`: `adam` (default) or `sgd`; `momentum` used for SGD
- `log_root`, `ckpt_root`, `reports_root`: base folders for artifacts
- `early_stop`: `{ monitor: val_loss|val_acc, mode: min|max, patience, min_delta }`
CLI flags can override the YAML. For example `--dataset`, `--epochs`, `--lr`, `--model-name`.
## Streamlit Grad‑CAM demo
Start the app (or try it online at `https://explainable-cnn.streamlit.app`)
```bash
streamlit run app/streamlit_app.py
```
What it does
- Load a trained checkpoint (`.ckpt`)
- Upload an image or sample one from the corresponding dataset
- Run inference and display top‑k predictions
- Visualize Grad‑CAM or Grad‑CAM++ overlays with adjustable alpha
Supplying checkpoints
- Local discovery: put `.ckpt` files under `saved_checkpoints/` or use the file uploader
- Download from a URL: paste a direct link to a `.ckpt` asset and click “Download checkpoint”
- Presets: provide a map of names → URLs via one of:
- Streamlit secrets: `st.secrets["release_checkpoints"] = { "Name": "https://...best.ckpt" }`
- `.streamlit/presets.json` or `presets.json` in repo root, either:
```json
{ "release_checkpoints": { "FMNIST SmallCNN": "https://.../best.ckpt" } }
```
or a flat mapping `{ "FMNIST SmallCNN": "https://..." }`
- Environment variable `RELEASE_CKPTS_JSON` with a JSON mapping string
Devices and CAM methods
- Device: `auto` (default), `cuda`, or `cpu`
- CAM: `Grad-CAM` or `Grad-CAM++` via `torchcam`
Checkpoint metadata expected
- `meta`: `{ dataset, model_name, img_size, mean, std, default_target_layer }`
- `classes`: list of class names (used to label predictions)
## Checkpoints and outputs
Each run writes:
- Checkpoints: `<ckpt_root>/<run_id>/{last.ckpt,best.ckpt}`
- Logs: `<log_root>/<run_id>/` (TensorBoard)
- Reports: `<reports_root>/<run_id>/metrics.json`, `figures/confusion_matrix.png`
Best checkpoint selection respects early‑stopping monitor (`val_loss` or `val_acc`).
## License and acknowledgements
- Uses `torchcam` for CAM extraction and `captum` as a general explainability dependency
- TorchVision models and datasets are used for baselines and data handling
___
If you run into issues, please open an issue with your command, config file, and environment details.
|