transform-eats / app.py
sergio-sanz-rodriguez's picture
updated app with 101 + unknown trained classes
2bf4af2
raw
history blame
6.38 kB
# Imports and class names setup
import os
import torch
import json
import gradio as gr
from model import create_vitbase_model, create_effnetb0
from timeit import default_timer as timer
from typing import Tuple, Dict
from torchvision.transforms import v2
# Specify class names
food_vision_class_names_path = "class_names.txt"
with open(food_vision_class_names_path, "r") as f:
class_names = f.read().splitlines()
# Specify number of classes
num_classes = len(class_names) - 1 # 101, "unknown" to be discarded
# Load the food description file
food_descriptions_json = "food_descriptions.json"
with open(food_descriptions_json, 'r') as f:
food_descriptions = json.load(f)
# Instantiate the model
classification_model_name_path = "effnetb0_classif_epoch13.pth"
effnetb0_model = create_effnetb0(
model_weights_dir=".",
model_weights_name=classification_model_name_path,
num_classes=2
)
# Load the ViT-Base/16 transformer with input image of 384x384 pixels and 101 + unknown classes
vitbase_model = create_vitbase_model(
model_weights_dir=".",
model_weights_name="vitbase16_102_2025-01-03.pth",
img_size=384,
num_classes=num_classes,
compile=True
)
# Specify manual transforms for model_2
transforms = v2.Compose([
v2.Resize(384), #v2.Resize((384, 384)),
v2.CenterCrop((384, 384)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# Put models into evaluation mode and turn on inference mode
effnetb0_model.eval()
vitbase_model.eval()
# Predict function
def predict(image) -> Tuple[Dict, str, str]:
"""Transforms and performs a prediction on image and returns prediction and time taken.
"""
try:
# Start the timer
start_time = timer()
# Transform the target image and add a batch dimension
image = transforms(image).unsqueeze(0)
# Make prediction...
with torch.inference_mode():
# If the picture is food
if effnetb0_model(image)[:,1].cpu() >= 0.9981166124343872:
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.softmax(vitbase_model(image), dim=1)
# Calculate entropy
entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item()
# Create a prediction label and prediction probability dictionary for each prediction class
pred_classes_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(num_classes)}
# Get the top predicted class
top_class = max(pred_classes_and_probs, key=pred_classes_and_probs.get)
# If the image is likely to be an unknown category
if pred_probs[0][class_names.index(top_class)] <= 0.5 and entropy > 2.6:
# Create prediction label and prediction probability for class unknown and rescale the rest of predictions
pred_classes_and_probs["unknown"] = pred_probs.max() * 1.25
prob_sum = sum(pred_classes_and_probs.values())
pred_classes_and_probs = {key: value / prob_sum for key, value in pred_classes_and_probs.items()}
# Get the top predicted class
top_class = "unknown"
# Otherwise
else:
# Set all probabilites to zero except class unknown
pred_classes_and_probs = {class_names[i]: 0.0 for i in range(num_classes)}
pred_classes_and_probs["unknown"] = 1.0
# Get the top predicted class
top_class = "unknown"
# Get the description of the top predicted class
top_class_description = food_descriptions.get(top_class, "Description not available.")
# Calculate the prediction time
pred_time = f"{round(timer() - start_time, 1)} s."
# Return the prediction dictionary and prediction time
return pred_classes_and_probs, pred_time, top_class_description
except Exception as e:
print(f"[ERROR] {e}")
return {}, "Error during prediction.", "N/A"
# Configure and design the Gradio App
# Create title, description, and examples
title = "Transform-Eats Large<br>πŸ₯ͺπŸ₯—πŸ₯£πŸ₯©πŸπŸ£πŸ°"
description = f"""
A cutting-edge Vision Transformer (ViT) model to classify 101 delicious food types. Discover the power of AI in culinary recognition.
### Supported Food Types
{', '.join(class_names[:-1])}.
"""
# Configure the upload image area
upload_input = gr.Image(type="pil", label="Upload Image", sources=['upload'], show_label=True, mirror_webcam=False)
# Configure the sample image area
food_vision_examples = [["examples/" + example] for example in os.listdir("examples")]
# Author
article = "Created by Sergio Sanz."
# Create sliders for the thresholds
#prob = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.4, label="Probability Threshold")
#entropy = gr.Slider(minimum=0, maximum=4.615, step=0.5, value=2.5, label="Entropy Threshold")
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to outputs
inputs=upload_input, # inputs #[upload_input, model_dropdown]
outputs=[gr.Label(num_top_classes=3, label="Prediction"),
gr.Textbox(label="Prediction time:"),
gr.Textbox(label="Food Description:")], # outputs
examples=food_vision_examples, # Create examples list from "examples/" directory
cache_examples=True, # Cache the examples
title=title, # Title of the app
description=description, # Brief description of the app
article=article, # Created by...
theme="ocean") # Theme
# Launch the demo!
demo.launch()