Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| import gluonnlp as nlp | |
| import numpy as np | |
| import os | |
| import sys | |
| import collections | |
| import logging # ๋ก๊น ๋ชจ๋ ์ํฌํธ | |
| # transformers์ AutoTokenizer ๋ฐ BertModel ์ํฌํธ | |
| from transformers import AutoTokenizer, BertModel | |
| from torch.utils.data import Dataset, DataLoader | |
| from huggingface_hub import hf_hub_download | |
| # --- ๋ก๊น ์ค์ --- | |
| # INFO ๋ ๋ฒจ ์ด์์ ๋ก๊ทธ๋ฅผ ์ถ๋ ฅํ๋๋ก ์ค์ ํฉ๋๋ค. | |
| # ์ค์ ๋ฐฐํฌ ํ๊ฒฝ์์๋ ๋ก๊ทธ ๋ ๋ฒจ์ WARNING์ด๋ ERROR๋ก ๋์ฌ ๋ถํ์ํ ๋ก๊ทธ๋ฅผ ์ค์ผ ์ ์์ต๋๋ค. | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ --- | |
| class BERTClassifier(torch.nn.Module): | |
| def __init__(self, | |
| bert, | |
| hidden_size = 768, | |
| num_classes=5, # ๋ถ๋ฅํ ํด๋์ค ์ (category ๋์ ๋๋ฆฌ ํฌ๊ธฐ์ ์ผ์น) | |
| dr_rate=None, | |
| params=None): | |
| super(BERTClassifier, self).__init__() | |
| self.bert = bert | |
| self.dr_rate = dr_rate | |
| self.classifier = torch.nn.Linear(hidden_size , num_classes) | |
| if dr_rate: | |
| self.dropout = torch.nn.Dropout(p=dr_rate) | |
| def gen_attention_mask(self, token_ids, valid_length): | |
| attention_mask = torch.zeros_like(token_ids) | |
| for i, v in enumerate(valid_length): | |
| attention_mask[i][:v] = 1 | |
| return attention_mask.float() | |
| def forward(self, token_ids, valid_length, segment_ids): | |
| attention_mask = self.gen_attention_mask(token_ids, valid_length) | |
| _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False) | |
| if self.dr_rate: | |
| out = self.dropout(pooler) | |
| else: | |
| out = pooler | |
| return self.classifier(out) | |
| # --- 2. BERTDataset ํด๋์ค ์ ์ --- | |
| class BERTDataset(Dataset): | |
| def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair): | |
| transform = nlp.data.BERTSentenceTransform( | |
| bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair | |
| ) | |
| self.sentences = [transform([i[sent_idx]]) for i in dataset] | |
| self.labels = [np.int32(i[label_idx]) for i in dataset] | |
| def __getitem__(self, i): | |
| return (self.sentences[i] + (self.labels[i],)) | |
| def __len__(self): | |
| return len(self.labels) | |
| # --- 3. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ --- | |
| app = FastAPI() | |
| device = torch.device("cpu") # Hugging Face Spaces์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| # โ category ๋ก๋ | |
| try: | |
| with open("category.pkl", "rb") as f: | |
| category = pickle.load(f) | |
| logger.info("category.pkl ๋ก๋ ์ฑ๊ณต.") | |
| except FileNotFoundError: | |
| logger.error("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.") | |
| sys.exit(1) | |
| # โ vocab ๋ก๋ | |
| try: | |
| with open("vocab.pkl", "rb") as f: | |
| vocab = pickle.load(f) | |
| logger.info("vocab.pkl ๋ก๋ ์ฑ๊ณต.") | |
| except FileNotFoundError: | |
| logger.error("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.") | |
| sys.exit(1) | |
| # โ ํ ํฌ๋์ด์ ๋ก๋ | |
| tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1') | |
| logger.info("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.") | |
| # โ ๋ชจ๋ธ ๋ก๋ (Hugging Face Hub์์ ๋ค์ด๋ก๋) | |
| try: | |
| HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" | |
| HF_MODEL_FILENAME = "textClassifierModel.pt" | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME) | |
| logger.info(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.") | |
| bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1') | |
| model = BERTClassifier( | |
| bert_base_model, | |
| dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์. | |
| num_classes=len(category) | |
| ) | |
| loaded_state_dict = torch.load(model_path, map_location=device) | |
| new_state_dict = collections.OrderedDict() | |
| for k, v in loaded_state_dict.items(): | |
| name = k | |
| if name.startswith('module.'): | |
| name = name[7:] | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| logger.info("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.") | |
| except Exception as e: | |
| logger.error(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| sys.exit(1) | |
| # โ ๋ฐ์ดํฐ์ ์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ | |
| max_len = 64 | |
| batch_size = 32 | |
| # โ ์์ธก ํจ์ | |
| def predict(predict_sentence): | |
| data = [predict_sentence, '0'] | |
| dataset_another = [data] | |
| another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False) | |
| test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0) | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader): | |
| token_ids = token_ids.long().to(device) | |
| segment_ids = segment_ids.long().to(device) | |
| out = model(token_ids, valid_length, segment_ids) | |
| logits = out # ๋ชจ๋ธ์ ์ง์ ์ถ๋ ฅ์ ๋ก์ง์ ๋๋ค. | |
| probs = torch.nn.functional.softmax(logits, dim=1) # ํ๋ฅ ๊ณ์ฐ | |
| predicted_category_index = torch.argmax(probs, dim=1).item() # ์์ธก ์ธ๋ฑ์ค | |
| predicted_category_name = list(category.keys())[predicted_category_index] # ์์ธก ์นดํ ๊ณ ๋ฆฌ ์ด๋ฆ | |
| # --- ์์ธก ์์ธ ๋ก๊น --- | |
| logger.info(f"Input Text: '{predict_sentence}'") | |
| logger.info(f"Raw Logits: {logits.tolist()}") | |
| logger.info(f"Probabilities: {probs.tolist()}") | |
| logger.info(f"Predicted Index: {predicted_category_index}") | |
| logger.info(f"Predicted Label: '{predicted_category_name}'") | |
| # --- ์์ธก ์์ธ ๋ก๊น ๋ --- | |
| return predicted_category_name | |
| # โ ์๋ํฌ์ธํธ ์ ์ | |
| class InputText(BaseModel): | |
| text: str | |
| def root(): | |
| return {"message": "Text Classification API (KoBERT)"} | |
| async def predict_route(item: InputText): | |
| result = predict(item.text) | |
| return {"text": item.text, "classification": result} | |