Spaces:
Sleeping
Sleeping
File size: 1,570 Bytes
8556fe3 984cb3d 8556fe3 984cb3d 8556fe3 984cb3d ba1857d 984cb3d ba1857d 984cb3d ba1857d 984cb3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import os
import requests
import tensorflow as tf
import gradio as gr
import numpy as np
MODEL_URL = "https://drive.google.com/uc?export=download&id=1ECjloRVUkgnKACeBZA06UU_-JddZQ5Z5"
MODEL_PATH = "my_mnist_model.keras"
def download_file_from_google_drive(url, destination):
if not os.path.exists(destination):
print("Downloading model from Google Drive...")
response = requests.get(url, stream=True)
with open(destination, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Download complete.")
download_file_from_google_drive(MODEL_URL, MODEL_PATH)
model = tf.keras.models.load_model(MODEL_PATH)
def predict_digit(image):
# If the input is a file path, open and process the image
if isinstance(image, str):
from PIL import Image
image = Image.open(image).convert('L')
image = np.array(image)
if image.ndim == 3:
image = image[..., 0]
image = np.array(image).astype("float32")
image = image.reshape(1, 28, 28)
image = image / 255.0
prediction = model.predict(image)
return str(np.argmax(prediction))
iface = gr.Interface(
fn=predict_digit,
inputs=gr.Image(image_mode='L', type='filepath', height=28, width=28),
outputs=gr.Label(num_top_classes=1),
title="MNIST Digit Classifier",
description="Upload or draw a digit (0-9) and the model will predict it."
)
if __name__ == "__main__":
iface.launch() |