xstress-api-hf / app.py
gaidasalsaa's picture
add token
f812c6c
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import requests
import torch
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download
import logging
logger = logging.getLogger("app")
logging.basicConfig(level=logging.INFO)
# ===========================
# CONFIG
# ===========================
HF_MODEL_REPO = "gaidasalsaa/model-indobertweet-terbaru"
BASE_MODEL = "indolem/indobertweet-base-uncased"
PT_FILE = "model_indobertweet.pth"
BEARER_TOKEN = "AAAAAAAAAAAAAAAAAAAAACOx5wEAAAAA8dmBFQL26Vn%2FEWRVeQu%2BiTqdd%2F4%3DE8QcDTWabLJphye8PVICImVIHd1BLMB9fEU3pxJGrpO1Uw2TsN"
# ===========================
# GLOBAL MODEL
# ===========================
tokenizer = None
model = None
# ===========================
# TEXT CLEANING
# ===========================
def clean_text(t):
t = t.lower()
t = re.sub(r"http\S+|www\.\S+", "", t)
t = re.sub(r"@\w+", "", t)
t = re.sub(r"#(\w+)", r"\1", t)
return t.strip()
# ===========================
# LOAD MODEL
# ===========================
def load_model_once():
global tokenizer, model
if tokenizer is not None and model is not None:
logger.info("Model already loaded.")
return
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
logger.info("Downloading model weights...")
model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename=PT_FILE,
)
logger.info("Loading IndoBERTweet architecture...")
model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL,
num_labels=2
)
logger.info("Loading state_dict...")
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
logger.info("MODEL READY")
# ===========================
# FASTAPI
# ===========================
app = FastAPI(title="Stress Detection API")
@app.on_event("startup")
def startup_event():
load_model_once()
class StressResponse(BaseModel):
message: str
data: Optional[dict] = None
# ===========================
# TWITTER API
# ===========================
def get_user_id(username):
url = f"https://api.x.com/2/users/by/username/{username}"
headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
try:
r = requests.get(url, headers=headers, timeout=10)
if r.status_code != 200:
return None, r.json()
return r.json()["data"]["id"], r.json()
except:
return None, {"error": "Request failed"}
def fetch_tweets(user_id, limit=25):
url = f"https://api.x.com/2/users/{user_id}/tweets"
params = {"max_results": limit, "tweet.fields": "id,text,created_at"}
headers = {"Authorization": f"Bearer {BEARER_TOKEN}"}
try:
r = requests.get(url, headers=headers, params=params, timeout=10)
if r.status_code != 200:
return None, r.json()
data = r.json().get("data", [])
return [t["text"] for t in data], r.json()
except:
return None, {"error": "Request failed"}
# ===========================
# KEYWORDS
# ===========================
def extract_keywords(tweets):
stress_words = [
"gelisah","cemas","tidur","takut","hati","resah","sampe","tenang",
"suka","mulu","sedih","ngerasa","gimana","gatau","perasaan",
"nangis","deg","khawatir","pikiran","harap","gabisa","bener",
"pengen","sakit","susah","bangun","biar","jam","kaya","bingung",
"mikir","tuhan","mikirin","bawaannya","marah","tbtb","anjir",
"cape","panik","enak","kali","pusing","semoga","kadang","langsung",
"kemarin","tugas","males"
]
found = set()
for t in tweets:
lower = t.lower()
for w in stress_words:
if w in lower:
found.add(w)
return list(found)
# ===========================
# INFERENCE
# ===========================
def predict_stress(text):
text = clean_text(text)
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=128
)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)[0]
label = int(torch.argmax(probs).item())
return label, float(probs[1])
# ===========================
# ROUTE
# ===========================
@app.get("/analyze/{username}", response_model=StressResponse)
def analyze(username: str):
user_id, _ = get_user_id(username)
if user_id is None:
return StressResponse(message="Failed to fetch profile", data=None)
tweets, _ = fetch_tweets(user_id)
if not tweets:
return StressResponse(message="No tweets available", data=None)
labels = [predict_stress(t)[0] for t in tweets]
stress_percentage = round(sum(labels) / len(labels) * 100, 2)
# 4-level status
if stress_percentage <= 25:
status = 0
elif stress_percentage <= 50:
status = 1
elif stress_percentage <= 75:
status = 2
else:
status = 3
keywords = extract_keywords(tweets)
return StressResponse(
message="Analysis complete",
data={
"username": username,
"total_tweets": len(tweets),
"stress_level": stress_percentage,
"keywords": keywords,
"stress_status": status
}
)