BinhQuocNguyen commited on
Commit
8556fe3
·
verified ·
1 Parent(s): faf4879

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -1,15 +1,29 @@
1
- import gradio as gr
 
2
  import tensorflow as tf
 
3
  import numpy as np
4
 
5
- # Load the model
6
- model = tf.keras.models.load_model("my_mnist_model.keras")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def predict_digit(image):
9
- # Convert to grayscale if needed
10
  if image.ndim == 3:
11
  image = image[..., 0]
12
- # Resize and preprocess
13
  image = np.array(image).astype("float32")
14
  image = image.reshape(1, 28, 28)
15
  image = image / 255.0
@@ -18,7 +32,7 @@ def predict_digit(image):
18
 
19
  iface = gr.Interface(
20
  fn=predict_digit,
21
- inputs=gr.Image(shape=(28, 28), image_mode='L', source="canvas", tool="editor"),
22
  outputs=gr.Label(num_top_classes=1),
23
  title="MNIST Digit Classifier",
24
  description="Draw a digit (0-9) and the model will predict it."
 
1
+ import os
2
+ import requests
3
  import tensorflow as tf
4
+ import gradio as gr
5
  import numpy as np
6
 
7
+ MODEL_URL = "https://drive.google.com/uc?export=download&id=1ECjloRVUkgnKACeBZA06UU_-JddZQ5Z5"
8
+ MODEL_PATH = "my_mnist_model.keras"
9
+
10
+ def download_file_from_google_drive(url, destination):
11
+ if not os.path.exists(destination):
12
+ print("Downloading model from Google Drive...")
13
+ response = requests.get(url, stream=True)
14
+ with open(destination, "wb") as f:
15
+ for chunk in response.iter_content(chunk_size=8192):
16
+ if chunk:
17
+ f.write(chunk)
18
+ print("Download complete.")
19
+
20
+ download_file_from_google_drive(MODEL_URL, MODEL_PATH)
21
+
22
+ model = tf.keras.models.load_model(MODEL_PATH)
23
 
24
  def predict_digit(image):
 
25
  if image.ndim == 3:
26
  image = image[..., 0]
 
27
  image = np.array(image).astype("float32")
28
  image = image.reshape(1, 28, 28)
29
  image = image / 255.0
 
32
 
33
  iface = gr.Interface(
34
  fn=predict_digit,
35
+ inputs=gr.Image(image_mode='L', source="canvas", tool="editor", height=28, width=28),
36
  outputs=gr.Label(num_top_classes=1),
37
  title="MNIST Digit Classifier",
38
  description="Draw a digit (0-9) and the model will predict it."