Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from transformers import BertModel, BertForSequenceClassification, AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| import torch | |
| import pickle | |
| import os | |
| import sys | |
| import psutil | |
| app = FastAPI() | |
| device = torch.device("cpu") | |
| # category.pkl λ‘λ | |
| try: | |
| with open("category.pkl", "rb") as f: | |
| category = pickle.load(f) | |
| print("β category.pkl λ‘λ μ±κ³΅.") | |
| except FileNotFoundError: | |
| print("β Error: category.pkl νμΌμ μ°Ύμ μ μμ΅λλ€.") | |
| sys.exit(1) | |
| # ν ν¬λμ΄μ λ‘λ | |
| tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1") | |
| print("β ν ν¬λμ΄μ λ‘λ μ±κ³΅.") | |
| class CustomClassifier(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # μ μνλ ꡬ쑰 κ·Έλλ‘ λ³΅μν΄μΌ ν¨ | |
| self.bert = BertModel.from_pretrained("skt/kobert-base-v1") | |
| self.classifier = torch.nn.Linear(768, len(category)) | |
| def forward(self, input_ids, attention_mask=None, token_type_ids=None): | |
| outputs = self.bert(input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids) | |
| pooled_output = outputs[1] # CLS ν ν° | |
| return self.classifier(pooled_output) | |
| HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" | |
| HF_MODEL_FILENAME = "textClassifierModel.pt" | |
| # λ©λͺ¨λ¦¬ μΈ‘μ μ | |
| process = psutil.Process(os.getpid()) | |
| mem_before = process.memory_info().rss / (1024 * 1024) | |
| print(f"π¦ λͺ¨λΈ λ€μ΄λ‘λ μ λ©λͺ¨λ¦¬ μ¬μ©λ: {mem_before:.2f} MB") | |
| # λͺ¨λΈ λ‘λ | |
| try: | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME) | |
| print(f"β λͺ¨λΈ νμΌ λ€μ΄λ‘λ μ±κ³΅: {model_path}") | |
| state_dict = torch.load(model_path, map_location=device) | |
| model = BertForSequenceClassification.from_pretrained( | |
| "skt/kobert-base-v1", | |
| num_labels=len(category), | |
| state_dict=state_dict, | |
| ) | |
| model.to(device) | |
| model.eval() | |
| print("β λͺ¨λΈ λ‘λ λ° μ€λΉ μλ£.") | |
| except Exception as e: | |
| print(f"β Error: λͺ¨λΈ λ‘λ μ€ μ€λ₯ λ°μ: {e}") | |
| sys.exit(1) | |
| # μμΈ‘ API | |
| async def predict_api(request: Request): | |
| data = await request.json() | |
| text = data.get("text") | |
| if not text: | |
| return {"error": "No text provided", "classification": "null"} | |
| encoded = tokenizer.encode_plus( | |
| text, max_length=64, padding='max_length', truncation=True, return_tensors='pt' | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**encoded) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| predicted = torch.argmax(probs, dim=1).item() | |
| label = list(category.keys())[predicted] | |
| return {"text": text, "classification": label} | |