File size: 4,695 Bytes
6e2eb3e
 
 
 
 
75e8c0c
8b69317
6e2eb3e
 
 
67f34aa
8b69317
 
6e2eb3e
67f34aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e2eb3e
 
8b69317
6e2eb3e
75e8c0c
 
 
d1a95be
 
6e2eb3e
 
75e8c0c
d1a95be
8b69317
6e2eb3e
ab6fe11
8b69317
33f5254
 
 
8b69317
67f34aa
8b69317
 
33f5254
8b69317
 
 
ab6fe11
6e2eb3e
 
75e8c0c
 
6e2eb3e
67f34aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96b173e
67f34aa
 
 
6e2eb3e
67f34aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96b173e
 
 
67f34aa
8b69317
96b173e
 
 
8b69317
 
 
 
 
67f34aa
 
8b69317
96b173e
8b69317
 
 
 
96b173e
 
 
67f34aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
import os
from typing import List

app = FastAPI(title="Headache Predictor API")

# Global variables for model and threshold
clf = None
threshold = 0.5

# --- Pydantic Models ---

class SinglePredictionRequest(BaseModel):
    features: List[float]

class SinglePredictionResponse(BaseModel):
    prediction: int
    probability: float

class BatchPredictionRequest(BaseModel):
    instances: List[List[float]]

class DayPrediction(BaseModel):
    day: int
    prediction: int
    probability: float

class BatchPredictionResponse(BaseModel):
    predictions: List[DayPrediction]

# --- Startup Event ---

@app.on_event("startup")
async def load_model():
    global clf, threshold
    try:
        cache_dir = "/tmp/hf_cache"
        os.makedirs(cache_dir, exist_ok=True)

        hf_token = os.environ.get("HF_TOKEN")

        model_path = hf_hub_download(
            repo_id="emp-admin/headache-predictor-xgboost",
            filename="model.pkl",
            cache_dir=cache_dir,
            token=hf_token
        )

        with open(model_path, "rb") as f:
            model_data = pickle.load(f)

        if isinstance(model_data, dict):
            clf = model_data["model"]
            # Load threshold if available, otherwise default to 0.5
            threshold = float(model_data.get("optimal_threshold", 0.5))
            print(f"✅ Model loaded (optimal_threshold={threshold})")
        else:
            clf = model_data
            threshold = 0.5
            print("✅ Model loaded (threshold=0.5 default)")

    except Exception as e:
        print(f"❌ Error loading model: {e}")
        import traceback
        traceback.print_exc()

# --- Endpoints ---

@app.get("/")
def read_root():
    return {
        "message": "Headache Predictor API",
        "status": "running",
        "endpoints": {
            "predict": "/predict - Single day prediction",
            "predict_batch": "/predict/batch - 7-day forecast",
            "health": "/health"
        },
        "examples": {
            "single": {
                "url": "/predict",
                # Example shortened for brevity in display
                "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
            },
            "batch": {
                "url": "/predict/batch",
                "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"]]}
            }
        }
    }

@app.get("/health")
def health_check():
    return {
        "status": "healthy",
        "model_loaded": clf is not None
    }

@app.post("/predict", response_model=SinglePredictionResponse)
def predict(request: SinglePredictionRequest):
    """Predict headache risk for a single day"""
    if clf is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        # Convert input to numpy array
        features = np.array(request.features).reshape(1, -1)

        # Get probability array for both classes [prob_0, prob_1]
        prob_array = clf.predict_proba(features)[0]
        
        # Always return probability of headache (class 1)
        headache_probability = float(prob_array[1])

        # Make prediction using the loaded threshold
        prediction = 1 if headache_probability >= threshold else 0

        return SinglePredictionResponse(
            prediction=int(prediction),
            probability=headache_probability
        )
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")

@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
    """Predict headache risk for multiple days (batch)"""
    if clf is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        X = np.array(request.instances, dtype=float)
        if X.ndim != 2:
            raise ValueError(f"Expected 2D array, got shape {X.shape}")

        probas = clf.predict_proba(X)[:, 1]   # class-1 prob
        
        # Use the global threshold
        preds = (probas >= threshold).astype(int)

        results = [
            DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i]))
            for i in range(len(probas))
        ]
        return BatchPredictionResponse(predictions=results)

    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")