File size: 4,695 Bytes
6e2eb3e 75e8c0c 8b69317 6e2eb3e 67f34aa 8b69317 6e2eb3e 67f34aa 6e2eb3e 8b69317 6e2eb3e 75e8c0c d1a95be 6e2eb3e 75e8c0c d1a95be 8b69317 6e2eb3e ab6fe11 8b69317 33f5254 8b69317 67f34aa 8b69317 33f5254 8b69317 ab6fe11 6e2eb3e 75e8c0c 6e2eb3e 67f34aa 96b173e 67f34aa 6e2eb3e 67f34aa 96b173e 67f34aa 8b69317 96b173e 8b69317 67f34aa 8b69317 96b173e 8b69317 96b173e 67f34aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
import os
from typing import List
app = FastAPI(title="Headache Predictor API")
# Global variables for model and threshold
clf = None
threshold = 0.5
# --- Pydantic Models ---
class SinglePredictionRequest(BaseModel):
features: List[float]
class SinglePredictionResponse(BaseModel):
prediction: int
probability: float
class BatchPredictionRequest(BaseModel):
instances: List[List[float]]
class DayPrediction(BaseModel):
day: int
prediction: int
probability: float
class BatchPredictionResponse(BaseModel):
predictions: List[DayPrediction]
# --- Startup Event ---
@app.on_event("startup")
async def load_model():
global clf, threshold
try:
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
repo_id="emp-admin/headache-predictor-xgboost",
filename="model.pkl",
cache_dir=cache_dir,
token=hf_token
)
with open(model_path, "rb") as f:
model_data = pickle.load(f)
if isinstance(model_data, dict):
clf = model_data["model"]
# Load threshold if available, otherwise default to 0.5
threshold = float(model_data.get("optimal_threshold", 0.5))
print(f"✅ Model loaded (optimal_threshold={threshold})")
else:
clf = model_data
threshold = 0.5
print("✅ Model loaded (threshold=0.5 default)")
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
# --- Endpoints ---
@app.get("/")
def read_root():
return {
"message": "Headache Predictor API",
"status": "running",
"endpoints": {
"predict": "/predict - Single day prediction",
"predict_batch": "/predict/batch - 7-day forecast",
"health": "/health"
},
"examples": {
"single": {
"url": "/predict",
# Example shortened for brevity in display
"body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
},
"batch": {
"url": "/predict/batch",
"body": {"instances": [["array of 37 features for day 1"], ["array for day 2"]]}
}
}
}
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": clf is not None
}
@app.post("/predict", response_model=SinglePredictionResponse)
def predict(request: SinglePredictionRequest):
"""Predict headache risk for a single day"""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Convert input to numpy array
features = np.array(request.features).reshape(1, -1)
# Get probability array for both classes [prob_0, prob_1]
prob_array = clf.predict_proba(features)[0]
# Always return probability of headache (class 1)
headache_probability = float(prob_array[1])
# Make prediction using the loaded threshold
prediction = 1 if headache_probability >= threshold else 0
return SinglePredictionResponse(
prediction=int(prediction),
probability=headache_probability
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")
@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
"""Predict headache risk for multiple days (batch)"""
if clf is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
X = np.array(request.instances, dtype=float)
if X.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {X.shape}")
probas = clf.predict_proba(X)[:, 1] # class-1 prob
# Use the global threshold
preds = (probas >= threshold).astype(int)
results = [
DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i]))
for i in range(len(probas))
]
return BatchPredictionResponse(predictions=results)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}") |