from fastapi import FastAPI, HTTPException from pydantic import BaseModel import numpy as np import pickle from huggingface_hub import hf_hub_download import os from typing import List app = FastAPI(title="Headache Predictor API") # Global variables for model and threshold clf = None threshold = 0.5 # --- Pydantic Models --- class SinglePredictionRequest(BaseModel): features: List[float] class SinglePredictionResponse(BaseModel): prediction: int probability: float class BatchPredictionRequest(BaseModel): instances: List[List[float]] class DayPrediction(BaseModel): day: int prediction: int probability: float class BatchPredictionResponse(BaseModel): predictions: List[DayPrediction] # --- Startup Event --- @app.on_event("startup") async def load_model(): global clf, threshold try: cache_dir = "/tmp/hf_cache" os.makedirs(cache_dir, exist_ok=True) hf_token = os.environ.get("HF_TOKEN") model_path = hf_hub_download( repo_id="emp-admin/headache-predictor-xgboost", filename="model.pkl", cache_dir=cache_dir, token=hf_token ) with open(model_path, "rb") as f: model_data = pickle.load(f) if isinstance(model_data, dict): clf = model_data["model"] # Load threshold if available, otherwise default to 0.5 threshold = float(model_data.get("optimal_threshold", 0.5)) print(f"✅ Model loaded (optimal_threshold={threshold})") else: clf = model_data threshold = 0.5 print("✅ Model loaded (threshold=0.5 default)") except Exception as e: print(f"❌ Error loading model: {e}") import traceback traceback.print_exc() # --- Endpoints --- @app.get("/") def read_root(): return { "message": "Headache Predictor API", "status": "running", "endpoints": { "predict": "/predict - Single day prediction", "predict_batch": "/predict/batch - 7-day forecast", "health": "/health" }, "examples": { "single": { "url": "/predict", # Example shortened for brevity in display "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]} }, "batch": { "url": "/predict/batch", "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"]]} } } } @app.get("/health") def health_check(): return { "status": "healthy", "model_loaded": clf is not None } @app.post("/predict", response_model=SinglePredictionResponse) def predict(request: SinglePredictionRequest): """Predict headache risk for a single day""" if clf is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Convert input to numpy array features = np.array(request.features).reshape(1, -1) # Get probability array for both classes [prob_0, prob_1] prob_array = clf.predict_proba(features)[0] # Always return probability of headache (class 1) headache_probability = float(prob_array[1]) # Make prediction using the loaded threshold prediction = 1 if headache_probability >= threshold else 0 return SinglePredictionResponse( prediction=int(prediction), probability=headache_probability ) except Exception as e: raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}") @app.post("/predict/batch", response_model=BatchPredictionResponse) def predict_batch(request: BatchPredictionRequest): """Predict headache risk for multiple days (batch)""" if clf is None: raise HTTPException(status_code=503, detail="Model not loaded") try: X = np.array(request.instances, dtype=float) if X.ndim != 2: raise ValueError(f"Expected 2D array, got shape {X.shape}") probas = clf.predict_proba(X)[:, 1] # class-1 prob # Use the global threshold preds = (probas >= threshold).astype(int) results = [ DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i])) for i in range(len(probas)) ] return BatchPredictionResponse(predictions=results) except Exception as e: raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")