sergio-sanz-rodriguez commited on
Commit
bbc7a9d
·
1 Parent(s): 63cd96e
Files changed (2) hide show
  1. app.py +4 -4
  2. model.py +1 -1
app.py CHANGED
@@ -16,7 +16,7 @@ with open(food_vision_class_names_path, "r") as f:
16
  class_names = f.read().splitlines()
17
 
18
  # Specify number of classes
19
- num_classes = len(class_names) - 1 # "unknown class to be removed to properly load the model"
20
 
21
  # Load the food description file
22
  food_descriptions_json = "food_descriptions.json"
@@ -28,7 +28,7 @@ classification_model_name_path = "effnetb0_classif_epoch13.pth"
28
  effnetb0_model, _ = create_effnetb0(
29
  model_weights_dir=".",
30
  model_weights_name=classification_model_name_path,
31
- num_classes=num_classes
32
  )
33
 
34
  # Load the ViT-Base transformer
@@ -38,7 +38,7 @@ vitbase_model = create_vitbase_model(
38
  model_weights_dir=".",
39
  model_weights_name=food_vision_model_name_path,
40
  img_size=IMG_SIZE,
41
- num_classes=num_classes
42
  )
43
 
44
  # Specify manual transforms
@@ -76,7 +76,7 @@ def predict(img) -> Tuple[Dict, float]:
76
  else:
77
 
78
  # Set all probabilites to zero
79
- pred_probs = torch.tensor([[0.0] * num_classes])
80
 
81
  # Calculate entropy
82
  entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
 
16
  class_names = f.read().splitlines()
17
 
18
  # Specify number of classes
19
+ NUM_CLASSES = len(class_names) - 1 # "unknown class to be removed to properly load the model"
20
 
21
  # Load the food description file
22
  food_descriptions_json = "food_descriptions.json"
 
28
  effnetb0_model, _ = create_effnetb0(
29
  model_weights_dir=".",
30
  model_weights_name=classification_model_name_path,
31
+ num_classes=2
32
  )
33
 
34
  # Load the ViT-Base transformer
 
38
  model_weights_dir=".",
39
  model_weights_name=food_vision_model_name_path,
40
  img_size=IMG_SIZE,
41
+ num_classes=NUM_CLASSES
42
  )
43
 
44
  # Specify manual transforms
 
76
  else:
77
 
78
  # Set all probabilites to zero
79
+ pred_probs = torch.tensor([[0.0] * NUM_CLASSES])
80
 
81
  # Calculate entropy
82
  entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
model.py CHANGED
@@ -83,7 +83,7 @@ def create_vitbase_model(
83
  def create_effnetb0(
84
  model_weights_dir: Path,
85
  model_weights_name: str,
86
- num_classes: int=101,
87
  dropout: float=0.2
88
  ):
89
  """Creates an EfficientNetB0 feature extractor model and transforms.
 
83
  def create_effnetb0(
84
  model_weights_dir: Path,
85
  model_weights_name: str,
86
+ num_classes: int=2,
87
  dropout: float=0.2
88
  ):
89
  """Creates an EfficientNetB0 feature extractor model and transforms.