emp-admin commited on
Commit
8b69317
·
verified ·
1 Parent(s): 47e0ff5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -112
app.py CHANGED
@@ -1,165 +1,77 @@
1
-
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- import xgboost as xgb
5
  import numpy as np
6
  import pickle
7
  from huggingface_hub import hf_hub_download
8
  import os
9
- import sys
10
- from typing import List, Union
11
 
12
  app = FastAPI(title="Headache Predictor API")
13
 
14
- # Load model at startup
15
- model = None
16
 
17
  @app.on_event("startup")
18
  async def load_model():
19
- global model
20
  try:
21
- # Set cache directory to writable location
22
  cache_dir = "/tmp/hf_cache"
23
  os.makedirs(cache_dir, exist_ok=True)
24
 
25
- # Get HF token from environment (set as Space secret)
26
  hf_token = os.environ.get("HF_TOKEN")
27
 
28
  model_path = hf_hub_download(
29
  repo_id="emp-admin/headache-predictor-xgboost",
30
  filename="model.pkl",
31
  cache_dir=cache_dir,
32
- token=hf_token # Use token for private repo access
33
  )
34
 
35
- with open(model_path, 'rb') as f:
36
  model_data = pickle.load(f)
37
 
38
- # Handle both dict format and raw model
39
  if isinstance(model_data, dict):
40
- model = model_data['model']
41
- print(f"✅ Model loaded successfully (threshold: {model_data.get('optimal_threshold', 0.5)})")
 
42
  else:
43
- model = model_data
44
- print("✅ Model loaded successfully")
 
45
 
46
  except Exception as e:
47
  print(f"❌ Error loading model: {e}")
48
  import traceback
49
  traceback.print_exc()
50
 
51
- class SinglePredictionRequest(BaseModel):
52
- features: List[float]
53
-
54
  class BatchPredictionRequest(BaseModel):
55
  instances: List[List[float]]
56
 
57
  class DayPrediction(BaseModel):
58
  day: int
59
  prediction: int
60
- probability: float # Probability of HEADACHE (class 1), regardless of prediction
61
-
62
- class SinglePredictionResponse(BaseModel):
63
- prediction: int
64
- probability: float # Probability of HEADACHE (class 1), regardless of prediction
65
 
66
  class BatchPredictionResponse(BaseModel):
67
  predictions: List[DayPrediction]
68
 
69
- @app.get("/")
70
- def read_root():
71
- return {
72
- "message": "Headache Predictor API",
73
- "status": "running",
74
- "endpoints": {
75
- "predict": "/predict - Single day prediction",
76
- "predict_batch": "/predict/batch - 7-day forecast",
77
- "health": "/health"
78
- },
79
- "examples": {
80
- "single": {
81
- "url": "/predict",
82
- "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
83
- },
84
- "batch": {
85
- "url": "/predict/batch",
86
- "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"], "..."]}
87
- }
88
- }
89
- }
90
-
91
- @app.get("/health")
92
- def health_check():
93
- return {
94
- "status": "healthy",
95
- "model_loaded": model is not None
96
- }
97
-
98
- @app.post("/predict", response_model=SinglePredictionResponse)
99
- def predict(request: SinglePredictionRequest):
100
- """Predict headache risk for a single day"""
101
- if model is None:
102
- raise HTTPException(status_code=503, detail="Model not loaded")
103
-
104
- try:
105
- # Convert input to numpy array
106
- features = np.array(request.features).reshape(1, -1)
107
-
108
- # Get probability array for both classes
109
- prob_array = model.predict_proba(features)[0]
110
-
111
- # Always return probability of headache (class 1)
112
- headache_probability = float(prob_array[1])
113
-
114
- # Make prediction using threshold if available
115
- if isinstance(model, dict) and 'optimal_threshold' in model:
116
- threshold = model['optimal_threshold']
117
- prediction = 1 if headache_probability >= threshold else 0
118
- else:
119
- prediction = model.predict(features)[0]
120
-
121
- return SinglePredictionResponse(
122
- prediction=int(prediction),
123
- probability=headache_probability
124
- )
125
- except Exception as e:
126
- raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")
127
-
128
  @app.post("/predict/batch", response_model=BatchPredictionResponse)
