""" FastAPI inference service for waste classification Provides REST API for predictions, feedback collection, and retraining """ from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from pathlib import Path import base64 from datetime import datetime import json import sys import os # Add ML directory to path sys.path.append(str(Path(__file__).parent.parent)) from ml.predict import WasteClassifier from ml.retrain import retrain_model app = FastAPI( title="AI Waste Segregation API", description="ML inference service for waste classification", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Configure appropriately for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global classifier instance classifier = None MODEL_PATH = Path(__file__).parent.parent / "ml" / "models" / "best_model.pth" RETRAINING_DIR = Path(__file__).parent.parent / "ml" / "data" / "retraining" class PredictionRequest(BaseModel): image: str # Base64 encoded image class PredictionResponse(BaseModel): category: str confidence: float probabilities: dict timestamp: int class FeedbackRequest(BaseModel): image: str predicted_category: str corrected_category: str confidence: float class FeedbackResponse(BaseModel): status: str message: str saved_path: str @app.on_event("startup") async def startup_event(): """Load ML model on startup""" global classifier if not MODEL_PATH.exists(): print(f"Warning: Model not found at {MODEL_PATH}") print("Please train a model first using: python ml/train.py") return try: classifier = WasteClassifier(str(MODEL_PATH)) print(f"Model loaded successfully from {MODEL_PATH}") except Exception as e: print(f"Error loading model: {e}") @app.get("/") async def root(): """Health check endpoint""" return { "status": "online", "service": "AI Waste Segregation API", "model_loaded": classifier is not None, "version": "1.0.0" } @app.get("/health") async def health(): """Detailed health check""" return { "status": "healthy", "model_loaded": classifier is not None, "model_path": str(MODEL_PATH), "timestamp": datetime.now().isoformat() } @app.post("/predict", response_model=PredictionResponse) async def predict(request: PredictionRequest): """ Predict waste category from image Args: request: PredictionRequest with base64 encoded image Returns: PredictionResponse with category, confidence, and probabilities """ if classifier is None: raise HTTPException( status_code=503, detail="Model not loaded. Please train a model first." ) try: # Perform prediction result = classifier.predict(request.image) return PredictionResponse( category=result['category'], confidence=result['confidence'], probabilities=result['probabilities'], timestamp=result['timestamp'] ) except Exception as e: print(f"Prediction error: {e}") raise HTTPException( status_code=500, detail=f"Prediction failed: {str(e)}" ) @app.post("/feedback", response_model=FeedbackResponse) async def save_feedback(request: FeedbackRequest): """ Save user feedback for continuous learning Args: request: FeedbackRequest with image and corrected category Returns: FeedbackResponse with save status """ try: # Create retraining directory for corrected category category_dir = RETRAINING_DIR / request.corrected_category category_dir.mkdir(parents=True, exist_ok=True) # Generate unique filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") filename = f"feedback_{timestamp}.jpg" filepath = category_dir / filename # Decode and save image if request.image.startswith('data:image'): image_data = request.image.split(',')[1] else: image_data = request.image image_bytes = base64.b64decode(image_data) with open(filepath, 'wb') as f: f.write(image_bytes) # Save metadata metadata = { 'timestamp': timestamp, 'predicted_category': request.predicted_category, 'corrected_category': request.corrected_category, 'confidence': request.confidence, 'saved_at': datetime.now().isoformat() } metadata_path = category_dir / f"feedback_{timestamp}.json" with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"Feedback saved: {request.predicted_category} -> {request.corrected_category}") return FeedbackResponse( status="success", message="Feedback saved for retraining", saved_path=str(filepath) ) except Exception as e: print(f"Feedback save error: {e}") raise HTTPException( status_code=500, detail=f"Failed to save feedback: {str(e)}" ) @app.post("/retrain") async def trigger_retrain(background_tasks: BackgroundTasks): """ Trigger model retraining with accumulated feedback Runs as background task to avoid timeout """ # Check if there's feedback to retrain on if not RETRAINING_DIR.exists(): raise HTTPException( status_code=400, detail="No feedback data available for retraining" ) feedback_count = sum(1 for _ in RETRAINING_DIR.rglob('*.jpg')) if feedback_count == 0: raise HTTPException( status_code=400, detail="No feedback samples found for retraining" ) # Add retraining to background tasks background_tasks.add_task(retrain_model) return { "status": "started", "message": f"Retraining initiated with {feedback_count} new samples", "feedback_count": feedback_count } @app.get("/retrain/status") async def get_retrain_status(): """Get retraining history and status""" log_file = Path(__file__).parent.parent / "ml" / "models" / "retraining_log.json" if not log_file.exists(): return { "status": "no_history", "message": "No retraining history available", "events": [] } try: with open(log_file, 'r') as f: log = json.load(f) return { "status": "success", "total_retrains": len(log), "events": log[-10:], # Last 10 events "latest": log[-1] if log else None } except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to read retraining log: {str(e)}" ) @app.get("/stats") async def get_stats(): """Get system statistics""" # Count feedback samples feedback_count = 0 feedback_by_category = {} if RETRAINING_DIR.exists(): for category in classifier.categories if classifier else []: category_dir = RETRAINING_DIR / category if category_dir.exists(): count = len(list(category_dir.glob('*.jpg'))) feedback_by_category[category] = count feedback_count += count return { "model_loaded": classifier is not None, "categories": classifier.categories if classifier else [], "feedback_samples": feedback_count, "feedback_by_category": feedback_by_category, "model_path": str(MODEL_PATH), "model_exists": MODEL_PATH.exists() } if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) uvicorn.run( "inference_service:app", host="0.0.0.0", port=port, reload=True )