|
|
from fastapi import FastAPI |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from transformers.pipelines import pipeline |
|
|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/tmp" |
|
|
|
|
|
SPAM_MODEL = "cjell/spam-detector-roberta" |
|
|
TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier" |
|
|
SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment" |
|
|
NSFW_MODEL = "michellejieli/NSFW_text_classifier" |
|
|
HATE_MODEL = "facebook/roberta-hate-speech-dynabench-r4-target" |
|
|
IMAGE_MODEL = "Falconsai/nsfw_image_detection" |
|
|
|
|
|
spam = pipeline("text-classification", model=SPAM_MODEL) |
|
|
|
|
|
toxic = pipeline("text-classification", model=TOXIC_MODEL) |
|
|
|
|
|
sentiment = pipeline("text-classification", model = SENTIMENT_MODEL) |
|
|
|
|
|
nsfw = pipeline("text-classification", model = NSFW_MODEL) |
|
|
|
|
|
hate = pipeline("text-classification", model = HATE_MODEL) |
|
|
|
|
|
image = pipeline("image-classification", model = IMAGE_MODEL) |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"status": "ok"} |
|
|
|
|
|
class Query(BaseModel): |
|
|
text: str |
|
|
|
|
|
@app.post("/spam") |
|
|
def predict_spam(query: Query): |
|
|
result = spam(query.text)[0] |
|
|
return {"label": result["label"], "score": result["score"]} |
|
|
|
|
|
@app.post("/toxic") |
|
|
def predict_toxic(query: Query): |
|
|
result = toxic(query.text)[0] |
|
|
return {"label": result["label"], "score": result["score"]} |
|
|
|
|
|
@app.post("/sentiment") |
|
|
def predict_sentiment(query: Query): |
|
|
result = sentiment(query.text)[0] |
|
|
return {"label": result["label"], "score": result["score"]} |
|
|
|
|
|
@app.post("/nsfw") |
|
|
def predict_nsfw(query: Query): |
|
|
result = nsfw(query.text)[0] |
|
|
return {"label": result["label"], "score": result["score"]} |
|
|
|
|
|
@app.post("/hate") |
|
|
def predict_hate(query: Query): |
|
|
result = hate(query.text)[0] |
|
|
return {"label": result["label"], "score": result["score"]} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
|
|
|
status = { |
|
|
"server": "running", |
|
|
"models": {} |
|
|
} |
|
|
|
|
|
models = { |
|
|
"spam": (SPAM_MODEL, spam), |
|
|
"toxic": (TOXIC_MODEL, toxic), |
|
|
"sentiment": (SENTIMENT_MODEL, sentiment), |
|
|
"nsfw": (NSFW_MODEL, nsfw), |
|
|
} |
|
|
|
|
|
for key, (model_name, model_pipeline) in models.items(): |
|
|
try: |
|
|
model_pipeline("test") |
|
|
status["models"][key] = { |
|
|
"model_name": model_name, |
|
|
"status": "running" |
|
|
} |
|
|
except Exception as e: |
|
|
status["models"][key] = { |
|
|
"model_name": model_name, |
|
|
"status": f"error: {str(e)}" |
|
|
} |
|
|
|
|
|
return status |
|
|
|