129
  def predict_batch(request: BatchPredictionRequest):
130
- """Predict headache risk for multiple days (7-day forecast)"""
131
- if model is None:
132
  raise HTTPException(status_code=503, detail="Model not loaded")
133
 
134
  try:
135
- # Convert all instances to numpy array
136
- features = np.array(request.instances)
137
-
138
- if features.ndim != 2:
139
- raise ValueError(f"Expected 2D array, got shape {features.shape}")
140
-
141
- # Get probabilities for all days
142
- probabilities = model.predict_proba(features)
143
-
144
- # Format results
145
- results = []
146
- for i, prob_array in enumerate(probabilities, 1):
147
- # Always use probability of headache (class 1)
148
- headache_probability = float(prob_array[1])
149
-
150
- # Make prediction using threshold if available
151
- if isinstance(model, dict) and 'optimal_threshold' in model:
152
- threshold = model['optimal_threshold']
153
- prediction = 1 if headache_probability >= threshold else 0
154
- else:
155
- prediction = model.predict(features[i-1:i])[0]
156
-
157
- results.append(DayPrediction(
158
- day=i,
159
- prediction=int(prediction),
160
- probability=headache_probability
161
- ))
162
 
 
 
 
 
163
  return BatchPredictionResponse(predictions=results)
164
 
165
  except Exception as e:
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
3
  import numpy as np
4
  import pickle
5
  from huggingface_hub import hf_hub_download
6
  import os
7
+ from typing import List
 
8
 
9
  app = FastAPI(title="Headache Predictor API")
10
 
11
+ clf = None
12
+ threshold = 0.5
13
 
14
  @app.on_event("startup")
15
  async def load_model():
16
+ global clf, threshold
17
  try:
 
18
  cache_dir = "/tmp/hf_cache"
19
  os.makedirs(cache_dir, exist_ok=True)
20
 
 
21
  hf_token = os.environ.get("HF_TOKEN")
22
 
23
  model_path = hf_hub_download(
24
  repo_id="emp-admin/headache-predictor-xgboost",
25
  filename="model.pkl",
26
  cache_dir=cache_dir,
27
+ token=hf_token
28
  )
29
 
30
+ with open(model_path, "rb") as f:
31
  model_data = pickle.load(f)
32
 
 
33
  if isinstance(model_data, dict):
34
+ clf = model_data["model"]
35
+ threshold = float(model_data.get("optimal_threshold", 0.5))
36
+ print(f"✅ Model loaded (optimal_threshold={threshold})")
37
  else:
38
+ clf = model_data
39
+ threshold = 0.5
40
+ print("✅ Model loaded (threshold=0.5 default)")
41
 
42
  except Exception as e:
43
  print(f"❌ Error loading model: {e}")
44
  import traceback
45
  traceback.print_exc()
46
 
 
 
 
47
  class BatchPredictionRequest(BaseModel):
48
  instances: List[List[float]]
49
 
50
  class DayPrediction(BaseModel):
51
  day: int
52
  prediction: int
53
+ probability: float
 
 
 
 
54
 
55
  class BatchPredictionResponse(BaseModel):
56
  predictions: List[DayPrediction]
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @app.post("/predict/batch", response_model=BatchPredictionResponse)
59
  def predict_batch(request: BatchPredictionRequest):
60
+ if clf is None:
 
61
  raise HTTPException(status_code=503, detail="Model not loaded")
62
 
63
  try:
64
+ X = np.array(request.instances, dtype=float)
65
+ if X.ndim != 2:
66
+ raise ValueError(f"Expected 2D array, got shape {X.shape}")
67
+
68
+ probas = clf.predict_proba(X)[:, 1] # class-1 prob
69
+ preds = (probas >= threshold).astype(int)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ results = [
72
+ DayPrediction(day=i+1, prediction=int(preds[i]), probability=float(probas[i]))
73
+ for i in range(len(probas))
74
+ ]
75
  return BatchPredictionResponse(predictions=results)
76
 
77
  except Exception as e: