import numpy as np import streamlit as st from PIL import Image, ImageOps import tensorflow as tf import matplotlib.pyplot as plt from pathlib import Path import cv2 # NEW st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered") MODEL_PATH = Path(__file__).resolve().parent / "best_mnist_cnn.keras" CLASS_NAMES = [str(i) for i in range(10)] @st.cache_resource def load_model(): if not MODEL_PATH.exists(): raise FileNotFoundError( f"Model not found: {MODEL_PATH.name}. Put it in the repo root (same folder as app.py)." ) return tf.keras.models.load_model(MODEL_PATH) def center_digit(binary_28: np.ndarray) -> np.ndarray: """ Centers the digit in a 28x28 binary image using its bounding box. Expects values in [0,255] with digit in white (255) on black (0). """ ys, xs = np.where(binary_28 > 0) if len(xs) == 0 or len(ys) == 0: return binary_28 # empty image fallback x_min, x_max = xs.min(), xs.max() y_min, y_max = ys.min(), ys.max() digit = binary_28[y_min:y_max+1, x_min:x_max+1] # Create blank canvas and paste centered canvas = np.zeros((28, 28), dtype=np.uint8) h, w = digit.shape top = (28 - h) // 2 left = (28 - w) // 2 # Clip if something goes wrong (safety) digit = digit[:min(h, 28), :min(w, 28)] canvas[top:top+digit.shape[0], left:left+digit.shape[1]] = digit return canvas def preprocess_image(pil_img: Image.Image, do_threshold: bool = True, threshold_value: int = 140, do_center: bool = True) -> np.ndarray: """ Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1). Uses thresholding + optional centering for better robustness. """ # 1) Convert to grayscale img = pil_img.convert("L") # 2) Make square (pad) to avoid distortion img = ImageOps.pad(img, (max(img.size), max(img.size)), method=Image.Resampling.LANCZOS, color=255) # 3) Resize to 28x28 img = img.resize((28, 28), Image.Resampling.LANCZOS) # Convert to numpy arr = np.array(img).astype(np.uint8) # 4) Auto-invert if background is dark # We want: black background (0), white digit (255) like MNIST (after threshold) if arr.mean() < 127: arr = 255 - arr # 5) Mild denoise (optional but helpful for uploaded images) arr = cv2.GaussianBlur(arr, (3, 3), 0) # 6) Thresholding (turn to clean black/white) if do_threshold: _, arr_bin = cv2.threshold(arr, threshold_value, 255, cv2.THRESH_BINARY_INV) # After THRESH_BINARY_INV: digit becomes white (255) if it was dark on light background # Ensure digit is white on black: # arr_bin currently has digit as white and background as black -> perfect else: # If no threshold: convert to "digit white on black" approx by inverting intensity arr_bin = 255 - arr # 7) Center the digit (optional) if do_center: arr_bin = center_digit(arr_bin) # 8) Normalize to 0..1 (MNIST style: digit bright) arr_norm = arr_bin.astype(np.float32) / 255.0 # Return shape (1, 28, 28, 1) return arr_norm.reshape(1, 28, 28, 1) def plot_probabilities(probs: np.ndarray): plt.figure(figsize=(7, 3)) plt.bar(range(10), probs) plt.xticks(range(10), CLASS_NAMES) plt.xlabel("Digit") plt.ylabel("Probability") plt.title("Prediction Probabilities") st.pyplot(plt.gcf()) plt.close() st.title("✍️ MNIST Digit Recognizer") st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.") with st.expander("⚙️ Preprocessing options"): do_threshold = st.checkbox("Use thresholding (recommended)", value=True) threshold_value = st.slider("Threshold value", 0, 255, 140, 5) do_center = st.checkbox("Center the digit (recommended)", value=True) with st.expander("ℹ️ Tips for best results"): st.markdown( "- Use a **single digit** centered in the image.\n" "- Prefer **white background** and **dark digit**.\n" "- PNG/JPG works fine.\n" "- If prediction is wrong, try adjusting the **threshold** slider." ) uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"]) try: model = load_model() except Exception as e: st.error(str(e)) st.stop() if uploaded is not None: pil_img = Image.open(uploaded) st.subheader("Your image") st.image(pil_img, use_container_width=True) x = preprocess_image(pil_img, do_threshold=do_threshold, threshold_value=threshold_value, do_center=do_center) probs = model.predict(x, verbose=0)[0] pred = int(np.argmax(probs)) conf = float(np.max(probs)) st.subheader("Prediction") st.metric("Predicted Digit", pred) st.write(f"Confidence: **{conf:.3f}**") st.subheader("Model probabilities") plot_probabilities(probs) st.subheader("Preprocessed 28×28 image (what the model sees)") st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True) else: st.info("Upload an image to get a prediction.")