emp-admin's picture
Update app.py
67f34aa verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
import os
from typing import List
app = FastAPI(title="Headache Predictor API")
# Global variables for model and threshold
clf = None
threshold = 0.5
# --- Pydantic Models ---
class SinglePredictionRequest(BaseModel):
features: List[float]
class SinglePredictionResponse(BaseModel):
prediction: int
probability: float
class BatchPredictionRequest(BaseModel):
instances: List[List[float]]
class DayPrediction(BaseModel):
day: int
prediction: int
probability: float
class BatchPredictionResponse(BaseModel):
predictions: List[DayPrediction]
# --- Startup Event ---
@app.on_event("startup")
async def load_model():
global clf, threshold
try:
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
repo_id="emp-admin/headache-predictor-xgboost",
filename="model.pkl",
cache_dir=cache_dir,
token=hf_token
)
with open(model_path, "rb") as f:
model_data = pickle.load(f)
if isinstance(model_data, dict):
clf = model_data["model"]
# Load threshold if available, otherwise default to 0.5
threshold = float(model_data.get("optimal_threshold", 0.5))
print(f"✅ Model loaded (optimal_threshold={threshold})")
else:
clf = model_data
threshold = 0.5
print("✅ Model loaded (threshold=0.5 default)")
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
# --- Endpoints ---
@app.get("/")
def read_root():
return {
"message": "Headache Predictor API",
"status": "running",
"endpoints": {
"predict": "/predict - Single day prediction",
"predict_batch": "/predict/batch - 7-day forecast",
"health": "/health"
},
"examples": {
"single": {
"url": "/predict",
# Example shortened for brevity in display
"body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
},
"batch": {
"url": "/predict/batch",
"body": {"instances": [["array of 37 features for day 1"], ["array for day 2"]]}
}
}
}
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": clf is not None
}
@app.post("/predict", response_model=SinglePredictionResponse)
def predict(request: SinglePredictionRequest):
"""Predict headache risk for a single day"""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Convert input to numpy array
features = np.array(request.features).reshape(1, -1)
# Get probability array for both classes [prob_0, prob_1]
prob_array = clf.predict_proba(features)[0]
# Always return probability of headache (class 1)
headache_probability = float(prob_array[1])
# Make prediction using the loaded threshold
prediction = 1 if headache_probability >= threshold else 0
return SinglePredictionResponse(
prediction=int(prediction),
probability=headache_probability
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")
@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
"""Predict headache risk for multiple days (batch)"""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
X = np.array(request.instances, dtype=float)
if X.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {X.shape}")
probas = clf.predict_proba(X)[:, 1] # class-1 prob
# Use the global threshold
preds = (probas >= threshold).astype(int)
results = [
DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i]))
for i in range(len(probas))
]
return BatchPredictionResponse(predictions=results)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")