|
|
--- |
|
|
license: mit |
|
|
--- |
|
|
|
|
|
## Demo |
|
|
https://huggingface.co/spaces/jerilseb/quickdraw-small |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from torch import nn |
|
|
import torchvision.transforms as transforms |
|
|
import torch.nn.functional as F |
|
|
from pathlib import Path |
|
|
|
|
|
LABELS = Path("classes.txt").read_text().splitlines() |
|
|
num_classes = len(LABELS) |
|
|
|
|
|
model = nn.Sequential( |
|
|
nn.Conv2d(1, 64, 3, padding="same"), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Conv2d(64, 128, 3, padding="same"), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Conv2d(128, 256, 3, padding="same"), |
|
|
nn.ReLU(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Flatten(), |
|
|
nn.Linear(2304, 512), |
|
|
nn.ReLU(), |
|
|
nn.Linear(512, num_classes), |
|
|
) |
|
|
|
|
|
state_dict = torch.load("model.pth", map_location="cpu") |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose( |
|
|
[ |
|
|
transforms.Resize((28, 28)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.5,), (0.5,)), |
|
|
] |
|
|
) |
|
|
|
|
|
def predict(image): |
|
|
image = image['composite'] |
|
|
tensor = transform(image).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
out = model(tensor) |
|
|
|
|
|
probabilities = F.softmax(out[0], dim=0) |
|
|
values, indices = torch.topk(probabilities, 5) |
|
|
print(values, indices) |
|
|
``` |