Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import pipeline | |
| from PIL import Image | |
| import io | |
| import logging | |
| from datetime import datetime | |
| import asyncio | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Age Detection API", version="1.0.0") | |
| # Add CORS middleware - CRITICAL FIX | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify your FlutterFlow domain | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variable to store the model | |
| pipe = None | |
| def load_model(): | |
| """Load the model with error handling""" | |
| global pipe | |
| try: | |
| logger.info("Loading age classification model...") | |
| pipe = pipeline("image-classification", model="nateraw/vit-age-classifier") | |
| logger.info("Model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| return False | |
| # Load model on startup | |
| async def startup_event(): | |
| success = load_model() | |
| if not success: | |
| logger.error("Failed to initialize model on startup") | |
| async def root(): | |
| return {"message": "Age Detection API is running", "status": "healthy"} | |
| async def health_check(): | |
| """Keep-alive endpoint to prevent sleeping""" | |
| global pipe | |
| model_status = "loaded" if pipe is not None else "not_loaded" | |
| return { | |
| "status": "alive", | |
| "timestamp": datetime.now().isoformat(), | |
| "model_status": model_status | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| global pipe | |
| try: | |
| # Check if model is loaded | |
| if pipe is None: | |
| logger.warning("Model not loaded, attempting to load...") | |
| success = load_model() | |
| if not success: | |
| raise HTTPException(status_code=500, detail="Model failed to load") | |
| # Validate file type - more robust approach | |
| # Don't rely solely on content_type as it might be incorrect | |
| valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'] | |
| filename_lower = (file.filename or '').lower() | |
| # Check both content type and file extension | |
| is_valid_content_type = file.content_type and file.content_type.startswith('image/') | |
| is_valid_extension = any(filename_lower.endswith(ext) for ext in valid_extensions) | |
| if not (is_valid_content_type or is_valid_extension): | |
| logger.warning(f"Invalid file type: content_type={file.content_type}, filename={file.filename}") | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| logger.info(f"Processing image: {file.filename}") | |
| image_data = await file.read() | |
| # Optimize image processing with better error handling | |
| try: | |
| image = Image.open(io.BytesIO(image_data)) | |
| # Verify it's actually an image by trying to get basic info | |
| image.verify() # This will raise an exception if not a valid image | |
| # Reopen the image since verify() closes it | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| # Resize large images to improve speed | |
| max_size = (1024, 1024) | |
| if image.size[0] > max_size[0] or image.size[1] > max_size[1]: | |
| image.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| logger.info(f"Resized image to {image.size}") | |
| except Exception as e: | |
| logger.error(f"Image processing error: {e}") | |
| raise HTTPException(status_code=400, detail="Invalid or corrupted image file") | |
| # Run prediction with timeout | |
| try: | |
| logger.info("Running model prediction...") | |
| # Add timeout to prevent hanging | |
| results = await asyncio.wait_for( | |
| asyncio.to_thread(pipe, image), | |
| timeout=30.0 | |
| ) | |
| logger.info(f"Prediction completed: {len(results)} results") | |
| except asyncio.TimeoutError: | |
| logger.error("Model prediction timed out") | |
| raise HTTPException(status_code=504, detail="Prediction timed out") | |
| except Exception as e: | |
| logger.error(f"Model prediction error: {e}") | |
| raise HTTPException(status_code=500, detail="Prediction failed") | |
| return JSONResponse(content={ | |
| "results": results, | |
| "timestamp": datetime.now().isoformat(), | |
| "image_size": image.size | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| # Additional endpoint to warm up the model | |
| async def warmup(): | |
| """Endpoint to warm up the model""" | |
| global pipe | |
| if pipe is None: | |
| success = load_model() | |
| return {"status": "loaded" if success else "failed"} | |
| return {"status": "already_loaded"} |