Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI inference service for waste classification | |
| Provides REST API for predictions, feedback collection, and retraining | |
| """ | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from pathlib import Path | |
| import base64 | |
| from datetime import datetime | |
| import json | |
| import sys | |
| import os | |
| # Add ML directory to path | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| from ml.predict import WasteClassifier | |
| from ml.retrain import retrain_model | |
| app = FastAPI( | |
| title="AI Waste Segregation API", | |
| description="ML inference service for waste classification", | |
| version="1.0.0" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global classifier instance | |
| classifier = None | |
| MODEL_PATH = Path(__file__).parent.parent / "ml" / "models" / "best_model.pth" | |
| RETRAINING_DIR = Path(__file__).parent.parent / "ml" / "data" / "retraining" | |
| class PredictionRequest(BaseModel): | |
| image: str # Base64 encoded image | |
| class PredictionResponse(BaseModel): | |
| category: str | |
| confidence: float | |
| probabilities: dict | |
| timestamp: int | |
| class FeedbackRequest(BaseModel): | |
| image: str | |
| predicted_category: str | |
| corrected_category: str | |
| confidence: float | |
| class FeedbackResponse(BaseModel): | |
| status: str | |
| message: str | |
| saved_path: str | |
| async def startup_event(): | |
| """Load ML model on startup""" | |
| global classifier | |
| if not MODEL_PATH.exists(): | |
| print(f"Warning: Model not found at {MODEL_PATH}") | |
| print("Please train a model first using: python ml/train.py") | |
| return | |
| try: | |
| classifier = WasteClassifier(str(MODEL_PATH)) | |
| print(f"Model loaded successfully from {MODEL_PATH}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "online", | |
| "service": "AI Waste Segregation API", | |
| "model_loaded": classifier is not None, | |
| "version": "1.0.0" | |
| } | |
| async def health(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": classifier is not None, | |
| "model_path": str(MODEL_PATH), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def predict(request: PredictionRequest): | |
| """ | |
| Predict waste category from image | |
| Args: | |
| request: PredictionRequest with base64 encoded image | |
| Returns: | |
| PredictionResponse with category, confidence, and probabilities | |
| """ | |
| if classifier is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model not loaded. Please train a model first." | |
| ) | |
| try: | |
| # Perform prediction | |
| result = classifier.predict(request.image) | |
| return PredictionResponse( | |
| category=result['category'], | |
| confidence=result['confidence'], | |
| probabilities=result['probabilities'], | |
| timestamp=result['timestamp'] | |
| ) | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {str(e)}" | |
| ) | |
| async def save_feedback(request: FeedbackRequest): | |
| """ | |
| Save user feedback for continuous learning | |
| Args: | |
| request: FeedbackRequest with image and corrected category | |
| Returns: | |
| FeedbackResponse with save status | |
| """ | |
| try: | |
| # Create retraining directory for corrected category | |
| category_dir = RETRAINING_DIR / request.corrected_category | |
| category_dir.mkdir(parents=True, exist_ok=True) | |
| # Generate unique filename | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| filename = f"feedback_{timestamp}.jpg" | |
| filepath = category_dir / filename | |
| # Decode and save image | |
| if request.image.startswith('data:image'): | |
| image_data = request.image.split(',')[1] | |
| else: | |
| image_data = request.image | |
| image_bytes = base64.b64decode(image_data) | |
| with open(filepath, 'wb') as f: | |
| f.write(image_bytes) | |
| # Save metadata | |
| metadata = { | |
| 'timestamp': timestamp, | |
| 'predicted_category': request.predicted_category, | |
| 'corrected_category': request.corrected_category, | |
| 'confidence': request.confidence, | |
| 'saved_at': datetime.now().isoformat() | |
| } | |
| metadata_path = category_dir / f"feedback_{timestamp}.json" | |
| with open(metadata_path, 'w') as f: | |
| json.dump(metadata, f, indent=2) | |
| print(f"Feedback saved: {request.predicted_category} -> {request.corrected_category}") | |
| return FeedbackResponse( | |
| status="success", | |
| message="Feedback saved for retraining", | |
| saved_path=str(filepath) | |
| ) | |
| except Exception as e: | |
| print(f"Feedback save error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to save feedback: {str(e)}" | |
| ) | |
| async def trigger_retrain(background_tasks: BackgroundTasks): | |
| """ | |
| Trigger model retraining with accumulated feedback | |
| Runs as background task to avoid timeout | |
| """ | |
| # Check if there's feedback to retrain on | |
| if not RETRAINING_DIR.exists(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No feedback data available for retraining" | |
| ) | |
| feedback_count = sum(1 for _ in RETRAINING_DIR.rglob('*.jpg')) | |
| if feedback_count == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No feedback samples found for retraining" | |
| ) | |
| # Add retraining to background tasks | |
| background_tasks.add_task(retrain_model) | |
| return { | |
| "status": "started", | |
| "message": f"Retraining initiated with {feedback_count} new samples", | |
| "feedback_count": feedback_count | |
| } | |
| async def get_retrain_status(): | |
| """Get retraining history and status""" | |
| log_file = Path(__file__).parent.parent / "ml" / "models" / "retraining_log.json" | |
| if not log_file.exists(): | |
| return { | |
| "status": "no_history", | |
| "message": "No retraining history available", | |
| "events": [] | |
| } | |
| try: | |
| with open(log_file, 'r') as f: | |
| log = json.load(f) | |
| return { | |
| "status": "success", | |
| "total_retrains": len(log), | |
| "events": log[-10:], # Last 10 events | |
| "latest": log[-1] if log else None | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to read retraining log: {str(e)}" | |
| ) | |
| async def get_stats(): | |
| """Get system statistics""" | |
| # Count feedback samples | |
| feedback_count = 0 | |
| feedback_by_category = {} | |
| if RETRAINING_DIR.exists(): | |
| for category in classifier.categories if classifier else []: | |
| category_dir = RETRAINING_DIR / category | |
| if category_dir.exists(): | |
| count = len(list(category_dir.glob('*.jpg'))) | |
| feedback_by_category[category] = count | |
| feedback_count += count | |
| return { | |
| "model_loaded": classifier is not None, | |
| "categories": classifier.categories if classifier else [], | |
| "feedback_samples": feedback_count, | |
| "feedback_by_category": feedback_by_category, | |
| "model_path": str(MODEL_PATH), | |
| "model_exists": MODEL_PATH.exists() | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run( | |
| "inference_service:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=True | |
| ) | |