Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| import tensorflow as tf | |
| # -------------------------------------------------- | |
| # Page config | |
| # -------------------------------------------------- | |
| st.set_page_config( | |
| page_title="ResNet50 Image Predictor", | |
| page_icon="π§ ", | |
| layout="centered" | |
| ) | |
| st.title("π§ ResNet50 Image Predictor") | |
| st.write("Classifies mushroom images using a trained ResNet50 model (architecture in code + weights from src/).") | |
| # -------------------------------------------------- | |
| # Paths (fixed, HF friendly) | |
| # -------------------------------------------------- | |
| MODEL_WEIGHTS_PATH = "src/resnet50_weights_noBN3.h5" | |
| CLASS_NAMES_PATH = "src/class_names4.json" | |
| IMG_SIZE = (224, 224) | |
| # -------------------------------------------------- | |
| # β FIXED CLASS NAMES (YOUR TABLE) β fallback if JSON is wrong | |
| # -------------------------------------------------- | |
| CLASS_NAMES_TABLE = [ | |
| "amanita", # 0 | |
| "boletus", # 1 | |
| "chantelle", # 2 | |
| "deterrimus", # 3 | |
| "rufus", # 4 | |
| "torminosus", # 5 | |
| "aurantiacum", # 6 | |
| "procera", # 7 | |
| "involutus", # 8 | |
| "russula", # 9 | |
| ] | |
| # -------------------------------------------------- | |
| # Helpers | |
| # -------------------------------------------------- | |
| def _safe_listdir(path: str): | |
| try: | |
| return sorted(os.listdir(path)) | |
| except Exception as e: | |
| return f"Could not list dir '{path}': {e}" | |
| def load_class_names(path: str) -> list: | |
| """ | |
| Loads class names from JSON. | |
| If JSON is missing or invalid or looks like ["0","1","2"...], | |
| we fall back to CLASS_NAMES_TABLE. | |
| """ | |
| # If file not found -> fallback | |
| if not os.path.exists(path): | |
| return CLASS_NAMES_TABLE | |
| # Try load JSON | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| names = json.load(f) | |
| except Exception: | |
| return CLASS_NAMES_TABLE | |
| # Must be list and non-empty | |
| if not isinstance(names, list) or len(names) == 0: | |
| return CLASS_NAMES_TABLE | |
| # If JSON contains only numbers as strings -> it's wrong -> fallback | |
| # Example: ["0","1","2","3"...] | |
| if all(isinstance(x, str) and x.strip().isdigit() for x in names): | |
| return CLASS_NAMES_TABLE | |
| # If JSON contains dict like {"amanita":0,...} -> convert to correct order | |
| if isinstance(names, dict): | |
| # Expect name->idx mapping | |
| idx_to_name = {int(v): k for k, v in names.items()} | |
| ordered = [idx_to_name[i] for i in range(len(idx_to_name))] | |
| return ordered | |
| # Otherwise: assume it's already correct list of names | |
| return names | |
| def build_resnet50_classifier(num_classes: int) -> tf.keras.Model: | |
| base_model = tf.keras.applications.ResNet50( | |
| weights="imagenet", | |
| include_top=False, | |
| input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3), | |
| ) | |
| base_model.trainable = False | |
| model = tf.keras.Sequential([ | |
| base_model, | |
| tf.keras.layers.GlobalAveragePooling2D(), | |
| tf.keras.layers.Dense(256, activation="relu"), | |
| tf.keras.layers.Dropout(0.5), | |
| tf.keras.layers.Dense(num_classes, activation="softmax"), | |
| ]) | |
| # Compile not required for inference, but ok | |
| model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | |
| return model | |
| def load_model_and_assets(weights_path: str, class_names_path: str): | |
| class_names = load_class_names(class_names_path) | |
| model = build_resnet50_classifier(num_classes=len(class_names)) | |
| if not os.path.exists(weights_path): | |
| raise FileNotFoundError(f"Missing weights file: {weights_path}") | |
| model.load_weights(weights_path) | |
| return model, class_names | |
| def preprocess_image(pil_img: Image.Image) -> np.ndarray: | |
| img = pil_img.convert("RGB").resize(IMG_SIZE) | |
| x = np.array(img, dtype=np.float32) | |
| x = np.expand_dims(x, axis=0) | |
| x = tf.keras.applications.resnet50.preprocess_input(x) | |
| return x | |
| # -------------------------------------------------- | |
| # Debug info (HF) | |
| # -------------------------------------------------- | |
| with st.expander("π Debug info (HuggingFace check)"): | |
| st.write("Files in repo root:", _safe_listdir(".")) | |
| st.write("Files in src/:", _safe_listdir("src")) | |
| st.write("Weights exists:", os.path.exists(MODEL_WEIGHTS_PATH), "->", MODEL_WEIGHTS_PATH) | |
| st.write("Class names exists:", os.path.exists(CLASS_NAMES_PATH), "->", CLASS_NAMES_PATH) | |
| st.write("TensorFlow version:", tf.__version__) | |
| # -------------------------------------------------- | |
| # Load model + assets | |
| # -------------------------------------------------- | |
| try: | |
| model, class_names = load_model_and_assets(MODEL_WEIGHTS_PATH, CLASS_NAMES_PATH) | |
| st.success("β Model + weights loaded successfully!") | |
| except Exception as e: | |
| st.error("β Model could not be loaded.") | |
| st.exception(e) | |
| st.stop() | |
| # -------------------------------------------------- | |
| # Show available classes (INDEX + NAME) | |
| # -------------------------------------------------- | |
| with st.expander("π§ͺ Available mushroom classes (you can test these)"): | |
| st.write(f"Total classes: **{len(class_names)}**") | |
| for i, name in enumerate(class_names): | |
| st.write(f"**{i}** β **{name}**") | |
| # -------------------------------------------------- | |
| # Image upload + prediction | |
| # -------------------------------------------------- | |
| uploaded_file = st.file_uploader( | |
| "Upload a mushroom image", | |
| type=["jpg", "jpeg", "png", "webp"] | |
| ) | |
| if uploaded_file is None: | |
| st.info("π Please upload an image to start prediction.") | |
| else: | |
| img = Image.open(uploaded_file) | |
| st.image(img, caption="Uploaded image", use_container_width=True) | |
| x = preprocess_image(img) | |
| preds = model.predict(x, verbose=0)[0] | |
| pred_idx = int(np.argmax(preds)) | |
| pred_conf = float(preds[pred_idx]) | |
| pred_name = class_names[pred_idx] if 0 <= pred_idx < len(class_names) else f"Class {pred_idx}" | |
| st.subheader("β Prediction") | |
| st.write(f"**Predicted class index:** {pred_idx}") | |
| st.write(f"**Predicted class name:** {pred_name}") | |
| st.write(f"**Confidence:** {pred_conf:.4f}") | |
| st.subheader("π Top-3 predictions") | |
| top3_idx = np.argsort(preds)[::-1][:3] | |
| for rank, idx in enumerate(top3_idx, start=1): | |
| idx = int(idx) | |
| name = class_names[idx] if 0 <= idx < len(class_names) else f"Class {idx}" | |
| prob = float(preds[idx]) | |
| st.write(f"{rank}. **{name}** (class {idx}) β **{prob*100:.2f}%**") |