Fasika commited on
Commit
39c0e94
·
1 Parent(s): 4e68ef5

prediction

Browse files
Files changed (2) hide show
  1. app.py +35 -2
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,7 +1,40 @@
1
- from fastapi import FastAPI
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
  app = FastAPI()
6
 
7
+ # Initialize the model and tokenizer once on app startup
8
+ checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
9
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
10
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
11
+
12
  @app.get("/")
13
  def greet_json():
14
+ return {"message": "Welcome to the sentiment analysis API!"}
15
+
16
+ @app.post("/predict")
17
+ async def predict(sequences: list[str]):
18
+ if not sequences:
19
+ raise HTTPException(status_code=400, detail="No sequences provided.")
20
+
21
+ # Tokenize input
22
+ tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
23
+
24
+ # Get model predictions
25
+ with torch.no_grad(): # avoid tracking gradients for inference
26
+ outputs = model(**tokens)
27
+
28
+ # Get predicted class and scores
29
+ scores = outputs.logits.softmax(dim=-1).tolist()
30
+ predictions = scores.index(max(score) for score in scores)
31
+
32
+ response = []
33
+ for i, seq in enumerate(sequences):
34
+ response.append({
35
+ "sequence": seq,
36
+ "prediction": int(predictions[i]), # Assuming binary classification
37
+ "score": scores[i]
38
+ })
39
+
40
+ return {"results": response}
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  fastapi
2
- uvicorn[standard]
 
 
 
1
  fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ transformers