MushroomClassificationCNN / src /streamlit_app.py
EnYa32's picture
Update src/streamlit_app.py
a798df7 verified
import os
import json
import numpy as np
import streamlit as st
from PIL import Image
import tensorflow as tf
# --------------------------------------------------
# Page config
# --------------------------------------------------
st.set_page_config(
page_title="ResNet50 Image Predictor",
page_icon="🧠",
layout="centered"
)
st.title("🧠 ResNet50 Image Predictor")
st.write("Classifies mushroom images using a trained ResNet50 model (architecture in code + weights from src/).")
# --------------------------------------------------
# Paths (fixed, HF friendly)
# --------------------------------------------------
MODEL_WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5"
CLASS_NAMES_PATH = "src/class_names4.json"
IMG_SIZE = (224, 224)
# --------------------------------------------------
# βœ… FIXED CLASS NAMES (YOUR TABLE) β€” fallback if JSON is wrong
# --------------------------------------------------
CLASS_NAMES_TABLE = [
"amanita", # 0
"boletus", # 1
"chantelle", # 2
"deterrimus", # 3
"rufus", # 4
"torminosus", # 5
"aurantiacum", # 6
"procera", # 7
"involutus", # 8
"russula", # 9
]
# --------------------------------------------------
# Helpers
# --------------------------------------------------
def _safe_listdir(path: str):
try:
return sorted(os.listdir(path))
except Exception as e:
return f"Could not list dir '{path}': {e}"
def load_class_names(path: str) -> list:
"""
Loads class names from JSON.
If JSON is missing or invalid or looks like ["0","1","2"...],
we fall back to CLASS_NAMES_TABLE.
"""
# If file not found -> fallback
if not os.path.exists(path):
return CLASS_NAMES_TABLE
# Try load JSON
try:
with open(path, "r", encoding="utf-8") as f:
names = json.load(f)
except Exception:
return CLASS_NAMES_TABLE
# Must be list and non-empty
if not isinstance(names, list) or len(names) == 0:
return CLASS_NAMES_TABLE
# If JSON contains only numbers as strings -> it's wrong -> fallback
# Example: ["0","1","2","3"...]
if all(isinstance(x, str) and x.strip().isdigit() for x in names):
return CLASS_NAMES_TABLE
# If JSON contains dict like {"amanita":0,...} -> convert to correct order
if isinstance(names, dict):
# Expect name->idx mapping
idx_to_name = {int(v): k for k, v in names.items()}
ordered = [idx_to_name[i] for i in range(len(idx_to_name))]
return ordered
# Otherwise: assume it's already correct list of names
return names
def build_resnet50_classifier(num_classes: int) -> tf.keras.Model:
base_model = tf.keras.applications.ResNet50(
weights="imagenet",
include_top=False,
input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3),
)
base_model.trainable = False
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(256, activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes, activation="softmax"),
])
# Compile not required for inference, but ok
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return model
@st.cache_resource
def load_model_and_assets(weights_path: str, class_names_path: str):
class_names = load_class_names(class_names_path)
model = build_resnet50_classifier(num_classes=len(class_names))
if not os.path.exists(weights_path):
raise FileNotFoundError(f"Missing weights file: {weights_path}")
model.load_weights(weights_path)
return model, class_names
def preprocess_image(pil_img: Image.Image) -> np.ndarray:
img = pil_img.convert("RGB").resize(IMG_SIZE)
x = np.array(img, dtype=np.float32)
x = np.expand_dims(x, axis=0)
x = tf.keras.applications.resnet50.preprocess_input(x)
return x
# --------------------------------------------------
# Debug info (HF)
# --------------------------------------------------
with st.expander("πŸ” Debug info (HuggingFace check)"):
st.write("Files in repo root:", _safe_listdir("."))
st.write("Files in src/:", _safe_listdir("src"))
st.write("Weights exists:", os.path.exists(MODEL_WEIGHTS_PATH), "->", MODEL_WEIGHTS_PATH)
st.write("Class names exists:", os.path.exists(CLASS_NAMES_PATH), "->", CLASS_NAMES_PATH)
st.write("TensorFlow version:", tf.__version__)
# --------------------------------------------------
# Load model + assets
# --------------------------------------------------
try:
model, class_names = load_model_and_assets(MODEL_WEIGHTS_PATH, CLASS_NAMES_PATH)
st.success("βœ… Model + weights loaded successfully!")
except Exception as e:
st.error("❌ Model could not be loaded.")
st.exception(e)
st.stop()
# --------------------------------------------------
# Show available classes (INDEX + NAME)
# --------------------------------------------------
with st.expander("πŸ§ͺ Available mushroom classes (you can test these)"):
st.write(f"Total classes: **{len(class_names)}**")
for i, name in enumerate(class_names):
st.write(f"**{i}** β€” **{name}**")
# --------------------------------------------------
# Image upload + prediction
# --------------------------------------------------
uploaded_file = st.file_uploader(
"Upload a mushroom image",
type=["jpg", "jpeg", "png", "webp"]
)
if uploaded_file is None:
st.info("πŸ‘† Please upload an image to start prediction.")
else:
img = Image.open(uploaded_file)
st.image(img, caption="Uploaded image", use_container_width=True)
x = preprocess_image(img)
preds = model.predict(x, verbose=0)[0]
pred_idx = int(np.argmax(preds))
pred_conf = float(preds[pred_idx])
pred_name = class_names[pred_idx] if 0 <= pred_idx < len(class_names) else f"Class {pred_idx}"
st.subheader("βœ… Prediction")
st.write(f"**Predicted class index:** {pred_idx}")
st.write(f"**Predicted class name:** {pred_name}")
st.write(f"**Confidence:** {pred_conf:.4f}")
st.subheader("🏁 Top-3 predictions")
top3_idx = np.argsort(preds)[::-1][:3]
for rank, idx in enumerate(top3_idx, start=1):
idx = int(idx)
name = class_names[idx] if 0 <= idx < len(class_names) else f"Class {idx}"
prob = float(preds[idx])
st.write(f"{rank}. **{name}** (class {idx}) β€” **{prob*100:.2f}%**")