File size: 981 Bytes
e8a8654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
725063e
e8a8654
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np

# Load the saved model (make sure it's in the same directory or provide the correct path)
model = joblib.load("linear_regression_model.pkl")

# Initialize the FastAPI app
app = FastAPI()

# Define the input schema for predictions
class PredictionInput(BaseModel):
    feature1: float

# Define the prediction endpoint
@app.post("/predict")
def predict(input_data: PredictionInput):
    try:
        # Convert input into model-compatible format (as a 2D array)
        input_features = np.array([[input_data.feature1]])
        prediction = model.predict(input_features)
        return {"prediction": prediction.tolist()}  # Return prediction as a list
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Basic greeting endpoint (optional)
@app.get("/")
def greet_json():
    return {"message": "Welcome to the Linear Regression API!"}