File size: 5,143 Bytes
3a30bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mb_preprocess
import numpy as np
import json
from huggingface_hub import hf_hub_download
from PIL import Image

# Download model and labels from Model Hub repository
print("Downloading model from Hugging Face Model Hub...")
model_path = hf_hub_download(
    repo_id="meetran/painting-classifier-keras-v1",
    filename="wikiart_mobilenetv2_multihead.keras"
)
labels_path = hf_hub_download(
    repo_id="meetran/painting-classifier-keras-v1",
    filename="class_labels.json"
)

print("Model and labels downloaded successfully")

# Load class labels
with open(labels_path, "r", encoding="utf-8") as f:
    class_labels = json.load(f)

artist_names = class_labels["artist_names"]
genre_names = class_labels["genre_names"]
style_names = class_labels["style_names"]

# Load the trained model
print("Loading model...")
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully")

IMG_SIZE = (224, 224)

def preprocess_image(image):
    """Preprocess input image for model inference"""
    img = np.array(image)
    img = tf.image.resize(img, IMG_SIZE)
    img = mb_preprocess(img)
    img = tf.expand_dims(img, axis=0)
    return img

def classify_painting(image):
    """Classify painting by artist, genre, and style"""
    if image is None:
        return None, None, None
    
    try:
        # Preprocess image
        processed_img = preprocess_image(image)
        
        # Get predictions
        predictions = model.predict(processed_img, verbose=0)
        
        # Process artist predictions
        artist_probs = tf.nn.softmax(predictions['artist'][0]).numpy()
        artist_dict = {artist_names[i]: float(artist_probs[i]) 
                       for i in range(len(artist_names))}
        
        # Process genre predictions
        genre_probs = tf.nn.softmax(predictions['genre'][0]).numpy()
        genre_dict = {genre_names[i]: float(genre_probs[i]) 
                      for i in range(len(genre_names))}
        
        # Process style predictions
        style_probs = tf.nn.softmax(predictions['style'][0]).numpy()
        style_dict = {style_names[i]: float(style_probs[i]) 
                      for i in range(len(style_names))}
        
        return artist_dict, genre_dict, style_dict
    
    except Exception as e:
        print(f"Error during classification: {e}")
        return None, None, None

# Create Gradio interface
with gr.Blocks(title="WikiArt Painting Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# WikiArt Painting Classifier")
    gr.Markdown(
        "Upload a painting image to classify its Artist (129 classes), "
        "Genre (11 classes), and Style (27 classes) using a MobileNetV2-based multi-task model."
    )
    gr.Markdown(
        "**Model Repository**: [meetran/painting-classifier-keras-v1]"
        "(https://huggingface.co/meetran/painting-classifier-keras-v1)"
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Painting Image")
            classify_btn = gr.Button("Classify Painting", variant="primary", size="lg")
            
            gr.Markdown("### Tips for Best Results")
            gr.Markdown(
                "- Upload clear, high-quality images of paintings\n"
                "- Works best with Western classical and modern art\n"
                "- Supports paintings from 129 famous artists\n"
                "- Can identify 27 different art styles"
            )
            
        with gr.Column():
            artist_output = gr.Label(label="Artist Prediction (Top 10)", num_top_classes=10)
            genre_output = gr.Label(label="Genre Prediction", num_top_classes=5)
            style_output = gr.Label(label="Art Style Prediction (Top 10)", num_top_classes=10)
    
    gr.Markdown("---")
    gr.Markdown("### Model Information")
    gr.Markdown(
        "- **Architecture**: MobileNetV2 (ImageNet pre-trained) with multi-head classification\n"
        "- **Dataset**: WikiArt dataset containing 84,440 paintings\n"
        "- **Training**: Two-stage training (frozen backbone + fine-tuning)\n"
        "- **Input Size**: 224x224 RGB images\n"
        "- **Framework**: TensorFlow/Keras\n\n"
        "**Notable Artists**: Claude Monet, Vincent van Gogh, Pablo Picasso, Leonardo da Vinci, "
        "Rembrandt, Salvador Dali, Michelangelo, Edgar Degas, Paul Cezanne, Henri Matisse, and 119 more.\n\n"
        "**Art Styles**: Impressionism, Cubism, Renaissance, Baroque, Expressionism, "
        "Abstract Expressionism, Realism, Pop Art, Romanticism, Symbolism, and 17 more."
    )
    
    # Connect button to function
    classify_btn.click(
        fn=classify_painting,
        inputs=image_input,
        outputs=[artist_output, genre_output, style_output]
    )
    
    # Auto-classify on image upload
    image_input.change(
        fn=classify_painting,
        inputs=image_input,
        outputs=[artist_output, genre_output, style_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()