CIFAR10 Classifier trained with PyTorch Lightning
Introduction
A ResNet18 model that achieves 94% prediction accuracy. Key features include:
- Data normalization and randomization (10% improvement)
- Dropout before FC classifier (1% improvement)
- Batch normalization in ResNetBlock (2% improvement)
- Cos learning rate schedule (1% improvement)
- ResNet18 is deeper than a simple CNN network.
Usage
Approach 1: use pytorch to predict
## Approach 1: use pytorch to predict
import torch
from model import CIFARCNN
# Evaluate model checkpoints
model = CIFARCNN.load_from_checkpoint("model.ckpt")
model.eval()
x = torch.randn(4, 3, 32, 32).to(model.device)
with torch.no_grad():
predictions = model(x) # the lightning module should implement forward func
print(predictions.shape) # should be [4, 10]
Approach 2: use Lightning to predict
import torch
from model import CIFARCNN
from lightning import Trainer
test_dataloader = DataLoader(...)
model = CIFARCNN.load_from_checkpoint("model.ckpt") # lightning will move model to default device
trainer = Trainer()
trainer.test(model, test_dataloader)
Visualize results
import matplotlib.pyplot as plt
cifar10_labels = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck",
}
samples, labels = next(iter(train_loader))
predicts = trainer.predict(model, samples)
labels = predicts.argmax(dim=1)
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flatten()):
ax.imshow(samples[i].permute(1, 2, 0))
ax.set_title(f"{cifar10_labels[labels[i].item()]}")
ax.axis("off")
plt.show()
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support