import gradio as gr import tensorflow as tf from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mb_preprocess import numpy as np import json from huggingface_hub import hf_hub_download from PIL import Image # Download model and labels from Model Hub repository print("Downloading model from Hugging Face Model Hub...") model_path = hf_hub_download( repo_id="meetran/painting-classifier-keras-v1", filename="wikiart_mobilenetv2_multihead.keras" ) labels_path = hf_hub_download( repo_id="meetran/painting-classifier-keras-v1", filename="class_labels.json" ) print("Model and labels downloaded successfully") # Load class labels with open(labels_path, "r", encoding="utf-8") as f: class_labels = json.load(f) artist_names = class_labels["artist_names"] genre_names = class_labels["genre_names"] style_names = class_labels["style_names"] # Load the trained model print("Loading model...") model = tf.keras.models.load_model(model_path) print("Model loaded successfully") IMG_SIZE = (224, 224) def preprocess_image(image): """Preprocess input image for model inference""" img = np.array(image) img = tf.image.resize(img, IMG_SIZE) img = mb_preprocess(img) img = tf.expand_dims(img, axis=0) return img def classify_painting(image): """Classify painting by artist, genre, and style""" if image is None: return None, None, None try: # Preprocess image processed_img = preprocess_image(image) # Get predictions predictions = model.predict(processed_img, verbose=0) # Process artist predictions artist_probs = tf.nn.softmax(predictions['artist'][0]).numpy() artist_dict = {artist_names[i]: float(artist_probs[i]) for i in range(len(artist_names))} # Process genre predictions genre_probs = tf.nn.softmax(predictions['genre'][0]).numpy() genre_dict = {genre_names[i]: float(genre_probs[i]) for i in range(len(genre_names))} # Process style predictions style_probs = tf.nn.softmax(predictions['style'][0]).numpy() style_dict = {style_names[i]: float(style_probs[i]) for i in range(len(style_names))} return artist_dict, genre_dict, style_dict except Exception as e: print(f"Error during classification: {e}") return None, None, None # Create Gradio interface with gr.Blocks(title="WikiArt Painting Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown("# WikiArt Painting Classifier") gr.Markdown( "Upload a painting image to classify its Artist (129 classes), " "Genre (11 classes), and Style (27 classes) using a MobileNetV2-based multi-task model." ) gr.Markdown( "**Model Repository**: [meetran/painting-classifier-keras-v1]" "(https://huggingface.co/meetran/painting-classifier-keras-v1)" ) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Painting Image") classify_btn = gr.Button("Classify Painting", variant="primary", size="lg") gr.Markdown("### Tips for Best Results") gr.Markdown( "- Upload clear, high-quality images of paintings\n" "- Works best with Western classical and modern art\n" "- Supports paintings from 129 famous artists\n" "- Can identify 27 different art styles" ) with gr.Column(): artist_output = gr.Label(label="Artist Prediction (Top 10)", num_top_classes=10) genre_output = gr.Label(label="Genre Prediction", num_top_classes=5) style_output = gr.Label(label="Art Style Prediction (Top 10)", num_top_classes=10) gr.Markdown("---") gr.Markdown("### Model Information") gr.Markdown( "- **Architecture**: MobileNetV2 (ImageNet pre-trained) with multi-head classification\n" "- **Dataset**: WikiArt dataset containing 84,440 paintings\n" "- **Training**: Two-stage training (frozen backbone + fine-tuning)\n" "- **Input Size**: 224x224 RGB images\n" "- **Framework**: TensorFlow/Keras\n\n" "**Notable Artists**: Claude Monet, Vincent van Gogh, Pablo Picasso, Leonardo da Vinci, " "Rembrandt, Salvador Dali, Michelangelo, Edgar Degas, Paul Cezanne, Henri Matisse, and 119 more.\n\n" "**Art Styles**: Impressionism, Cubism, Renaissance, Baroque, Expressionism, " "Abstract Expressionism, Realism, Pop Art, Romanticism, Symbolism, and 17 more." ) # Connect button to function classify_btn.click( fn=classify_painting, inputs=image_input, outputs=[artist_output, genre_output, style_output] ) # Auto-classify on image upload image_input.change( fn=classify_painting, inputs=image_input, outputs=[artist_output, genre_output, style_output] ) # Launch the app if __name__ == "__main__": demo.launch()