Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| import gluonnlp as nlp | |
| import numpy as np | |
| import os | |
| import sys # ์ค๋ฅ ์ ์๋น์ค ์ข ๋ฃ๋ฅผ ์ํด sys ๋ชจ๋ ์ํฌํธ | |
| # transformers์ AutoTokenizer๋ง ์ฌ์ฉํฉ๋๋ค. | |
| from transformers import AutoTokenizer # BertModel, BertForSequenceClassification ๋ฑ์ ์ด์ ์ง์ ํ์ ์์ต๋๋ค. | |
| from torch.utils.data import Dataset, DataLoader | |
| import logging # ๋ก๊น ๋ชจ๋ ์ํฌํธ ์ ์ง | |
| from huggingface_hub import hf_hub_download # hf_hub_download ์ํฌํธ ์ ์ง | |
| # collections ๋ชจ๋์ ๋ ์ด์ ํ์ ์์ ์ ์์ง๋ง, ํน์ ๋ชฐ๋ผ ์ ์งํฉ๋๋ค. | |
| import collections | |
| # --- 1. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ --- | |
| app = FastAPI() | |
| device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| # โ category ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ) | |
| try: | |
| with open("category.pkl", "rb") as f: | |
| category = pickle.load(f) | |
| print("category.pkl ๋ก๋ ์ฑ๊ณต.") | |
| except FileNotFoundError: | |
| print("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.") | |
| sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์ | |
| # โ vocab ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ) | |
| try: | |
| with open("vocab.pkl", "rb") as f: | |
| vocab = pickle.load(f) | |
| print("vocab.pkl ๋ก๋ ์ฑ๊ณต.") | |
| except FileNotFoundError: | |
| print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.") | |
| sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์ | |
| # โ ํ ํฌ๋์ด์ ๋ก๋ (transformers.AutoTokenizer ์ฌ์ฉ) | |
| tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1') | |
| print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.") | |
| # โ ๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋) | |
| # textClassifierModel.pt ํ์ผ์ ์ด๋ฏธ ๊ฒฝ๋ํ๋ '์์ ํ ๋ชจ๋ธ ๊ฐ์ฒด'๋ผ๊ณ ๊ฐ์ ํ๊ณ ์ง์ ๋ก๋ํฉ๋๋ค. | |
| try: | |
| HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID | |
| HF_MODEL_FILENAME = "textClassifierModel.pt" # Hugging Face Hub์ ์ ๋ก๋ํ ํ์ผ ์ด๋ฆ๊ณผ ์ผ์นํด์ผ ํฉ๋๋ค. | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME) | |
| print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.") | |
| # --- ์์ ๋ ํต์ฌ ๋ถ๋ถ --- | |
| # ๊ฒฝ๋ํ๋ ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ์ง์ ๋ก๋ํฉ๋๋ค. | |
| # ์ด ํ์ผ์ ์ด๋ฏธ PyTorch ๋ชจ๋ธ ๊ฐ์ฒด(์์ํ๋ ๋ชจ๋ธ ํฌํจ)์ด๋ฏ๋ก ๋ฐ๋ก ๋ก๋ํ์ฌ ์ฌ์ฉํฉ๋๋ค. | |
| model = torch.load(model_path, map_location=device) | |
| # --- ์์ ๋ ํต์ฌ ๋ถ๋ถ ๋ --- | |
| model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์ | |
| print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.") | |
| except Exception as e: | |
| print(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์ | |
| # --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) --- | |
| # ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์ ๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค. | |
| class BERTDataset(Dataset): | |
| def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair): | |
| # nlp.data.BERTSentenceTransform์ ํ ํฌ๋์ด์ ํจ์๋ฅผ ๋ฐ์ต๋๋ค. | |
| # AutoTokenizer์ tokenize ๋ฉ์๋๋ฅผ ์ง์ ์ ๋ฌํฉ๋๋ค. | |
| transform = nlp.data.BERTSentenceTransform( | |
| bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair | |
| ) | |
| self.sentences = [transform([i[sent_idx]]) for i in dataset] | |
| self.labels = [np.int32(i[label_idx]) for i in dataset] | |
| def __getitem__(self, i): | |
| return (self.sentences[i] + (self.labels[i],)) | |
| def __len__(self): | |
| return len(self.labels) | |
| # โ ๋ฐ์ดํฐ์ ์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ | |
| max_len = 64 | |
| batch_size = 32 | |
| # โ ์์ธก ํจ์ | |
| def predict(predict_sentence): | |
| data = [predict_sentence, '0'] | |
| dataset_another = [data] | |
| # num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ | |
| # tokenizer.tokenize๋ฅผ BERTDataset์ ์ ๋ฌํฉ๋๋ค. | |
| another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False) | |
| test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0) | |
| model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์ | |
| with torch.no_grad(): # ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋นํ์ฑํ | |
| for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader): | |
| token_ids = token_ids.long().to(device) | |
| segment_ids = segment_ids.long().to(device) | |
| out = model(token_ids, valid_length, segment_ids) | |
| logits = out | |
| logits = logits.detach().cpu().numpy() | |
| predicted_category_index = np.argmax(logits) | |
| predicted_category_name = list(category.keys())[predicted_category_index] | |
| return predicted_category_name | |
| # โ ์๋ํฌ์ธํธ ์ ์ | |
| class InputText(BaseModel): | |
| text: str | |
| def root(): | |
| return {"message": "Text Classification API (KoBERT)"} | |
| async def predict_route(item: InputText): | |
| result = predict(item.text) | |
| return {"text": item.text, "classification": result} | |