ExplainableCNN / README.MD
Stefano01's picture
Upload folder using huggingface_hub
dfafaa4 verified
---
title: ExplainableCNN
app_file: app/gradio_app.py
sdk: gradio
sdk_version: 5.47.0
---
# ExplainableCNN
End‑to‑end image classification with explainability. Train CNNs on common vision datasets, save checkpoints and metrics, and visualize Grad‑CAM/Grad‑CAM++ heatmaps in a Streamlit app.
>**Online App**: You can try the app online at `https://explainable-cnn.streamlit.app`
Contents
- Quick start
- Installation
- Datasets
- Training
- Configuration reference
- Streamlit Grad‑CAM demo
- Checkpoints and outputs
- Project layout
- FAQ / Tips
## Quick start
1) Install dependencies (CPU‑only by default):
```bash
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
```
2) Train with defaults (Fashion‑MNIST, small CNN):
```bash
python -m src.train --config configs/baseline.yaml
```
3) Launch the Grad‑CAM demo and visualize predictions:
```bash
streamlit run app/streamlit_app.py
```
## Installation
This repo ships with CPU‑only PyTorch wheels via the official extra index in `requirements.txt`. If you have CUDA, you can install the matching GPU wheels from PyTorch and keep the rest of the requirements.
```bash
python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
```
### GPU installation (recommended for training)
1) Install CUDA‑enabled PyTorch that matches your driver and CUDA version (see `https://pytorch.org/get-started/locally/`). Examples:
```bash
# CUDA 12.1
pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio
# CUDA 11.8
# pip install --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio
```
2) Install the rest of the project dependencies (excluding torch*):
```bash
pip install -r requirements-gpu.txt
```
Notes
- If you want GPU builds: follow the selector at `https://pytorch.org/get-started/locally/` and install the torch/torchvision/torchaudio triplet before installing the rest of the requirements.
- This project uses: torch/torchvision, torchmetrics, captum/torchcam, lightning (as the newer package name), albumentations, TensorBoard, Streamlit, PyYAML, etc.
## Datasets
Built‑in dataset options for training:
- `fashion-mnist` (default)
- `mnist`
- `cifar10`
Where data lives
- By default datasets are downloaded under `data/`.
## Training
Run with a YAML config
```bash
python -m src.train --config configs/baseline.yaml
```
Override config values from the CLI
```bash
python -m src.train --config configs/baseline.yaml --epochs 12 --lr 5e-4
```
Switch dataset from the CLI
```bash
python -m src.train --config configs/baseline.yaml --dataset mnist
```
Use a ResNet‑18 backbone for CIFAR‑10 (adapted conv1/no maxpool)
```bash
python -m src.train --config configs/cifar10_resnet18.yaml
```
Training flow (high level)
- Loads YAML config and merges CLI overrides
- Builds dataloaders with dataset‑specific transforms and normalization
- Builds model: `smallcnn`, `resnet18_cifar`, or `resnet18_imagenet`
- Optimizer: Adam (default) or SGD with momentum
- Trains with early stopping and ReduceLROnPlateau on val loss
- Writes TensorBoard logs, metrics JSONs, and image reports (confusion matrix)
- Saves `last.ckpt` and `best.ckpt` with model weights and metadata
Outputs per run (under roots from config)
- `runs/<run_id>/` TensorBoard logs
- `checkpoints/<run_id>/last.ckpt` and `best.ckpt`
- `reports/<run_id>/config_effective.yaml`, `metrics.json`, and `figures/confusion_matrix.png`
## Configuration reference
See examples in `configs/`:
- `baseline.yaml` (Fashion‑MNIST + `smallcnn`)
- `cifar10_resnet18.yaml` (CIFAR‑10 + adapted ResNet‑18)
Common keys
- `dataset`: one of `fashion-mnist`, `mnist`, `cifar10`
- `model_name`: `smallcnn` | `resnet18_cifar` | `resnet18_imagenet`
- `data_dir`: root folder for data (default `./data`)
- `batch_size`, `epochs`, `lr`, `weight_decay`, `num_workers`, `seed`, `device`
- `img_size`, `mean`, `std`: image shape and normalization stats
- `optimizer`: `adam` (default) or `sgd`; `momentum` used for SGD
- `log_root`, `ckpt_root`, `reports_root`: base folders for artifacts
- `early_stop`: `{ monitor: val_loss|val_acc, mode: min|max, patience, min_delta }`
CLI flags can override the YAML. For example `--dataset`, `--epochs`, `--lr`, `--model-name`.
## Streamlit Grad‑CAM demo
Start the app (or try it online at `https://explainable-cnn.streamlit.app`)
```bash
streamlit run app/streamlit_app.py
```
What it does
- Load a trained checkpoint (`.ckpt`)
- Upload an image or sample one from the corresponding dataset
- Run inference and display top‑k predictions
- Visualize Grad‑CAM or Grad‑CAM++ overlays with adjustable alpha
Supplying checkpoints
- Local discovery: put `.ckpt` files under `saved_checkpoints/` or use the file uploader
- Download from a URL: paste a direct link to a `.ckpt` asset and click “Download checkpoint”
- Presets: provide a map of names → URLs via one of:
- Streamlit secrets: `st.secrets["release_checkpoints"] = { "Name": "https://...best.ckpt" }`
- `.streamlit/presets.json` or `presets.json` in repo root, either:
```json
{ "release_checkpoints": { "FMNIST SmallCNN": "https://.../best.ckpt" } }
```
or a flat mapping `{ "FMNIST SmallCNN": "https://..." }`
- Environment variable `RELEASE_CKPTS_JSON` with a JSON mapping string
Devices and CAM methods
- Device: `auto` (default), `cuda`, or `cpu`
- CAM: `Grad-CAM` or `Grad-CAM++` via `torchcam`
Checkpoint metadata expected
- `meta`: `{ dataset, model_name, img_size, mean, std, default_target_layer }`
- `classes`: list of class names (used to label predictions)
## Checkpoints and outputs
Each run writes:
- Checkpoints: `<ckpt_root>/<run_id>/{last.ckpt,best.ckpt}`
- Logs: `<log_root>/<run_id>/` (TensorBoard)
- Reports: `<reports_root>/<run_id>/metrics.json`, `figures/confusion_matrix.png`
Best checkpoint selection respects early‑stopping monitor (`val_loss` or `val_acc`).
## License and acknowledgements
- Uses `torchcam` for CAM extraction and `captum` as a general explainability dependency
- TorchVision models and datasets are used for baselines and data handling
___
If you run into issues, please open an issue with your command, config file, and environment details.