Spaces:
Sleeping
Sleeping
| import requests | |
| import tensorflow as tf | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| def classify_image(input_image): | |
| # Download human-readable labels for ImageNet. | |
| try: | |
| response = requests.get("https://git.io/JJkYN") | |
| response.raise_for_status() # Ensure the request was successful | |
| labels = response.text.split("\n") | |
| except Exception as e: | |
| print("Error fetching labels:", e) | |
| labels = ["Unknown"] * 1000 # Fallback in case the request fails | |
| # Load the MobileNetV2 model | |
| inception_net = tf.keras.applications.MobileNetV2( | |
| input_shape=(224, 224, 3), | |
| alpha=1.0, | |
| include_top=True, | |
| weights="imagenet", | |
| classes=1000, | |
| classifier_activation="softmax" | |
| ) | |
| # Handle input_image (ensure it's a PIL Image) | |
| if isinstance(input_image, str): # If it's a file path or URL | |
| input_image = Image.open(input_image).convert("RGB") | |
| elif isinstance(input_image, np.ndarray): # If it's a numpy array | |
| input_image = Image.fromarray(input_image).convert("RGB") | |
| # Resize the image to 224x224 | |
| input_image = input_image.resize((224, 224)) | |
| # Convert image to a numpy array | |
| input_image = np.array(input_image) | |
| # Ensure it's in the right format (RGB channels only) | |
| if input_image.shape[-1] == 4: # If there's an alpha channel | |
| input_image = input_image[..., :3] # Remove the alpha channel | |
| # Reshape for a single prediction | |
| input_image = input_image.reshape((1, 224, 224, 3)) | |
| # Preprocess the image | |
| input_image = tf.keras.applications.mobilenet_v2.preprocess_input(input_image) | |
| # Perform prediction | |
| prediction = inception_net.predict(input_image).flatten() | |
| # Get the top indices and their respective confidence scores | |
| top_indices = np.argsort(prediction)[-3:][::-1] # Get the top 3 indices | |
| confidences = {labels[i]: float(prediction[i]) for i in top_indices} | |
| return confidences | |
| image = gr.Image(interactive=True, label="Upload Image") | |
| label = gr.Label(num_top_classes=3, label="Top Predictions") | |
| demo = gr.Interface( | |
| title="Image Classifier Keras", | |
| fn=classify_image, | |
| inputs=image, | |
| outputs=label, | |
| examples=[["./images/banana.jpg"], ["./images/car.jpg"], ["./images/guitar.jpg"], ["./images/lion.jpg"]], | |
| theme="default", | |
| css=".footer{display:none !important}" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |