|
|
import os |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from huggingface_hub import hf_hub_download |
|
|
from nets import get_model_from_name |
|
|
from utils.utils import cvtColor, get_classes, letterbox_image, preprocess_input |
|
|
import tempfile |
|
|
|
|
|
class Classification: |
|
|
def __init__(self, model_choice): |
|
|
self.model_choice = model_choice |
|
|
self.classes_path = "src/cls_classes.txt" |
|
|
self.input_shape = (224, 224) |
|
|
self.alpha = 0.25 |
|
|
|
|
|
cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache") |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
self.model_path = hf_hub_download( |
|
|
repo_id="sudo-paras-shah/micro-expression-casme2", |
|
|
filename="ep097.weights.h5" if self.model_choice is "mobilenet" else "ep089.weights.h5", |
|
|
cache_dir=cache_dir |
|
|
) |
|
|
|
|
|
|
|
|
self.class_names, self.num_classes = get_classes(self.classes_path) |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
if self.model_choice == "mobilenet": |
|
|
self.model = get_model_from_name[self.model_choice]( |
|
|
input_shape=[self.input_shape[0], self.input_shape[1], 3], |
|
|
classes=self.num_classes, |
|
|
alpha=self.alpha |
|
|
) |
|
|
else: |
|
|
self.model = get_model_from_name[self.model_choice]( |
|
|
input_shape=[self.input_shape[0], self.input_shape[1], 3], |
|
|
classes=self.num_classes |
|
|
) |
|
|
|
|
|
self.model.load_weights(self.model_path) |
|
|
print("Model loaded from", self.model_path) |
|
|
print("Classes:", self.class_names) |
|
|
|
|
|
def detect_image(self, image): |
|
|
image = cvtColor(image) |
|
|
image = letterbox_image(image, [self.input_shape[1], self.input_shape[0]]) |
|
|
image = np.array(image, dtype=np.float32) |
|
|
image = preprocess_input(image) |
|
|
image = np.expand_dims(image, axis=0) |
|
|
|
|
|
preds = self.model.predict(image)[0] |
|
|
class_index = np.argmax(preds) |
|
|
class_name = self.class_names[class_index] |
|
|
probability = preds[class_index] |
|
|
|
|
|
return class_name, probability |
|
|
|