import os import numpy as np import tensorflow as tf from huggingface_hub import hf_hub_download from nets import get_model_from_name from utils.utils import cvtColor, get_classes, letterbox_image, preprocess_input import tempfile class Classification: def __init__(self, model_choice): self.model_choice = model_choice self.classes_path = "src/cls_classes.txt" self.input_shape = (224, 224) self.alpha = 0.25 cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache") os.makedirs(cache_dir, exist_ok=True) self.model_path = hf_hub_download( repo_id="sudo-paras-shah/micro-expression-casme2", filename="ep097.weights.h5" if self.model_choice is "mobilenet" else "ep089.weights.h5", cache_dir=cache_dir ) # Load class names and model self.class_names, self.num_classes = get_classes(self.classes_path) self.load_model() def load_model(self): if self.model_choice == "mobilenet": self.model = get_model_from_name[self.model_choice]( input_shape=[self.input_shape[0], self.input_shape[1], 3], classes=self.num_classes, alpha=self.alpha ) else: self.model = get_model_from_name[self.model_choice]( input_shape=[self.input_shape[0], self.input_shape[1], 3], classes=self.num_classes ) self.model.load_weights(self.model_path) print("Model loaded from", self.model_path) print("Classes:", self.class_names) def detect_image(self, image): image = cvtColor(image) image = letterbox_image(image, [self.input_shape[1], self.input_shape[0]]) image = np.array(image, dtype=np.float32) image = preprocess_input(image) image = np.expand_dims(image, axis=0) preds = self.model.predict(image)[0] class_index = np.argmax(preds) class_name = self.class_names[class_index] probability = preds[class_index] return class_name, probability