meetran commited on
Commit
3a30bf7
·
verified ·
1 Parent(s): dd3e61b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mb_preprocess
4
+ import numpy as np
5
+ import json
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image
8
+
9
+ # Download model and labels from Model Hub repository
10
+ print("Downloading model from Hugging Face Model Hub...")
11
+ model_path = hf_hub_download(
12
+ repo_id="meetran/painting-classifier-keras-v1",
13
+ filename="wikiart_mobilenetv2_multihead.keras"
14
+ )
15
+ labels_path = hf_hub_download(
16
+ repo_id="meetran/painting-classifier-keras-v1",
17
+ filename="class_labels.json"
18
+ )
19
+
20
+ print("Model and labels downloaded successfully")
21
+
22
+ # Load class labels
23
+ with open(labels_path, "r", encoding="utf-8") as f:
24
+ class_labels = json.load(f)
25
+
26
+ artist_names = class_labels["artist_names"]
27
+ genre_names = class_labels["genre_names"]
28
+ style_names = class_labels["style_names"]
29
+
30
+ # Load the trained model
31
+ print("Loading model...")
32
+ model = tf.keras.models.load_model(model_path)
33
+ print("Model loaded successfully")
34
+
35
+ IMG_SIZE = (224, 224)
36
+
37
+ def preprocess_image(image):
38
+ """Preprocess input image for model inference"""
39
+ img = np.array(image)
40
+ img = tf.image.resize(img, IMG_SIZE)
41
+ img = mb_preprocess(img)
42
+ img = tf.expand_dims(img, axis=0)
43
+ return img
44
+
45
+ def classify_painting(image):
46
+ """Classify painting by artist, genre, and style"""
47
+ if image is None:
48
+ return None, None, None
49
+
50
+ try:
51
+ # Preprocess image
52
+ processed_img = preprocess_image(image)
53
+
54
+ # Get predictions
55
+ predictions = model.predict(processed_img, verbose=0)
56
+
57
+ # Process artist predictions
58
+ artist_probs = tf.nn.softmax(predictions['artist'][0]).numpy()
59
+ artist_dict = {artist_names[i]: float(artist_probs[i])
60
+ for i in range(len(artist_names))}
61
+
62
+ # Process genre predictions
63
+ genre_probs = tf.nn.softmax(predictions['genre'][0]).numpy()
64
+ genre_dict = {genre_names[i]: float(genre_probs[i])
65
+ for i in range(len(genre_names))}
66
+
67
+ # Process style predictions
68
+ style_probs = tf.nn.softmax(predictions['style'][0]).numpy()
69
+ style_dict = {style_names[i]: float(style_probs[i])
70
+ for i in range(len(style_names))}
71
+
72
+ return artist_dict, genre_dict, style_dict
73
+
74
+ except Exception as e:
75
+ print(f"Error during classification: {e}")
76
+ return None, None, None
77
+
78
+ # Create Gradio interface
79
+ with gr.Blocks(title="WikiArt Painting Classifier", theme=gr.themes.Soft()) as demo:
80
+ gr.Markdown("# WikiArt Painting Classifier")
81
+ gr.Markdown(
82
+ "Upload a painting image to classify its Artist (129 classes), "
83
+ "Genre (11 classes), and Style (27 classes) using a MobileNetV2-based multi-task model."
84
+ )
85
+ gr.Markdown(
86
+ "**Model Repository**: [meetran/painting-classifier-keras-v1]"
87
+ "(https://huggingface.co/meetran/painting-classifier-keras-v1)"
88
+ )
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ image_input = gr.Image(type="pil", label="Upload Painting Image")
93
+ classify_btn = gr.Button("Classify Painting", variant="primary", size="lg")
94
+
95
+ gr.Markdown("### Tips for Best Results")
96
+ gr.Markdown(
97
+ "- Upload clear, high-quality images of paintings\n"
98
+ "- Works best with Western classical and modern art\n"
99
+ "- Supports paintings from 129 famous artists\n"
100
+ "- Can identify 27 different art styles"
101
+ )
102
+
103
+ with gr.Column():
104
+ artist_output = gr.Label(label="Artist Prediction (Top 10)", num_top_classes=10)
105
+ genre_output = gr.Label(label="Genre Prediction", num_top_classes=5)
106
+ style_output = gr.Label(label="Art Style Prediction (Top 10)", num_top_classes=10)
107
+
108
+ gr.Markdown("---")
109
+ gr.Markdown("### Model Information")
110
+ gr.Markdown(
111
+ "- **Architecture**: MobileNetV2 (ImageNet pre-trained) with multi-head classification\n"
112
+ "- **Dataset**: WikiArt dataset containing 84,440 paintings\n"
113
+ "- **Training**: Two-stage training (frozen backbone + fine-tuning)\n"
114
+ "- **Input Size**: 224x224 RGB images\n"
115
+ "- **Framework**: TensorFlow/Keras\n\n"
116
+ "**Notable Artists**: Claude Monet, Vincent van Gogh, Pablo Picasso, Leonardo da Vinci, "
117
+ "Rembrandt, Salvador Dali, Michelangelo, Edgar Degas, Paul Cezanne, Henri Matisse, and 119 more.\n\n"
118
+ "**Art Styles**: Impressionism, Cubism, Renaissance, Baroque, Expressionism, "
119
+ "Abstract Expressionism, Realism, Pop Art, Romanticism, Symbolism, and 17 more."
120
+ )
121
+
122
+ # Connect button to function
123
+ classify_btn.click(
124
+ fn=classify_painting,
125
+ inputs=image_input,
126
+ outputs=[artist_output, genre_output, style_output]
127
+ )
128
+
129
+ # Auto-classify on image upload
130
+ image_input.change(
131
+ fn=classify_painting,
132
+ inputs=image_input,
133
+ outputs=[artist_output, genre_output, style_output]
134
+ )
135
+
136
+ # Launch the app
137
+ if __name__ == "__main__":
138
+ demo.launch()