BinhQuocNguyen commited on
Commit
bc964a2
Β·
verified Β·
1 Parent(s): 04b4f08

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Spaces FastAPI Food Recognition Service
4
+ Optimized for Hugging Face Spaces deployment
5
+ """
6
+
7
+ import gradio as gr
8
+ import requests
9
+ import base64
10
+ import io
11
+ from PIL import Image
12
+ import torch
13
+ from transformers import pipeline
14
+ import logging
15
+ from datetime import datetime
16
+ import os
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Global variables for model
23
+ classifier = None
24
+ model_loaded = False
25
+
26
+ # Model configuration
27
+ MODEL_ID = "BinhQuocNguyen/food-recognition-vit"
28
+ FOOD_CLASSES = [
29
+ "apple_pie", "caesar_salad", "chocolate_cake", "cup_cakes", "donuts",
30
+ "hamburger", "ice_cream", "pancakes", "pizza", "waffles"
31
+ ]
32
+
33
+ def load_model():
34
+ """Load the Hugging Face model"""
35
+ global classifier, model_loaded
36
+ try:
37
+ logger.info(f"Loading model: {MODEL_ID}")
38
+ classifier = pipeline(
39
+ "image-classification",
40
+ model=MODEL_ID,
41
+ device=-1 # Use CPU (change to 0 for GPU)
42
+ )
43
+ model_loaded = True
44
+ logger.info("Model loaded successfully!")
45
+ return True
46
+ except Exception as e:
47
+ logger.error(f"Failed to load model: {e}")
48
+ model_loaded = False
49
+ return False
50
+
51
+ def preprocess_image(image):
52
+ """Preprocess uploaded image"""
53
+ try:
54
+ if isinstance(image, str):
55
+ # If it's a file path
56
+ image = Image.open(image)
57
+ elif hasattr(image, 'convert'):
58
+ # If it's already a PIL Image
59
+ pass
60
+ else:
61
+ # If it's numpy array or other format
62
+ image = Image.fromarray(image)
63
+
64
+ # Convert to RGB if necessary
65
+ if image.mode != 'RGB':
66
+ image = image.convert('RGB')
67
+ return image
68
+ except Exception as e:
69
+ raise ValueError(f"Invalid image format: {e}")
70
+
71
+ def predict_food(image):
72
+ """Predict food type from image"""
73
+ if not model_loaded:
74
+ return "Model not loaded. Please try again.", None
75
+
76
+ try:
77
+ # Preprocess image
78
+ processed_image = preprocess_image(image)
79
+
80
+ # Make prediction
81
+ results = classifier(processed_image)
82
+
83
+ # Format results
84
+ predictions = []
85
+ for result in results:
86
+ predictions.append({
87
+ 'label': result['label'],
88
+ 'confidence': result['score']
89
+ })
90
+
91
+ # Get top prediction
92
+ top_prediction = predictions[0]
93
+ confidence_percent = top_prediction['confidence'] * 100
94
+
95
+ # Create result text
96
+ result_text = f"πŸ• **Predicted Food:** {top_prediction['label'].replace('_', ' ').title()}\n"
97
+ result_text += f"🎯 **Confidence:** {confidence_percent:.1f}%\n\n"
98
+ result_text += "**Top 3 Predictions:**\n"
99
+
100
+ for i, pred in enumerate(predictions[:3], 1):
101
+ food_name = pred['label'].replace('_', ' ').title()
102
+ conf_percent = pred['confidence'] * 100
103
+ result_text += f"{i}. {food_name}: {conf_percent:.1f}%\n"
104
+
105
+ return result_text, processed_image
106
+
107
+ except Exception as e:
108
+ logger.error(f"Prediction error: {e}")
109
+ return f"❌ Error: {str(e)}", None
110
+
111
+ def get_model_info():
112
+ """Get model information"""
113
+ return {
114
+ "model_id": MODEL_ID,
115
+ "model_url": f"https://huggingface.co/{MODEL_ID}",
116
+ "classes": FOOD_CLASSES,
117
+ "num_classes": len(FOOD_CLASSES),
118
+ "device": "cpu"
119
+ }
120
+
121
+ # Load model on startup
122
+ load_model()
123
+
124
+ # Create Gradio interface
125
+ def create_interface():
126
+ """Create the Gradio interface"""
127
+
128
+ with gr.Blocks(
129
+ title="Food Recognition API",
130
+ theme=gr.themes.Soft(),
131
+ css="""
132
+ .gradio-container {
133
+ max-width: 800px !important;
134
+ margin: auto !important;
135
+ }
136
+ """
137
+ ) as interface:
138
+
139
+ gr.Markdown("""
140
+ # πŸ• Food Recognition API
141
+
142
+ Upload an image of food and get instant predictions! This API uses a Vision Transformer model
143
+ trained to recognize 10 different types of food.
144
+
145
+ **Supported Food Types:** Apple Pie, Caesar Salad, Chocolate Cake, Cup Cakes, Donuts,
146
+ Hamburger, Ice Cream, Pancakes, Pizza, Waffles
147
+ """)
148
+
149
+ with gr.Row():
150
+ with gr.Column(scale=1):
151
+ image_input = gr.Image(
152
+ label="Upload Food Image",
153
+ type="pil",
154
+ height=300
155
+ )
156
+
157
+ predict_btn = gr.Button(
158
+ "πŸ” Predict Food",
159
+ variant="primary",
160
+ size="lg"
161
+ )
162
+
163
+ gr.Markdown("""
164
+ ### πŸ“Š Model Information
165
+ - **Model:** Vision Transformer (ViT)
166
+ - **Accuracy:** 68%
167
+ - **Classes:** 10 food types
168
+ - **Source:** [Hugging Face Model](https://huggingface.co/BinhQuocNguyen/food-recognition-vit)
169
+ """)
170
+
171
+ with gr.Column(scale=1):
172
+ output_text = gr.Markdown(
173
+ label="Prediction Results",
174
+ value="πŸ‘† Upload an image and click 'Predict Food' to get started!"
175
+ )
176
+
177
+ output_image = gr.Image(
178
+ label="Processed Image",
179
+ height=300
180
+ )
181
+
182
+ # Example images
183
+ gr.Markdown("### πŸ“Έ Example Images")
184
+ gr.Examples(
185
+ examples=[
186
+ ["food_recognition_model/data/processed/val/apple_pie/apple_pie_000.jpg"],
187
+ ["food_recognition_model/data/processed/val/pizza/pizza_000.jpg"],
188
+ ["food_recognition_model/data/processed/val/hamburger/hamburger_000.jpg"],
189
+ ],
190
+ inputs=image_input,
191
+ label="Click on an example to test"
192
+ )
193
+
194
+ # Event handlers
195
+ predict_btn.click(
196
+ fn=predict_food,
197
+ inputs=image_input,
198
+ outputs=[output_text, output_image]
199
+ )
200
+
201
+ # Footer
202
+ gr.Markdown("""
203
+ ---
204
+ **Built with:** FastAPI, Gradio, Hugging Face Transformers, PyTorch
205
+
206
+ **Model Performance:** 68% accuracy on 10 food classes
207
+
208
+ **API Endpoints:** Available at `/docs` for programmatic access
209
+ """)
210
+
211
+ return interface
212
+
213
+ # Create the interface
214
+ interface = create_interface()
215
+
216
+ # FastAPI app for additional endpoints
217
+ from fastapi import FastAPI
218
+ from fastapi.middleware.cors import CORSMiddleware
219
+ from pydantic import BaseModel
220
+ from typing import List, Optional
221
+ import uvicorn
222
+
223
+ # Initialize FastAPI app
224
+ app = FastAPI(
225
+ title="Food Recognition API",
226
+ description="API for food recognition using Hugging Face Vision Transformer model",
227
+ version="1.0.0"
228
+ )
229
+
230
+ # Add CORS middleware
231
+ app.add_middleware(
232
+ CORSMiddleware,
233
+ allow_origins=["*"],
234
+ allow_credentials=True,
235
+ allow_methods=["*"],
236
+ allow_headers=["*"],
237
+ )
238
+
239
+ # Pydantic models
240
+ class PredictionResult(BaseModel):
241
+ label: str
242
+ confidence: float
243
+
244
+ class PredictionResponse(BaseModel):
245
+ predictions: List[PredictionResult]
246
+ processing_time: float
247
+ model_info: dict
248
+
249
+ class HealthResponse(BaseModel):
250
+ status: str
251
+ model_loaded: bool
252
+ timestamp: str
253
+ model_info: Optional[dict] = None
254
+
255
+ # FastAPI routes
256
+ @app.get("/")
257
+ async def root():
258
+ """Root endpoint"""
259
+ return {
260
+ "message": "Food Recognition API",
261
+ "version": "1.0.0",
262
+ "model": MODEL_ID,
263
+ "gradio_interface": "/",
264
+ "api_docs": "/docs"
265
+ }
266
+
267
+ @app.get("/health", response_model=HealthResponse)
268
+ async def health_check():
269
+ """Health check endpoint"""
270
+ return HealthResponse(
271
+ status="healthy" if model_loaded else "unhealthy",
272
+ model_loaded=model_loaded,
273
+ timestamp=datetime.now().isoformat(),
274
+ model_info=get_model_info() if model_loaded else None
275
+ )
276
+
277
+ @app.get("/classes")
278
+ async def get_classes():
279
+ """Get supported food classes"""
280
+ return {
281
+ "classes": FOOD_CLASSES,
282
+ "num_classes": len(FOOD_CLASSES),
283
+ "model_id": MODEL_ID
284
+ }
285
+
286
+ @app.get("/model-info")
287
+ async def get_model_information():
288
+ """Get detailed model information"""
289
+ if not model_loaded:
290
+ return {"error": "Model not loaded"}
291
+ return get_model_info()
292
+
293
+ # Mount Gradio interface
294
+ app = gr.mount_gradio_app(app, interface, path="/")
295
+
296
+ if __name__ == "__main__":
297
+ # For local development
298
+ uvicorn.run(app, host="0.0.0.0", port=7860)