Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import torch | |
| import subprocess | |
| import json | |
| from transformers import AutoTokenizer, BertForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| import logging | |
| logger = logging.getLogger("app") | |
| logging.basicConfig(level=logging.INFO) | |
| # ===================================================== | |
| # CONFIG | |
| # ===================================================== | |
| HF_MODEL_REPO = "gaidasalsaa/indobertweet-xstress-model" | |
| BASE_MODEL = "indolem/indobertweet-base-uncased" | |
| PT_FILE = "best_indobertweet.pth" | |
| # ===================================================== | |
| # GLOBAL MODEL STORAGE | |
| # ===================================================== | |
| tokenizer = None | |
| model = None | |
| # ===================================================== | |
| # LOAD MODEL | |
| # ===================================================== | |
| def load_model_once(): | |
| global tokenizer, model | |
| if tokenizer is not None and model is not None: | |
| return | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| logger.info("Downloading fine-tuned weights...") | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename=PT_FILE) | |
| logger.info("Loading base model architecture...") | |
| model = BertForSequenceClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=2 | |
| ) | |
| logger.info("Loading weight .pth...") | |
| state_dict = torch.load(model_path, map_location="cpu") | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to("cpu") | |
| model.eval() | |
| logger.info("MODEL READY") | |
| # ===================================================== | |
| # FASTAPI | |
| # ===================================================== | |
| app = FastAPI(title="Stress Detection API") | |
| def startup_event(): | |
| load_model_once() | |
| class StressResponse(BaseModel): | |
| message: str | |
| data: Optional[dict] = None | |
| # ===================================================== | |
| # SNSCRAPE FETCH TWEETS | |
| # ===================================================== | |
| def fetch_tweets_snscrape(username, limit=50): | |
| tweets = [] | |
| try: | |
| command = [ | |
| "snscrape", | |
| "--jsonl", | |
| "--max-results", str(limit), | |
| f"twitter-user {username}" | |
| ] | |
| result = subprocess.run(command, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| return None | |
| for line in result.stdout.splitlines(): | |
| item = json.loads(line) | |
| if "content" in item: | |
| tweets.append(item["content"]) | |
| return tweets | |
| except Exception: | |
| return None | |
| # ===================================================== | |
| # KEYWORD EXTRACTION | |
| # ===================================================== | |
| def extract_keywords(tweets): | |
| stress_words = [ | |
| "capek", "cape", "capai", "letih", "lelah", "pusing", | |
| "stress", "stres", "burnout", "kesal", "badmood", | |
| "sedih", "tertekan", "muak", "bosan" | |
| ] | |
| found = set() | |
| for t in tweets: | |
| lower = t.lower() | |
| for word in stress_words: | |
| if word in lower: | |
| found.add(word) | |
| return list(found) | |
| # ===================================================== | |
| # MODEL INFERENCE | |
| # ===================================================== | |
| def predict_stress(text): | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=128 | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=1)[0] | |
| label = torch.argmax(probs).item() | |
| return label | |
| # ===================================================== | |
| # API ROUTE | |
| # ===================================================== | |
| def analyze(username: str): | |
| tweets = fetch_tweets_snscrape(username) | |
| if tweets is None or len(tweets) == 0: | |
| return StressResponse(message="No tweets available", data=None) | |
| labels = [predict_stress(t) for t in tweets] | |
| stress_percentage = round(sum(labels) / len(labels) * 100, 2) | |
| if stress_percentage <= 25: | |
| status = 0 | |
| elif stress_percentage <= 50: | |
| status = 1 | |
| elif stress_percentage <= 75: | |
| status = 2 | |
| else: | |
| status = 3 | |
| keywords = extract_keywords(tweets) | |
| return StressResponse( | |
| message="Analysis complete", | |
| data={ | |
| "username": username, | |
| "total_tweets": len(tweets), | |
| "stress_level": stress_percentage, | |
| "keywords": keywords, # kalau tidak ketemu => [] | |
| "stress_status": status | |
| } | |
| ) | |