junaid17's picture
Upload 8 files
3d14250 verified
from fastapi import FastAPI, HTTPException
from predict_helper import predict
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title='Customer Segmentation', version='1.0')
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
class BaseInput(BaseModel):
Age : int = Field(..., ge=18, le=100, description="Customer age between 18 and 100")
Income : int = Field(..., ge=0, le=200000, description="Income between 0 and 200000")
Total_Spendings : int = Field(..., ge=0, le=5000, description="Total spendings (sum of purchases)")
NumWebPurchases : int = Field(..., ge=0, le=100, description="Number of web purchases")
NumStorePurchases : int = Field(..., ge=0, le=100, description="Number of store purchases")
NumWebVisitsMonth : int = Field(..., ge=0, le=50, description="Number of web visits per month")
Recency : int = Field(..., ge=0, le=365, description="Recency (days since last purchase)")
class BaseOutput(BaseModel):
cluster_id : int
cluster_name : str
description : str
recommendation : str
@app.get('/')
def Status():
return {'message' : 'The api server is live and working'}
@app.post('/predict', response_model=BaseOutput)
def predict_segment(input_data: BaseInput):
try:
result = predict(
age=input_data.Age,
income=input_data.Income,
total_spending=input_data.Total_Spendings,
num_web_purchases=input_data.NumWebPurchases,
num_store_purchases=input_data.NumStorePurchases,
num_web_visits=input_data.NumWebVisitsMonth,
recency=input_data.Recency
)
return result
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error while predicting the output: {e}")