meetran's picture
Create app.py
3a30bf7 verified
import gradio as gr
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mb_preprocess
import numpy as np
import json
from huggingface_hub import hf_hub_download
from PIL import Image
# Download model and labels from Model Hub repository
print("Downloading model from Hugging Face Model Hub...")
model_path = hf_hub_download(
repo_id="meetran/painting-classifier-keras-v1",
filename="wikiart_mobilenetv2_multihead.keras"
)
labels_path = hf_hub_download(
repo_id="meetran/painting-classifier-keras-v1",
filename="class_labels.json"
)
print("Model and labels downloaded successfully")
# Load class labels
with open(labels_path, "r", encoding="utf-8") as f:
class_labels = json.load(f)
artist_names = class_labels["artist_names"]
genre_names = class_labels["genre_names"]
style_names = class_labels["style_names"]
# Load the trained model
print("Loading model...")
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully")
IMG_SIZE = (224, 224)
def preprocess_image(image):
"""Preprocess input image for model inference"""
img = np.array(image)
img = tf.image.resize(img, IMG_SIZE)
img = mb_preprocess(img)
img = tf.expand_dims(img, axis=0)
return img
def classify_painting(image):
"""Classify painting by artist, genre, and style"""
if image is None:
return None, None, None
try:
# Preprocess image
processed_img = preprocess_image(image)
# Get predictions
predictions = model.predict(processed_img, verbose=0)
# Process artist predictions
artist_probs = tf.nn.softmax(predictions['artist'][0]).numpy()
artist_dict = {artist_names[i]: float(artist_probs[i])
for i in range(len(artist_names))}
# Process genre predictions
genre_probs = tf.nn.softmax(predictions['genre'][0]).numpy()
genre_dict = {genre_names[i]: float(genre_probs[i])
for i in range(len(genre_names))}
# Process style predictions
style_probs = tf.nn.softmax(predictions['style'][0]).numpy()
style_dict = {style_names[i]: float(style_probs[i])
for i in range(len(style_names))}
return artist_dict, genre_dict, style_dict
except Exception as e:
print(f"Error during classification: {e}")
return None, None, None
# Create Gradio interface
with gr.Blocks(title="WikiArt Painting Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("# WikiArt Painting Classifier")
gr.Markdown(
"Upload a painting image to classify its Artist (129 classes), "
"Genre (11 classes), and Style (27 classes) using a MobileNetV2-based multi-task model."
)
gr.Markdown(
"**Model Repository**: [meetran/painting-classifier-keras-v1]"
"(https://huggingface.co/meetran/painting-classifier-keras-v1)"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Painting Image")
classify_btn = gr.Button("Classify Painting", variant="primary", size="lg")
gr.Markdown("### Tips for Best Results")
gr.Markdown(
"- Upload clear, high-quality images of paintings\n"
"- Works best with Western classical and modern art\n"
"- Supports paintings from 129 famous artists\n"
"- Can identify 27 different art styles"
)
with gr.Column():
artist_output = gr.Label(label="Artist Prediction (Top 10)", num_top_classes=10)
genre_output = gr.Label(label="Genre Prediction", num_top_classes=5)
style_output = gr.Label(label="Art Style Prediction (Top 10)", num_top_classes=10)
gr.Markdown("---")
gr.Markdown("### Model Information")
gr.Markdown(
"- **Architecture**: MobileNetV2 (ImageNet pre-trained) with multi-head classification\n"
"- **Dataset**: WikiArt dataset containing 84,440 paintings\n"
"- **Training**: Two-stage training (frozen backbone + fine-tuning)\n"
"- **Input Size**: 224x224 RGB images\n"
"- **Framework**: TensorFlow/Keras\n\n"
"**Notable Artists**: Claude Monet, Vincent van Gogh, Pablo Picasso, Leonardo da Vinci, "
"Rembrandt, Salvador Dali, Michelangelo, Edgar Degas, Paul Cezanne, Henri Matisse, and 119 more.\n\n"
"**Art Styles**: Impressionism, Cubism, Renaissance, Baroque, Expressionism, "
"Abstract Expressionism, Realism, Pop Art, Romanticism, Symbolism, and 17 more."
)
# Connect button to function
classify_btn.click(
fn=classify_painting,
inputs=image_input,
outputs=[artist_output, genre_output, style_output]
)
# Auto-classify on image upload
image_input.change(
fn=classify_painting,
inputs=image_input,
outputs=[artist_output, genre_output, style_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()