Commit
·
bbc7a9d
1
Parent(s):
63cd96e
bug fix
Browse files
app.py
CHANGED
|
@@ -16,7 +16,7 @@ with open(food_vision_class_names_path, "r") as f:
|
|
| 16 |
class_names = f.read().splitlines()
|
| 17 |
|
| 18 |
# Specify number of classes
|
| 19 |
-
|
| 20 |
|
| 21 |
# Load the food description file
|
| 22 |
food_descriptions_json = "food_descriptions.json"
|
|
@@ -28,7 +28,7 @@ classification_model_name_path = "effnetb0_classif_epoch13.pth"
|
|
| 28 |
effnetb0_model, _ = create_effnetb0(
|
| 29 |
model_weights_dir=".",
|
| 30 |
model_weights_name=classification_model_name_path,
|
| 31 |
-
num_classes=
|
| 32 |
)
|
| 33 |
|
| 34 |
# Load the ViT-Base transformer
|
|
@@ -38,7 +38,7 @@ vitbase_model = create_vitbase_model(
|
|
| 38 |
model_weights_dir=".",
|
| 39 |
model_weights_name=food_vision_model_name_path,
|
| 40 |
img_size=IMG_SIZE,
|
| 41 |
-
num_classes=
|
| 42 |
)
|
| 43 |
|
| 44 |
# Specify manual transforms
|
|
@@ -76,7 +76,7 @@ def predict(img) -> Tuple[Dict, float]:
|
|
| 76 |
else:
|
| 77 |
|
| 78 |
# Set all probabilites to zero
|
| 79 |
-
pred_probs = torch.tensor([[0.0] *
|
| 80 |
|
| 81 |
# Calculate entropy
|
| 82 |
entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
|
|
|
|
| 16 |
class_names = f.read().splitlines()
|
| 17 |
|
| 18 |
# Specify number of classes
|
| 19 |
+
NUM_CLASSES = len(class_names) - 1 # "unknown class to be removed to properly load the model"
|
| 20 |
|
| 21 |
# Load the food description file
|
| 22 |
food_descriptions_json = "food_descriptions.json"
|
|
|
|
| 28 |
effnetb0_model, _ = create_effnetb0(
|
| 29 |
model_weights_dir=".",
|
| 30 |
model_weights_name=classification_model_name_path,
|
| 31 |
+
num_classes=2
|
| 32 |
)
|
| 33 |
|
| 34 |
# Load the ViT-Base transformer
|
|
|
|
| 38 |
model_weights_dir=".",
|
| 39 |
model_weights_name=food_vision_model_name_path,
|
| 40 |
img_size=IMG_SIZE,
|
| 41 |
+
num_classes=NUM_CLASSES
|
| 42 |
)
|
| 43 |
|
| 44 |
# Specify manual transforms
|
|
|
|
| 76 |
else:
|
| 77 |
|
| 78 |
# Set all probabilites to zero
|
| 79 |
+
pred_probs = torch.tensor([[0.0] * NUM_CLASSES])
|
| 80 |
|
| 81 |
# Calculate entropy
|
| 82 |
entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
|
model.py
CHANGED
|
@@ -83,7 +83,7 @@ def create_vitbase_model(
|
|
| 83 |
def create_effnetb0(
|
| 84 |
model_weights_dir: Path,
|
| 85 |
model_weights_name: str,
|
| 86 |
-
num_classes: int=
|
| 87 |
dropout: float=0.2
|
| 88 |
):
|
| 89 |
"""Creates an EfficientNetB0 feature extractor model and transforms.
|
|
|
|
| 83 |
def create_effnetb0(
|
| 84 |
model_weights_dir: Path,
|
| 85 |
model_weights_name: str,
|
| 86 |
+
num_classes: int=2,
|
| 87 |
dropout: float=0.2
|
| 88 |
):
|
| 89 |
"""Creates an EfficientNetB0 feature extractor model and transforms.
|