CIFAR10 Classifier trained with PyTorch Lightning

Introduction

A ResNet18 model that achieves 94% prediction accuracy. Key features include:

  1. Data normalization and randomization (10% improvement)
  2. Dropout before FC classifier (1% improvement)
  3. Batch normalization in ResNetBlock (2% improvement)
  4. Cos learning rate schedule (1% improvement)
  5. ResNet18 is deeper than a simple CNN network.

Usage

Approach 1: use pytorch to predict


## Approach 1: use pytorch to predict
import torch
from model import CIFARCNN

# Evaluate model checkpoints
model = CIFARCNN.load_from_checkpoint("model.ckpt")
model.eval()
x = torch.randn(4, 3, 32, 32).to(model.device)

with torch.no_grad():
    predictions = model(x)     # the lightning module should implement forward func
print(predictions.shape)  # should be [4, 10]

Approach 2: use Lightning to predict

import torch
from model import CIFARCNN
from lightning import Trainer

test_dataloader = DataLoader(...)
model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device
trainer = Trainer()

trainer.test(model, test_dataloader)

Visualize results

import matplotlib.pyplot as plt

cifar10_labels = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

samples, labels = next(iter(train_loader))
predicts = trainer.predict(model, samples)
labels = predicts.argmax(dim=1)

fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(samples[i].permute(1, 2, 0))
    ax.set_title(f"{cifar10_labels[labels[i].item()]}")
    ax.axis("off")
plt.show()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support