File size: 2,641 Bytes
54bcbcc
3334dee
54bcbcc
4e74751
54bcbcc
 
 
 
2a91b49
54bcbcc
 
 
4e6f5bf
4e74751
54bcbcc
 
 
 
 
 
 
 
 
4e6f5bf
 
4e74751
 
54bcbcc
 
 
3334dee
 
 
 
 
 
 
 
54bcbcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e6f5bf
 
 
 
 
 
54bcbcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers.pipelines import pipeline
import os

os.environ["HF_HOME"] = "/tmp" 

SPAM_MODEL = "cjell/spam-detector-roberta"
TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
SENTIMENT_MODEL =  "nlptown/bert-base-multilingual-uncased-sentiment"
NSFW_MODEL = "michellejieli/NSFW_text_classifier"
HATE_MODEL = "facebook/roberta-hate-speech-dynabench-r4-target"
IMAGE_MODEL = "Falconsai/nsfw_image_detection"

spam = pipeline("text-classification", model=SPAM_MODEL)

toxic = pipeline("text-classification", model=TOXIC_MODEL)

sentiment = pipeline("text-classification", model = SENTIMENT_MODEL)

nsfw = pipeline("text-classification", model = NSFW_MODEL)

hate = pipeline("text-classification", model = HATE_MODEL)

image = pipeline("image-classification", model = IMAGE_MODEL)


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def root():
    return {"status": "ok"}

class Query(BaseModel):
    text: str

@app.post("/spam")
def predict_spam(query: Query):
    result = spam(query.text)[0]
    return {"label": result["label"], "score": result["score"]}

@app.post("/toxic")
def predict_toxic(query: Query):
    result = toxic(query.text)[0]
    return {"label": result["label"], "score": result["score"]}

@app.post("/sentiment")
def predict_sentiment(query: Query):
    result = sentiment(query.text)[0]
    return {"label": result["label"], "score": result["score"]}

@app.post("/nsfw")
def predict_nsfw(query: Query):
    result = nsfw(query.text)[0]
    return {"label": result["label"], "score": result["score"]}

@app.post("/hate")
def predict_hate(query: Query):
    result = hate(query.text)[0]
    return {"label": result["label"], "score": result["score"]}


@app.get("/health")
def health_check():

    status = {
        "server": "running",
        "models": {}
    }

    models = {
        "spam": (SPAM_MODEL, spam),
        "toxic": (TOXIC_MODEL, toxic),
        "sentiment": (SENTIMENT_MODEL, sentiment),
        "nsfw": (NSFW_MODEL, nsfw),
    }

    for key, (model_name, model_pipeline) in models.items():
        try:
            model_pipeline("test")
            status["models"][key] = {
                "model_name": model_name,
                "status": "running"
            }
        except Exception as e:
            status["models"][key] = {
                "model_name": model_name,
                "status": f"error: {str(e)}"
            }

    return status