File size: 5,143 Bytes
3a30bf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mb_preprocess
import numpy as np
import json
from huggingface_hub import hf_hub_download
from PIL import Image
# Download model and labels from Model Hub repository
print("Downloading model from Hugging Face Model Hub...")
model_path = hf_hub_download(
repo_id="meetran/painting-classifier-keras-v1",
filename="wikiart_mobilenetv2_multihead.keras"
)
labels_path = hf_hub_download(
repo_id="meetran/painting-classifier-keras-v1",
filename="class_labels.json"
)
print("Model and labels downloaded successfully")
# Load class labels
with open(labels_path, "r", encoding="utf-8") as f:
class_labels = json.load(f)
artist_names = class_labels["artist_names"]
genre_names = class_labels["genre_names"]
style_names = class_labels["style_names"]
# Load the trained model
print("Loading model...")
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully")
IMG_SIZE = (224, 224)
def preprocess_image(image):
"""Preprocess input image for model inference"""
img = np.array(image)
img = tf.image.resize(img, IMG_SIZE)
img = mb_preprocess(img)
img = tf.expand_dims(img, axis=0)
return img
def classify_painting(image):
"""Classify painting by artist, genre, and style"""
if image is None:
return None, None, None
try:
# Preprocess image
processed_img = preprocess_image(image)
# Get predictions
predictions = model.predict(processed_img, verbose=0)
# Process artist predictions
artist_probs = tf.nn.softmax(predictions['artist'][0]).numpy()
artist_dict = {artist_names[i]: float(artist_probs[i])
for i in range(len(artist_names))}
# Process genre predictions
genre_probs = tf.nn.softmax(predictions['genre'][0]).numpy()
genre_dict = {genre_names[i]: float(genre_probs[i])
for i in range(len(genre_names))}
# Process style predictions
style_probs = tf.nn.softmax(predictions['style'][0]).numpy()
style_dict = {style_names[i]: float(style_probs[i])
for i in range(len(style_names))}
return artist_dict, genre_dict, style_dict
except Exception as e:
print(f"Error during classification: {e}")
return None, None, None
# Create Gradio interface
with gr.Blocks(title="WikiArt Painting Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown("# WikiArt Painting Classifier")
gr.Markdown(
"Upload a painting image to classify its Artist (129 classes), "
"Genre (11 classes), and Style (27 classes) using a MobileNetV2-based multi-task model."
)
gr.Markdown(
"**Model Repository**: [meetran/painting-classifier-keras-v1]"
"(https://huggingface.co/meetran/painting-classifier-keras-v1)"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Painting Image")
classify_btn = gr.Button("Classify Painting", variant="primary", size="lg")
gr.Markdown("### Tips for Best Results")
gr.Markdown(
"- Upload clear, high-quality images of paintings\n"
"- Works best with Western classical and modern art\n"
"- Supports paintings from 129 famous artists\n"
"- Can identify 27 different art styles"
)
with gr.Column():
artist_output = gr.Label(label="Artist Prediction (Top 10)", num_top_classes=10)
genre_output = gr.Label(label="Genre Prediction", num_top_classes=5)
style_output = gr.Label(label="Art Style Prediction (Top 10)", num_top_classes=10)
gr.Markdown("---")
gr.Markdown("### Model Information")
gr.Markdown(
"- **Architecture**: MobileNetV2 (ImageNet pre-trained) with multi-head classification\n"
"- **Dataset**: WikiArt dataset containing 84,440 paintings\n"
"- **Training**: Two-stage training (frozen backbone + fine-tuning)\n"
"- **Input Size**: 224x224 RGB images\n"
"- **Framework**: TensorFlow/Keras\n\n"
"**Notable Artists**: Claude Monet, Vincent van Gogh, Pablo Picasso, Leonardo da Vinci, "
"Rembrandt, Salvador Dali, Michelangelo, Edgar Degas, Paul Cezanne, Henri Matisse, and 119 more.\n\n"
"**Art Styles**: Impressionism, Cubism, Renaissance, Baroque, Expressionism, "
"Abstract Expressionism, Realism, Pop Art, Romanticism, Symbolism, and 17 more."
)
# Connect button to function
classify_btn.click(
fn=classify_painting,
inputs=image_input,
outputs=[artist_output, genre_output, style_output]
)
# Auto-classify on image upload
image_input.change(
fn=classify_painting,
inputs=image_input,
outputs=[artist_output, genre_output, style_output]
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|