|
|
from fastapi import FastAPI, HTTPException |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english" |
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint) |
|
|
|
|
|
@app.get("/") |
|
|
def greet_json(): |
|
|
return {"message": "Welcome to the sentiment analysis API!"} |
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict(sequences: list[str]): |
|
|
if not sequences: |
|
|
raise HTTPException(status_code=400, detail="No sequences provided.") |
|
|
|
|
|
|
|
|
tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**tokens) |
|
|
|
|
|
|
|
|
scores = outputs.logits.softmax(dim=-1).tolist() |
|
|
predictions = scores.index(max(score) for score in scores) |
|
|
|
|
|
response = [] |
|
|
for i, seq in enumerate(sequences): |
|
|
response.append({ |
|
|
"sequence": seq, |
|
|
"prediction": int(predictions[i]), |
|
|
"score": scores[i] |
|
|
}) |
|
|
|
|
|
return {"results": response} |
|
|
|