micro-expression-recognition / src /classification.py
sudo-paras-shah's picture
Add async processing
ec56169
import os
import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download
from nets import get_model_from_name
from utils.utils import cvtColor, get_classes, letterbox_image, preprocess_input
import tempfile
class Classification:
def __init__(self, model_choice):
self.model_choice = model_choice
self.classes_path = "src/cls_classes.txt"
self.input_shape = (224, 224)
self.alpha = 0.25
cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache")
os.makedirs(cache_dir, exist_ok=True)
self.model_path = hf_hub_download(
repo_id="sudo-paras-shah/micro-expression-casme2",
filename="ep097.weights.h5" if self.model_choice is "mobilenet" else "ep089.weights.h5",
cache_dir=cache_dir
)
# Load class names and model
self.class_names, self.num_classes = get_classes(self.classes_path)
self.load_model()
def load_model(self):
if self.model_choice == "mobilenet":
self.model = get_model_from_name[self.model_choice](
input_shape=[self.input_shape[0], self.input_shape[1], 3],
classes=self.num_classes,
alpha=self.alpha
)
else:
self.model = get_model_from_name[self.model_choice](
input_shape=[self.input_shape[0], self.input_shape[1], 3],
classes=self.num_classes
)
self.model.load_weights(self.model_path)
print("Model loaded from", self.model_path)
print("Classes:", self.class_names)
def detect_image(self, image):
image = cvtColor(image)
image = letterbox_image(image, [self.input_shape[1], self.input_shape[0]])
image = np.array(image, dtype=np.float32)
image = preprocess_input(image)
image = np.expand_dims(image, axis=0)
preds = self.model.predict(image)[0]
class_index = np.argmax(preds)
class_name = self.class_names[class_index]
probability = preds[class_index]
return class_name, probability