AfshinMA commited on
Commit
9c98bdb
·
verified ·
1 Parent(s): 538fffb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -10
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import requests
2
  import tensorflow as tf
3
  import gradio as gr
 
 
4
 
5
  def classify_image(input_image):
6
  # Download human-readable labels for ImageNet.
@@ -22,26 +24,44 @@ def classify_image(input_image):
22
  classifier_activation="softmax"
23
  )
24
 
25
- # Resize the input image to the expected size
26
- input_image = input_image.reshape((1, 224, 224, 3)) # Reshape for a single prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  input_image = tf.keras.applications.mobilenet_v2.preprocess_input(input_image)
28
-
29
  # Perform prediction
30
  prediction = inception_net.predict(input_image).flatten()
31
  confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
32
  return confidences
33
-
34
  image = gr.Image(interactive=True, label="Upload Image")
35
  label = gr.Label(num_top_classes=3, label="Top Predictions")
36
 
37
  demo = gr.Interface(
38
  title="Image Classifier Keras",
39
- fn= classify_image,
40
- inputs= image,
41
- outputs= label,
42
- examples= [["./images/banana.jpg"], ["./images/car.jpg"], ["./images/guitar.jpg"], ["./images/lion.jpg"]], # Use valid URLs or local paths
43
- theme= "default",
44
- css= ".footer{display:none !important}"
45
  )
46
 
47
  if __name__ == "__main__":
 
1
  import requests
2
  import tensorflow as tf
3
  import gradio as gr
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
  def classify_image(input_image):
8
  # Download human-readable labels for ImageNet.
 
24
  classifier_activation="softmax"
25
  )
26
 
27
+ # Handle input_image (ensure it's a PIL Image)
28
+ if isinstance(input_image, str): # If it's a file path or URL
29
+ input_image = Image.open(input_image).convert("RGB")
30
+ elif isinstance(input_image, np.ndarray): # If it's a numpy array
31
+ input_image = Image.fromarray(input_image).convert("RGB")
32
+
33
+ # Resize the image to 224x224
34
+ input_image = input_image.resize((224, 224))
35
+
36
+ # Convert image to a numpy array
37
+ input_image = np.array(input_image)
38
+
39
+ # Ensure it's in the right format (RGB channels only)
40
+ if input_image.shape[-1] == 4: # If there's an alpha channel
41
+ input_image = input_image[..., :3] # Remove the alpha channel
42
+
43
+ # Reshape for a single prediction
44
+ input_image = input_image.reshape((1, 224, 224, 3))
45
+
46
+ # Preprocess the image
47
  input_image = tf.keras.applications.mobilenet_v2.preprocess_input(input_image)
48
+
49
  # Perform prediction
50
  prediction = inception_net.predict(input_image).flatten()
51
  confidences = {labels[i]: float(prediction[i]) for i in range(len(labels))}
52
  return confidences
53
+
54
  image = gr.Image(interactive=True, label="Upload Image")
55
  label = gr.Label(num_top_classes=3, label="Top Predictions")
56
 
57
  demo = gr.Interface(
58
  title="Image Classifier Keras",
59
+ fn=classify_image,
60
+ inputs=image,
61
+ outputs=label,
62
+ examples=[["./images/banana.jpg"], ["./images/car.jpg"], ["./images/guitar.jpg"], ["./images/lion.jpg"]],
63
+ theme="default",
64
+ css=".footer{display:none !important}"
65
  )
66
 
67
  if __name__ == "__main__":