BinhQuocNguyen commited on
Commit
ba1857d
·
verified ·
1 Parent(s): 6fa04a2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -22,6 +22,11 @@ download_file_from_google_drive(MODEL_URL, MODEL_PATH)
22
  model = tf.keras.models.load_model(MODEL_PATH)
23
 
24
  def predict_digit(image):
 
 
 
 
 
25
  if image.ndim == 3:
26
  image = image[..., 0]
27
  image = np.array(image).astype("float32")
@@ -32,10 +37,10 @@ def predict_digit(image):
32
 
33
  iface = gr.Interface(
34
  fn=predict_digit,
35
- inputs=gr.Image(image_mode='L', height=28, width=28),
36
  outputs=gr.Label(num_top_classes=1),
37
  title="MNIST Digit Classifier",
38
- description="Draw a digit (0-9) and the model will predict it."
39
  )
40
 
41
  if __name__ == "__main__":
 
22
  model = tf.keras.models.load_model(MODEL_PATH)
23
 
24
  def predict_digit(image):
25
+ # If the input is a file path, open and process the image
26
+ if isinstance(image, str):
27
+ from PIL import Image
28
+ image = Image.open(image).convert('L')
29
+ image = np.array(image)
30
  if image.ndim == 3:
31
  image = image[..., 0]
32
  image = np.array(image).astype("float32")
 
37
 
38
  iface = gr.Interface(
39
  fn=predict_digit,
40
+ inputs=gr.Image(image_mode='L', type='filepath', height=28, width=28),
41
  outputs=gr.Label(num_top_classes=1),
42
  title="MNIST Digit Classifier",
43
+ description="Upload or draw a digit (0-9) and the model will predict it."
44
  )
45
 
46
  if __name__ == "__main__":