| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import numpy as np |
| | import pickle |
| | from huggingface_hub import hf_hub_download |
| | import os |
| | from typing import List |
| |
|
| | app = FastAPI(title="Headache Predictor API") |
| |
|
| | |
| | clf = None |
| | threshold = 0.5 |
| |
|
| | |
| |
|
| | class SinglePredictionRequest(BaseModel): |
| | features: List[float] |
| |
|
| | class SinglePredictionResponse(BaseModel): |
| | prediction: int |
| | probability: float |
| |
|
| | class BatchPredictionRequest(BaseModel): |
| | instances: List[List[float]] |
| |
|
| | class DayPrediction(BaseModel): |
| | day: int |
| | prediction: int |
| | probability: float |
| |
|
| | class BatchPredictionResponse(BaseModel): |
| | predictions: List[DayPrediction] |
| |
|
| | |
| |
|
| | @app.on_event("startup") |
| | async def load_model(): |
| | global clf, threshold |
| | try: |
| | cache_dir = "/tmp/hf_cache" |
| | os.makedirs(cache_dir, exist_ok=True) |
| |
|
| | hf_token = os.environ.get("HF_TOKEN") |
| |
|
| | model_path = hf_hub_download( |
| | repo_id="emp-admin/headache-predictor-xgboost", |
| | filename="model.pkl", |
| | cache_dir=cache_dir, |
| | token=hf_token |
| | ) |
| |
|
| | with open(model_path, "rb") as f: |
| | model_data = pickle.load(f) |
| |
|
| | if isinstance(model_data, dict): |
| | clf = model_data["model"] |
| | |
| | threshold = float(model_data.get("optimal_threshold", 0.5)) |
| | print(f"✅ Model loaded (optimal_threshold={threshold})") |
| | else: |
| | clf = model_data |
| | threshold = 0.5 |
| | print("✅ Model loaded (threshold=0.5 default)") |
| |
|
| | except Exception as e: |
| | print(f"❌ Error loading model: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | |
| |
|
| | @app.get("/") |
| | def read_root(): |
| | return { |
| | "message": "Headache Predictor API", |
| | "status": "running", |
| | "endpoints": { |
| | "predict": "/predict - Single day prediction", |
| | "predict_batch": "/predict/batch - 7-day forecast", |
| | "health": "/health" |
| | }, |
| | "examples": { |
| | "single": { |
| | "url": "/predict", |
| | |
| | "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]} |
| | }, |
| | "batch": { |
| | "url": "/predict/batch", |
| | "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"]]} |
| | } |
| | } |
| | } |
| |
|
| | @app.get("/health") |
| | def health_check(): |
| | return { |
| | "status": "healthy", |
| | "model_loaded": clf is not None |
| | } |
| |
|
| | @app.post("/predict", response_model=SinglePredictionResponse) |
| | def predict(request: SinglePredictionRequest): |
| | """Predict headache risk for a single day""" |
| | if clf is None: |
| | raise HTTPException(status_code=503, detail="Model not loaded") |
| |
|
| | try: |
| | |
| | features = np.array(request.features).reshape(1, -1) |
| |
|
| | |
| | prob_array = clf.predict_proba(features)[0] |
| | |
| | |
| | headache_probability = float(prob_array[1]) |
| |
|
| | |
| | prediction = 1 if headache_probability >= threshold else 0 |
| |
|
| | return SinglePredictionResponse( |
| | prediction=int(prediction), |
| | probability=headache_probability |
| | ) |
| | except Exception as e: |
| | raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}") |
| |
|
| | @app.post("/predict/batch", response_model=BatchPredictionResponse) |
| | def predict_batch(request: BatchPredictionRequest): |
| | """Predict headache risk for multiple days (batch)""" |
| | if clf is None: |
| | raise HTTPException(status_code=503, detail="Model not loaded") |
| |
|
| | try: |
| | X = np.array(request.instances, dtype=float) |
| | if X.ndim != 2: |
| | raise ValueError(f"Expected 2D array, got shape {X.shape}") |
| |
|
| | probas = clf.predict_proba(X)[:, 1] |
| | |
| | |
| | preds = (probas >= threshold).astype(int) |
| |
|
| | results = [ |
| | DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i])) |
| | for i in range(len(probas)) |
| | ] |
| | return BatchPredictionResponse(predictions=results) |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}") |