Yichuan Huang commited on
Commit
2514593
·
verified ·
1 Parent(s): a738c69

Upload 9 files

Browse files
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torchvision
4
+ import gradio as gr
5
+
6
+ # Define and load my resnet50 model
7
+ model = torchvision.models.resnet50()
8
+ num_ftrs = model.fc.in_features
9
+ model.fc = nn.Sequential(
10
+ # Add dropout layer with 50% probability
11
+ nn.Dropout(0.5),
12
+ # Add a linear layer in order to deal with 5 classes
13
+ nn.Linear(num_ftrs, 5),
14
+ )
15
+
16
+ model.load_state_dict(
17
+ torch.load("model/final_model_state_dict.pth", map_location=torch.device("cpu"))
18
+ )
19
+ model.eval()
20
+
21
+ # Define the labels
22
+ labels = ["bird", "cat", "dog", "horse", "sheep"]
23
+
24
+
25
+ # Define the predict function
26
+ def predict(inp):
27
+ inp = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
28
+ with torch.no_grad():
29
+ prediction = model(inp)
30
+ # Map prediction to label
31
+ prediction = labels[prediction.argmax()]
32
+ return prediction
33
+
34
+
35
+ # Define the gradio interface
36
+ demo = gr.Interface(
37
+ fn=predict,
38
+ inputs=gr.Image(type="pil"),
39
+ outputs=gr.Label(num_top_classes=5),
40
+ examples=[["demo/input_imgs/cat.jpeg"], ["demo/input_imgs/dog.jpeg"]],
41
+ )
42
+
43
+ demo.launch()
input_imgs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
input_imgs/bird.jpeg ADDED
input_imgs/cat.jpeg ADDED
input_imgs/dog.jpeg ADDED
input_imgs/horse.jpeg ADDED
input_imgs/sheep.jpeg ADDED
model/final_model_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bda7d53e1240258ec39a6ed2270516f4144a5813cd4dabdc324a877c1f3dea9
3
+ size 94392002
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pillow==10.3.0
2
+ pycocotools==2.0.7
3
+ scikit-learn==1.4.2
4
+ tensorboard==2.16.2
5
+ torch==2.3.0
6
+ torch-tb-profiler==0.4.3
7
+ torchvision==0.18.0
8
+ gradio==4.32.2