import os import json import numpy as np import streamlit as st from PIL import Image import tensorflow as tf # -------------------------------------------------- # Page config # -------------------------------------------------- st.set_page_config( page_title="ResNet50 Image Predictor", page_icon="๐Ÿง ", layout="centered" ) st.title("๐Ÿง  ResNet50 Image Predictor") st.write("Classifies mushroom images using a trained ResNet50 model (architecture in code + weights from src/).") # -------------------------------------------------- # Paths (fixed, HF friendly) # -------------------------------------------------- MODEL_WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5" CLASS_NAMES_PATH = "src/class_names4.json" IMG_SIZE = (224, 224) # -------------------------------------------------- # โœ… FIXED CLASS NAMES (YOUR TABLE) โ€” fallback if JSON is wrong # -------------------------------------------------- CLASS_NAMES_TABLE = [ "amanita", # 0 "boletus", # 1 "chantelle", # 2 "deterrimus", # 3 "rufus", # 4 "torminosus", # 5 "aurantiacum", # 6 "procera", # 7 "involutus", # 8 "russula", # 9 ] # -------------------------------------------------- # Helpers # -------------------------------------------------- def _safe_listdir(path: str): try: return sorted(os.listdir(path)) except Exception as e: return f"Could not list dir '{path}': {e}" def load_class_names(path: str) -> list: """ Loads class names from JSON. If JSON is missing or invalid or looks like ["0","1","2"...], we fall back to CLASS_NAMES_TABLE. """ # If file not found -> fallback if not os.path.exists(path): return CLASS_NAMES_TABLE # Try load JSON try: with open(path, "r", encoding="utf-8") as f: names = json.load(f) except Exception: return CLASS_NAMES_TABLE # Must be list and non-empty if not isinstance(names, list) or len(names) == 0: return CLASS_NAMES_TABLE # If JSON contains only numbers as strings -> it's wrong -> fallback # Example: ["0","1","2","3"...] if all(isinstance(x, str) and x.strip().isdigit() for x in names): return CLASS_NAMES_TABLE # If JSON contains dict like {"amanita":0,...} -> convert to correct order if isinstance(names, dict): # Expect name->idx mapping idx_to_name = {int(v): k for k, v in names.items()} ordered = [idx_to_name[i] for i in range(len(idx_to_name))] return ordered # Otherwise: assume it's already correct list of names return names def build_resnet50_classifier(num_classes: int) -> tf.keras.Model: base_model = tf.keras.applications.ResNet50( weights="imagenet", include_top=False, input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3), ) base_model.trainable = False model = tf.keras.Sequential([ base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(256, activation="relu"), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(num_classes, activation="softmax"), ]) # Compile not required for inference, but ok model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) return model @st.cache_resource def load_model_and_assets(weights_path: str, class_names_path: str): class_names = load_class_names(class_names_path) model = build_resnet50_classifier(num_classes=len(class_names)) if not os.path.exists(weights_path): raise FileNotFoundError(f"Missing weights file: {weights_path}") model.load_weights(weights_path) return model, class_names def preprocess_image(pil_img: Image.Image) -> np.ndarray: img = pil_img.convert("RGB").resize(IMG_SIZE) x = np.array(img, dtype=np.float32) x = np.expand_dims(x, axis=0) x = tf.keras.applications.resnet50.preprocess_input(x) return x # -------------------------------------------------- # Debug info (HF) # -------------------------------------------------- with st.expander("๐Ÿ” Debug info (HuggingFace check)"): st.write("Files in repo root:", _safe_listdir(".")) st.write("Files in src/:", _safe_listdir("src")) st.write("Weights exists:", os.path.exists(MODEL_WEIGHTS_PATH), "->", MODEL_WEIGHTS_PATH) st.write("Class names exists:", os.path.exists(CLASS_NAMES_PATH), "->", CLASS_NAMES_PATH) st.write("TensorFlow version:", tf.__version__) # -------------------------------------------------- # Load model + assets # -------------------------------------------------- try: model, class_names = load_model_and_assets(MODEL_WEIGHTS_PATH, CLASS_NAMES_PATH) st.success("โœ… Model + weights loaded successfully!") except Exception as e: st.error("โŒ Model could not be loaded.") st.exception(e) st.stop() # -------------------------------------------------- # Show available classes (INDEX + NAME) # -------------------------------------------------- with st.expander("๐Ÿงช Available mushroom classes (you can test these)"): st.write(f"Total classes: **{len(class_names)}**") for i, name in enumerate(class_names): st.write(f"**{i}** โ€” **{name}**") # -------------------------------------------------- # Image upload + prediction # -------------------------------------------------- uploaded_file = st.file_uploader( "Upload a mushroom image", type=["jpg", "jpeg", "png", "webp"] ) if uploaded_file is None: st.info("๐Ÿ‘† Please upload an image to start prediction.") else: img = Image.open(uploaded_file) st.image(img, caption="Uploaded image", use_container_width=True) x = preprocess_image(img) preds = model.predict(x, verbose=0)[0] pred_idx = int(np.argmax(preds)) pred_conf = float(preds[pred_idx]) pred_name = class_names[pred_idx] if 0 <= pred_idx < len(class_names) else f"Class {pred_idx}" st.subheader("โœ… Prediction") st.write(f"**Predicted class index:** {pred_idx}") st.write(f"**Predicted class name:** {pred_name}") st.write(f"**Confidence:** {pred_conf:.4f}") st.subheader("๐Ÿ Top-3 predictions") top3_idx = np.argsort(preds)[::-1][:3] for rank, idx in enumerate(top3_idx, start=1): idx = int(idx) name = class_names[idx] if 0 <= idx < len(class_names) else f"Class {idx}" prob = float(preds[idx]) st.write(f"{rank}. **{name}** (class {idx}) โ€” **{prob*100:.2f}%**")