|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import json |
|
|
import gradio as gr |
|
|
from model import create_vitbase_model, create_effnetb0 |
|
|
from timeit import default_timer as timer |
|
|
from typing import Tuple, Dict |
|
|
from torchvision.transforms import v2 |
|
|
|
|
|
|
|
|
|
|
|
food_vision_class_names_path = "class_names.txt" |
|
|
with open(food_vision_class_names_path, "r") as f: |
|
|
class_names = f.read().splitlines() |
|
|
|
|
|
|
|
|
num_classes = len(class_names) - 1 |
|
|
|
|
|
|
|
|
food_descriptions_json = "food_descriptions.json" |
|
|
with open(food_descriptions_json, 'r') as f: |
|
|
food_descriptions = json.load(f) |
|
|
|
|
|
|
|
|
classification_model_name_path = "effnetb0_classif_epoch13.pth" |
|
|
effnetb0_model = create_effnetb0( |
|
|
model_weights_dir=".", |
|
|
model_weights_name=classification_model_name_path, |
|
|
num_classes=2 |
|
|
) |
|
|
|
|
|
|
|
|
vitbase_model = create_vitbase_model( |
|
|
model_weights_dir=".", |
|
|
model_weights_name="vitbase16_102_2025-01-03.pth", |
|
|
img_size=384, |
|
|
num_classes=num_classes, |
|
|
compile=True |
|
|
) |
|
|
|
|
|
|
|
|
transforms = v2.Compose([ |
|
|
v2.Resize(384), |
|
|
v2.CenterCrop((384, 384)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
effnetb0_model.eval() |
|
|
vitbase_model.eval() |
|
|
|
|
|
|
|
|
def predict(image) -> Tuple[Dict, str, str]: |
|
|
|
|
|
"""Transforms and performs a prediction on image and returns prediction and time taken. |
|
|
""" |
|
|
try: |
|
|
|
|
|
start_time = timer() |
|
|
|
|
|
|
|
|
image = transforms(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
|
|
|
|
|
|
if effnetb0_model(image)[:,1].cpu() >= 0.9981166124343872: |
|
|
|
|
|
|
|
|
pred_probs = torch.softmax(vitbase_model(image), dim=1) |
|
|
|
|
|
|
|
|
entropy = -torch.sum(pred_probs * torch.log(pred_probs), dim=1).item() |
|
|
|
|
|
|
|
|
pred_classes_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(num_classes)} |
|
|
|
|
|
|
|
|
top_class = max(pred_classes_and_probs, key=pred_classes_and_probs.get) |
|
|
|
|
|
|
|
|
if pred_probs[0][class_names.index(top_class)] <= 0.5 and entropy > 2.6: |
|
|
|
|
|
|
|
|
pred_classes_and_probs["unknown"] = pred_probs.max() * 1.25 |
|
|
prob_sum = sum(pred_classes_and_probs.values()) |
|
|
pred_classes_and_probs = {key: value / prob_sum for key, value in pred_classes_and_probs.items()} |
|
|
|
|
|
|
|
|
top_class = "unknown" |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
pred_classes_and_probs = {class_names[i]: 0.0 for i in range(num_classes)} |
|
|
pred_classes_and_probs["unknown"] = 1.0 |
|
|
|
|
|
|
|
|
top_class = "unknown" |
|
|
|
|
|
|
|
|
top_class_description = food_descriptions.get(top_class, "Description not available.") |
|
|
|
|
|
|
|
|
pred_time = f"{round(timer() - start_time, 1)} s." |
|
|
|
|
|
|
|
|
return pred_classes_and_probs, pred_time, top_class_description |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR] {e}") |
|
|
return {}, "Error during prediction.", "N/A" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "Transform-Eats Large<br>π₯ͺπ₯π₯£π₯©ππ£π°" |
|
|
description = f""" |
|
|
A cutting-edge Vision Transformer (ViT) model to classify 101 delicious food types. Discover the power of AI in culinary recognition. |
|
|
|
|
|
### Supported Food Types |
|
|
{', '.join(class_names[:-1])}. |
|
|
""" |
|
|
|
|
|
|
|
|
upload_input = gr.Image(type="pil", label="Upload Image", sources=['upload'], show_label=True, mirror_webcam=False) |
|
|
|
|
|
|
|
|
food_vision_examples = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
|
|
|
article = "Created by Sergio Sanz." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
|
inputs=upload_input, |
|
|
outputs=[gr.Label(num_top_classes=3, label="Prediction"), |
|
|
gr.Textbox(label="Prediction time:"), |
|
|
gr.Textbox(label="Food Description:")], |
|
|
examples=food_vision_examples, |
|
|
cache_examples=True, |
|
|
title=title, |
|
|
description=description, |
|
|
article=article, |
|
|
theme="ocean") |
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|