Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image, ImageOps | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import cv2 # NEW | |
| st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered") | |
| MODEL_PATH = Path(__file__).resolve().parent / "best_mnist_cnn.keras" | |
| CLASS_NAMES = [str(i) for i in range(10)] | |
| def load_model(): | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError( | |
| f"Model not found: {MODEL_PATH.name}. Put it in the repo root (same folder as app.py)." | |
| ) | |
| return tf.keras.models.load_model(MODEL_PATH) | |
| def center_digit(binary_28: np.ndarray) -> np.ndarray: | |
| """ | |
| Centers the digit in a 28x28 binary image using its bounding box. | |
| Expects values in [0,255] with digit in white (255) on black (0). | |
| """ | |
| ys, xs = np.where(binary_28 > 0) | |
| if len(xs) == 0 or len(ys) == 0: | |
| return binary_28 # empty image fallback | |
| x_min, x_max = xs.min(), xs.max() | |
| y_min, y_max = ys.min(), ys.max() | |
| digit = binary_28[y_min:y_max+1, x_min:x_max+1] | |
| # Create blank canvas and paste centered | |
| canvas = np.zeros((28, 28), dtype=np.uint8) | |
| h, w = digit.shape | |
| top = (28 - h) // 2 | |
| left = (28 - w) // 2 | |
| # Clip if something goes wrong (safety) | |
| digit = digit[:min(h, 28), :min(w, 28)] | |
| canvas[top:top+digit.shape[0], left:left+digit.shape[1]] = digit | |
| return canvas | |
| def preprocess_image(pil_img: Image.Image, | |
| do_threshold: bool = True, | |
| threshold_value: int = 140, | |
| do_center: bool = True) -> np.ndarray: | |
| """ | |
| Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1). | |
| Uses thresholding + optional centering for better robustness. | |
| """ | |
| # 1) Convert to grayscale | |
| img = pil_img.convert("L") | |
| # 2) Make square (pad) to avoid distortion | |
| img = ImageOps.pad(img, (max(img.size), max(img.size)), | |
| method=Image.Resampling.LANCZOS, color=255) | |
| # 3) Resize to 28x28 | |
| img = img.resize((28, 28), Image.Resampling.LANCZOS) | |
| # Convert to numpy | |
| arr = np.array(img).astype(np.uint8) | |
| # 4) Auto-invert if background is dark | |
| # We want: black background (0), white digit (255) like MNIST (after threshold) | |
| if arr.mean() < 127: | |
| arr = 255 - arr | |
| # 5) Mild denoise (optional but helpful for uploaded images) | |
| arr = cv2.GaussianBlur(arr, (3, 3), 0) | |
| # 6) Thresholding (turn to clean black/white) | |
| if do_threshold: | |
| _, arr_bin = cv2.threshold(arr, threshold_value, 255, cv2.THRESH_BINARY_INV) | |
| # After THRESH_BINARY_INV: digit becomes white (255) if it was dark on light background | |
| # Ensure digit is white on black: | |
| # arr_bin currently has digit as white and background as black -> perfect | |
| else: | |
| # If no threshold: convert to "digit white on black" approx by inverting intensity | |
| arr_bin = 255 - arr | |
| # 7) Center the digit (optional) | |
| if do_center: | |
| arr_bin = center_digit(arr_bin) | |
| # 8) Normalize to 0..1 (MNIST style: digit bright) | |
| arr_norm = arr_bin.astype(np.float32) / 255.0 | |
| # Return shape (1, 28, 28, 1) | |
| return arr_norm.reshape(1, 28, 28, 1) | |
| def plot_probabilities(probs: np.ndarray): | |
| plt.figure(figsize=(7, 3)) | |
| plt.bar(range(10), probs) | |
| plt.xticks(range(10), CLASS_NAMES) | |
| plt.xlabel("Digit") | |
| plt.ylabel("Probability") | |
| plt.title("Prediction Probabilities") | |
| st.pyplot(plt.gcf()) | |
| plt.close() | |
| st.title("✍️ MNIST Digit Recognizer") | |
| st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.") | |
| with st.expander("⚙️ Preprocessing options"): | |
| do_threshold = st.checkbox("Use thresholding (recommended)", value=True) | |
| threshold_value = st.slider("Threshold value", 0, 255, 140, 5) | |
| do_center = st.checkbox("Center the digit (recommended)", value=True) | |
| with st.expander("ℹ️ Tips for best results"): | |
| st.markdown( | |
| "- Use a **single digit** centered in the image.\n" | |
| "- Prefer **white background** and **dark digit**.\n" | |
| "- PNG/JPG works fine.\n" | |
| "- If prediction is wrong, try adjusting the **threshold** slider." | |
| ) | |
| uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"]) | |
| try: | |
| model = load_model() | |
| except Exception as e: | |
| st.error(str(e)) | |
| st.stop() | |
| if uploaded is not None: | |
| pil_img = Image.open(uploaded) | |
| st.subheader("Your image") | |
| st.image(pil_img, use_container_width=True) | |
| x = preprocess_image(pil_img, do_threshold=do_threshold, | |
| threshold_value=threshold_value, do_center=do_center) | |
| probs = model.predict(x, verbose=0)[0] | |
| pred = int(np.argmax(probs)) | |
| conf = float(np.max(probs)) | |
| st.subheader("Prediction") | |
| st.metric("Predicted Digit", pred) | |
| st.write(f"Confidence: **{conf:.3f}**") | |
| st.subheader("Model probabilities") | |
| plot_probabilities(probs) | |
| st.subheader("Preprocessed 28×28 image (what the model sees)") | |
| st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True) | |
| else: | |
| st.info("Upload an image to get a prediction.") |