DigitRecognitionCNN / src /streamlit_app.py
EnYa32's picture
Update src/streamlit_app.py
f303063 verified
import numpy as np
import streamlit as st
from PIL import Image, ImageOps
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
import cv2 # NEW
st.set_page_config(page_title="MNIST Digit Recognizer", page_icon="✍️", layout="centered")
MODEL_PATH = Path(__file__).resolve().parent / "best_mnist_cnn.keras"
CLASS_NAMES = [str(i) for i in range(10)]
@st.cache_resource
def load_model():
if not MODEL_PATH.exists():
raise FileNotFoundError(
f"Model not found: {MODEL_PATH.name}. Put it in the repo root (same folder as app.py)."
)
return tf.keras.models.load_model(MODEL_PATH)
def center_digit(binary_28: np.ndarray) -> np.ndarray:
"""
Centers the digit in a 28x28 binary image using its bounding box.
Expects values in [0,255] with digit in white (255) on black (0).
"""
ys, xs = np.where(binary_28 > 0)
if len(xs) == 0 or len(ys) == 0:
return binary_28 # empty image fallback
x_min, x_max = xs.min(), xs.max()
y_min, y_max = ys.min(), ys.max()
digit = binary_28[y_min:y_max+1, x_min:x_max+1]
# Create blank canvas and paste centered
canvas = np.zeros((28, 28), dtype=np.uint8)
h, w = digit.shape
top = (28 - h) // 2
left = (28 - w) // 2
# Clip if something goes wrong (safety)
digit = digit[:min(h, 28), :min(w, 28)]
canvas[top:top+digit.shape[0], left:left+digit.shape[1]] = digit
return canvas
def preprocess_image(pil_img: Image.Image,
do_threshold: bool = True,
threshold_value: int = 140,
do_center: bool = True) -> np.ndarray:
"""
Convert uploaded image to MNIST-like 28x28 grayscale tensor (1, 28, 28, 1).
Uses thresholding + optional centering for better robustness.
"""
# 1) Convert to grayscale
img = pil_img.convert("L")
# 2) Make square (pad) to avoid distortion
img = ImageOps.pad(img, (max(img.size), max(img.size)),
method=Image.Resampling.LANCZOS, color=255)
# 3) Resize to 28x28
img = img.resize((28, 28), Image.Resampling.LANCZOS)
# Convert to numpy
arr = np.array(img).astype(np.uint8)
# 4) Auto-invert if background is dark
# We want: black background (0), white digit (255) like MNIST (after threshold)
if arr.mean() < 127:
arr = 255 - arr
# 5) Mild denoise (optional but helpful for uploaded images)
arr = cv2.GaussianBlur(arr, (3, 3), 0)
# 6) Thresholding (turn to clean black/white)
if do_threshold:
_, arr_bin = cv2.threshold(arr, threshold_value, 255, cv2.THRESH_BINARY_INV)
# After THRESH_BINARY_INV: digit becomes white (255) if it was dark on light background
# Ensure digit is white on black:
# arr_bin currently has digit as white and background as black -> perfect
else:
# If no threshold: convert to "digit white on black" approx by inverting intensity
arr_bin = 255 - arr
# 7) Center the digit (optional)
if do_center:
arr_bin = center_digit(arr_bin)
# 8) Normalize to 0..1 (MNIST style: digit bright)
arr_norm = arr_bin.astype(np.float32) / 255.0
# Return shape (1, 28, 28, 1)
return arr_norm.reshape(1, 28, 28, 1)
def plot_probabilities(probs: np.ndarray):
plt.figure(figsize=(7, 3))
plt.bar(range(10), probs)
plt.xticks(range(10), CLASS_NAMES)
plt.xlabel("Digit")
plt.ylabel("Probability")
plt.title("Prediction Probabilities")
st.pyplot(plt.gcf())
plt.close()
st.title("✍️ MNIST Digit Recognizer")
st.write("Upload an image of a handwritten digit (0–9). The model will predict the digit.")
with st.expander("⚙️ Preprocessing options"):
do_threshold = st.checkbox("Use thresholding (recommended)", value=True)
threshold_value = st.slider("Threshold value", 0, 255, 140, 5)
do_center = st.checkbox("Center the digit (recommended)", value=True)
with st.expander("ℹ️ Tips for best results"):
st.markdown(
"- Use a **single digit** centered in the image.\n"
"- Prefer **white background** and **dark digit**.\n"
"- PNG/JPG works fine.\n"
"- If prediction is wrong, try adjusting the **threshold** slider."
)
uploaded = st.file_uploader("Upload an image (png/jpg/jpeg)", type=["png", "jpg", "jpeg"])
try:
model = load_model()
except Exception as e:
st.error(str(e))
st.stop()
if uploaded is not None:
pil_img = Image.open(uploaded)
st.subheader("Your image")
st.image(pil_img, use_container_width=True)
x = preprocess_image(pil_img, do_threshold=do_threshold,
threshold_value=threshold_value, do_center=do_center)
probs = model.predict(x, verbose=0)[0]
pred = int(np.argmax(probs))
conf = float(np.max(probs))
st.subheader("Prediction")
st.metric("Predicted Digit", pred)
st.write(f"Confidence: **{conf:.3f}**")
st.subheader("Model probabilities")
plot_probabilities(probs)
st.subheader("Preprocessed 28×28 image (what the model sees)")
st.image((x[0].reshape(28, 28) * 255).astype(np.uint8), clamp=True)
else:
st.info("Upload an image to get a prediction.")