temple / app.py
Fasika
prediction
39c0e94
raw
history blame
1.33 kB
from fastapi import FastAPI, HTTPException
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
app = FastAPI()
# Initialize the model and tokenizer once on app startup
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
@app.get("/")
def greet_json():
return {"message": "Welcome to the sentiment analysis API!"}
@app.post("/predict")
async def predict(sequences: list[str]):
if not sequences:
raise HTTPException(status_code=400, detail="No sequences provided.")
# Tokenize input
tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
# Get model predictions
with torch.no_grad(): # avoid tracking gradients for inference
outputs = model(**tokens)
# Get predicted class and scores
scores = outputs.logits.softmax(dim=-1).tolist()
predictions = scores.index(max(score) for score in scores)
response = []
for i, seq in enumerate(sequences):
response.append({
"sequence": seq,
"prediction": int(predictions[i]), # Assuming binary classification
"score": scores[i]
})
return {"results": response}