Testing / app.py
Propvia's picture
fixing name
2a91b49
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers.pipelines import pipeline
import os
os.environ["HF_HOME"] = "/tmp"
SPAM_MODEL = "cjell/spam-detector-roberta"
TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier"
SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment"
NSFW_MODEL = "michellejieli/NSFW_text_classifier"
HATE_MODEL = "facebook/roberta-hate-speech-dynabench-r4-target"
IMAGE_MODEL = "Falconsai/nsfw_image_detection"
spam = pipeline("text-classification", model=SPAM_MODEL)
toxic = pipeline("text-classification", model=TOXIC_MODEL)
sentiment = pipeline("text-classification", model = SENTIMENT_MODEL)
nsfw = pipeline("text-classification", model = NSFW_MODEL)
hate = pipeline("text-classification", model = HATE_MODEL)
image = pipeline("image-classification", model = IMAGE_MODEL)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {"status": "ok"}
class Query(BaseModel):
text: str
@app.post("/spam")
def predict_spam(query: Query):
result = spam(query.text)[0]
return {"label": result["label"], "score": result["score"]}
@app.post("/toxic")
def predict_toxic(query: Query):
result = toxic(query.text)[0]
return {"label": result["label"], "score": result["score"]}
@app.post("/sentiment")
def predict_sentiment(query: Query):
result = sentiment(query.text)[0]
return {"label": result["label"], "score": result["score"]}
@app.post("/nsfw")
def predict_nsfw(query: Query):
result = nsfw(query.text)[0]
return {"label": result["label"], "score": result["score"]}
@app.post("/hate")
def predict_hate(query: Query):
result = hate(query.text)[0]
return {"label": result["label"], "score": result["score"]}
@app.get("/health")
def health_check():
status = {
"server": "running",
"models": {}
}
models = {
"spam": (SPAM_MODEL, spam),
"toxic": (TOXIC_MODEL, toxic),
"sentiment": (SENTIMENT_MODEL, sentiment),
"nsfw": (NSFW_MODEL, nsfw),
}
for key, (model_name, model_pipeline) in models.items():
try:
model_pipeline("test")
status["models"][key] = {
"model_name": model_name,
"status": "running"
}
except Exception as e:
status["models"][key] = {
"model_name": model_name,
"status": f"error: {str(e)}"
}
return status