Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| import tensorflow as tf | |
| import gradio as gr | |
| import numpy as np | |
| MODEL_URL = "https://drive.google.com/uc?export=download&id=1ECjloRVUkgnKACeBZA06UU_-JddZQ5Z5" | |
| MODEL_PATH = "my_mnist_model.keras" | |
| def download_file_from_google_drive(url, destination): | |
| if not os.path.exists(destination): | |
| print("Downloading model from Google Drive...") | |
| response = requests.get(url, stream=True) | |
| with open(destination, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| print("Download complete.") | |
| download_file_from_google_drive(MODEL_URL, MODEL_PATH) | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| def predict_digit(image): | |
| # If the input is a file path, open and process the image | |
| if isinstance(image, str): | |
| from PIL import Image | |
| image = Image.open(image).convert('L') | |
| image = np.array(image) | |
| if image.ndim == 3: | |
| image = image[..., 0] | |
| image = np.array(image).astype("float32") | |
| image = image.reshape(1, 28, 28) | |
| image = image / 255.0 | |
| prediction = model.predict(image) | |
| return str(np.argmax(prediction)) | |
| iface = gr.Interface( | |
| fn=predict_digit, | |
| inputs=gr.Image(image_mode='L', type='filepath', height=28, width=28), | |
| outputs=gr.Label(num_top_classes=1), | |
| title="MNIST Digit Classifier", | |
| description="Upload or draw a digit (0-9) and the model will predict it." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |