Spaces:
Sleeping
Sleeping
| title: ExplainableCNN | |
| app_file: app/gradio_app.py | |
| sdk: gradio | |
| sdk_version: 5.47.0 | |
| # ExplainableCNN | |
| End‑to‑end image classification with explainability. Train CNNs on common vision datasets, save checkpoints and metrics, and visualize Grad‑CAM/Grad‑CAM++ heatmaps in a Streamlit app. | |
| >**Online App**: You can try the app online at `https://explainable-cnn.streamlit.app` | |
| Contents | |
| - Quick start | |
| - Installation | |
| - Datasets | |
| - Training | |
| - Configuration reference | |
| - Streamlit Grad‑CAM demo | |
| - Checkpoints and outputs | |
| - Project layout | |
| - FAQ / Tips | |
| ## Quick start | |
| 1) Install dependencies (CPU‑only by default): | |
| ```bash | |
| python -m venv .venv && source .venv/bin/activate | |
| pip install -r requirements.txt | |
| ``` | |
| 2) Train with defaults (Fashion‑MNIST, small CNN): | |
| ```bash | |
| python -m src.train --config configs/baseline.yaml | |
| ``` | |
| 3) Launch the Grad‑CAM demo and visualize predictions: | |
| ```bash | |
| streamlit run app/streamlit_app.py | |
| ``` | |
| ## Installation | |
| This repo ships with CPU‑only PyTorch wheels via the official extra index in `requirements.txt`. If you have CUDA, you can install the matching GPU wheels from PyTorch and keep the rest of the requirements. | |
| ```bash | |
| python -m venv .venv | |
| source .venv/bin/activate | |
| pip install -r requirements.txt | |
| ``` | |
| ### GPU installation (recommended for training) | |
| 1) Install CUDA‑enabled PyTorch that matches your driver and CUDA version (see `https://pytorch.org/get-started/locally/`). Examples: | |
| ```bash | |
| # CUDA 12.1 | |
| pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision torchaudio | |
| # CUDA 11.8 | |
| # pip install --index-url https://download.pytorch.org/whl/cu118 torch torchvision torchaudio | |
| ``` | |
| 2) Install the rest of the project dependencies (excluding torch*): | |
| ```bash | |
| pip install -r requirements-gpu.txt | |
| ``` | |
| Notes | |
| - If you want GPU builds: follow the selector at `https://pytorch.org/get-started/locally/` and install the torch/torchvision/torchaudio triplet before installing the rest of the requirements. | |
| - This project uses: torch/torchvision, torchmetrics, captum/torchcam, lightning (as the newer package name), albumentations, TensorBoard, Streamlit, PyYAML, etc. | |
| ## Datasets | |
| Built‑in dataset options for training: | |
| - `fashion-mnist` (default) | |
| - `mnist` | |
| - `cifar10` | |
| Where data lives | |
| - By default datasets are downloaded under `data/`. | |
| ## Training | |
| Run with a YAML config | |
| ```bash | |
| python -m src.train --config configs/baseline.yaml | |
| ``` | |
| Override config values from the CLI | |
| ```bash | |
| python -m src.train --config configs/baseline.yaml --epochs 12 --lr 5e-4 | |
| ``` | |
| Switch dataset from the CLI | |
| ```bash | |
| python -m src.train --config configs/baseline.yaml --dataset mnist | |
| ``` | |
| Use a ResNet‑18 backbone for CIFAR‑10 (adapted conv1/no maxpool) | |
| ```bash | |
| python -m src.train --config configs/cifar10_resnet18.yaml | |
| ``` | |
| Training flow (high level) | |
| - Loads YAML config and merges CLI overrides | |
| - Builds dataloaders with dataset‑specific transforms and normalization | |
| - Builds model: `smallcnn`, `resnet18_cifar`, or `resnet18_imagenet` | |
| - Optimizer: Adam (default) or SGD with momentum | |
| - Trains with early stopping and ReduceLROnPlateau on val loss | |
| - Writes TensorBoard logs, metrics JSONs, and image reports (confusion matrix) | |
| - Saves `last.ckpt` and `best.ckpt` with model weights and metadata | |
| Outputs per run (under roots from config) | |
| - `runs/<run_id>/` TensorBoard logs | |
| - `checkpoints/<run_id>/last.ckpt` and `best.ckpt` | |
| - `reports/<run_id>/config_effective.yaml`, `metrics.json`, and `figures/confusion_matrix.png` | |
| ## Configuration reference | |
| See examples in `configs/`: | |
| - `baseline.yaml` (Fashion‑MNIST + `smallcnn`) | |
| - `cifar10_resnet18.yaml` (CIFAR‑10 + adapted ResNet‑18) | |
| Common keys | |
| - `dataset`: one of `fashion-mnist`, `mnist`, `cifar10` | |
| - `model_name`: `smallcnn` | `resnet18_cifar` | `resnet18_imagenet` | |
| - `data_dir`: root folder for data (default `./data`) | |
| - `batch_size`, `epochs`, `lr`, `weight_decay`, `num_workers`, `seed`, `device` | |
| - `img_size`, `mean`, `std`: image shape and normalization stats | |
| - `optimizer`: `adam` (default) or `sgd`; `momentum` used for SGD | |
| - `log_root`, `ckpt_root`, `reports_root`: base folders for artifacts | |
| - `early_stop`: `{ monitor: val_loss|val_acc, mode: min|max, patience, min_delta }` | |
| CLI flags can override the YAML. For example `--dataset`, `--epochs`, `--lr`, `--model-name`. | |
| ## Streamlit Grad‑CAM demo | |
| Start the app (or try it online at `https://explainable-cnn.streamlit.app`) | |
| ```bash | |
| streamlit run app/streamlit_app.py | |
| ``` | |
| What it does | |
| - Load a trained checkpoint (`.ckpt`) | |
| - Upload an image or sample one from the corresponding dataset | |
| - Run inference and display top‑k predictions | |
| - Visualize Grad‑CAM or Grad‑CAM++ overlays with adjustable alpha | |
| Supplying checkpoints | |
| - Local discovery: put `.ckpt` files under `saved_checkpoints/` or use the file uploader | |
| - Download from a URL: paste a direct link to a `.ckpt` asset and click “Download checkpoint” | |
| - Presets: provide a map of names → URLs via one of: | |
| - Streamlit secrets: `st.secrets["release_checkpoints"] = { "Name": "https://...best.ckpt" }` | |
| - `.streamlit/presets.json` or `presets.json` in repo root, either: | |
| ```json | |
| { "release_checkpoints": { "FMNIST SmallCNN": "https://.../best.ckpt" } } | |
| ``` | |
| or a flat mapping `{ "FMNIST SmallCNN": "https://..." }` | |
| - Environment variable `RELEASE_CKPTS_JSON` with a JSON mapping string | |
| Devices and CAM methods | |
| - Device: `auto` (default), `cuda`, or `cpu` | |
| - CAM: `Grad-CAM` or `Grad-CAM++` via `torchcam` | |
| Checkpoint metadata expected | |
| - `meta`: `{ dataset, model_name, img_size, mean, std, default_target_layer }` | |
| - `classes`: list of class names (used to label predictions) | |
| ## Checkpoints and outputs | |
| Each run writes: | |
| - Checkpoints: `<ckpt_root>/<run_id>/{last.ckpt,best.ckpt}` | |
| - Logs: `<log_root>/<run_id>/` (TensorBoard) | |
| - Reports: `<reports_root>/<run_id>/metrics.json`, `figures/confusion_matrix.png` | |
| Best checkpoint selection respects early‑stopping monitor (`val_loss` or `val_acc`). | |
| ## License and acknowledgements | |
| - Uses `torchcam` for CAM extraction and `captum` as a general explainability dependency | |
| - TorchVision models and datasets are used for baselines and data handling | |
| ___ | |
| If you run into issues, please open an issue with your command, config file, and environment details. | |