Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| import gluonnlp as nlp | |
| import numpy as np | |
| import os | |
| from kobert_tokenizer import KoBERTTokenizer # kobert_tokenizer ์ํฌํธ ์ ์ง | |
| from transformers import BertModel # BertModel ์ํฌํธ ์ ์ง | |
| from torch.utils.data import Dataset, DataLoader # DataLoader ์ํฌํธ ์ถ๊ฐ | |
| import logging # ๋ก๊น ๋ชจ๋ ์ํฌํธ ์ ์ง | |
| # --- 1. BERTClassifier ๋ชจ๋ธ ํด๋์ค ์ ์ (model.py์์ ์ฎ๊ฒจ์ด) --- | |
| # ์ด ํด๋์ค๋ ๋ชจ๋ธ์ ์ํคํ ์ฒ๋ฅผ ์ ์ํฉ๋๋ค. | |
| class BERTClassifier(torch.nn.Module): | |
| def __init__(self, | |
| bert, | |
| hidden_size = 768, | |
| num_classes=5, # ๋ถ๋ฅํ ํด๋์ค ์ (category ๋์ ๋๋ฆฌ ํฌ๊ธฐ์ ์ผ์น) | |
| dr_rate=None, | |
| params=None): | |
| super(BERTClassifier, self).__init__() | |
| self.bert = bert | |
| self.dr_rate = dr_rate | |
| self.classifier = torch.nn.Linear(hidden_size , num_classes) | |
| if dr_rate: | |
| self.dropout = torch.nn.Dropout(p=dr_rate) | |
| def gen_attention_mask(self, token_ids, valid_length): | |
| attention_mask = torch.zeros_like(token_ids) | |
| for i, v in enumerate(valid_length): | |
| attention_mask[i][:v] = 1 | |
| return attention_mask.float() | |
| def forward(self, token_ids, valid_length, segment_ids): | |
| attention_mask = self.gen_attention_mask(token_ids, valid_length) | |
| # BertModel์ ์ถ๋ ฅ ๊ตฌ์กฐ์ ๋ฐ๋ผ ์์ | |
| # Hugging Face Transformers์ BertModel์ (last_hidden_state, pooler_output, ...) ๋ฐํ | |
| # pooler_output (CLS ํ ํฐ์ ์ต์ข ์๋ ์ํ๋ฅผ ํต๊ณผํ ๊ฒฐ๊ณผ) ์ฌ์ฉ | |
| _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False) | |
| if self.dr_rate: | |
| out = self.dropout(pooler) | |
| else: | |
| out = pooler | |
| return self.classifier(out) | |
| # --- 2. BERTDataset ํด๋์ค ์ ์ (dataset.py์์ ์ฎ๊ฒจ์ด) --- | |
| # ์ด ํด๋์ค๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ ์ ๋ ฅ ํ์์ผ๋ก ๋ณํํฉ๋๋ค. | |
| class BERTDataset(Dataset): | |
| def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair): | |
| transform = nlp.data.BERTSentenceTransform( | |
| bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair | |
| ) | |
| self.sentences = [transform([i[sent_idx]]) for i in dataset] | |
| self.labels = [np.int32(i[label_idx]) for i in dataset] | |
| def __getitem__(self, i): | |
| return (self.sentences[i] + (self.labels[i],)) | |
| def __len__(self): | |
| return len(self.labels) | |
| # --- 3. FastAPI ์ฑ ๋ฐ ์ ์ญ ๋ณ์ ์ค์ --- | |
| app = FastAPI() | |
| device = torch.device("cpu") # Render์ ๋ฌด๋ฃ ํฐ์ด๋ ์ฃผ๋ก CPU๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| # โ category ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ) | |
| try: | |
| with open("category.pkl", "rb") as f: | |
| category = pickle.load(f) | |
| print("category.pkl ๋ก๋ ์ฑ๊ณต.") | |
| except FileNotFoundError: | |
| print("Error: category.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.") | |
| sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์ | |
| # โ vocab ๋ก๋ (GitHub ์ ์ฅ์ ๋ฃจํธ์ ์์ด์ผ ํจ) | |
| try: | |
| with open("vocab.pkl", "rb") as f: | |
| vocab = pickle.load(f) | |
| print("vocab.pkl ๋ก๋ ์ฑ๊ณต.") | |
| except FileNotFoundError: | |
| print("Error: vocab.pkl ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ํ๋ก์ ํธ ๋ฃจํธ์ ์๋์ง ํ์ธํ์ธ์.") | |
| sys.exit(1) # ํ์ผ ์์ผ๋ฉด ์๋น์ค ์์ํ์ง ์์ | |
| # โ ํ ํฌ๋์ด์ ๋ก๋ (kobert_tokenizer ์ฌ์ฉ) | |
| # Colab ์ฝ๋์์ ์ฌ์ฉ๋ ๋ฐฉ์์ด๋ฏ๋ก ์ ์งํฉ๋๋ค. | |
| tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1') | |
| print("ํ ํฌ๋์ด์ ๋ก๋ ์ฑ๊ณต.") | |
| # โ ๋ชจ๋ธ ๋ก๋ | |
| # ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ์ ์ํ๊ณ , ์ ์ฅ๋ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค. | |
| # num_classes๋ category ๋์ ๋๋ฆฌ์ ํฌ๊ธฐ์ ์ผ์นํด์ผ ํฉ๋๋ค. | |
| model = BERTClassifier( | |
| BertModel.from_pretrained('skt/kobert-base-v1'), | |
| dr_rate=0.5, # ํ์ต ์ ์ฌ์ฉ๋ dr_rate ๊ฐ์ผ๋ก ๋ณ๊ฒฝํ์ธ์. | |
| num_classes=len(category) | |
| ) | |
| # textClassifierModel.pt ํ์ผ ๋ก๋ | |
| # ์ด ํ์ผ์ GitHub ์ ์ฅ์์ ์์ด์ผ ํ๋ฉฐ, Dockerfile์์ Hugging Face Hub์์ ๋ค์ด๋ก๋ํ๋๋ก ์ค์ ๋์ด ์์ต๋๋ค. | |
| try: | |
| # Dockerfile์์ ๋ชจ๋ธ์ ๋ค์ด๋ก๋ํ๋๋ก ์ค์ ํ์ผ๋ฏ๋ก, ์ฌ๊ธฐ์๋ ๋ก์ปฌ ๊ฒฝ๋ก๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
| # ๋ง์ฝ Dockerfile์์ hf_hub_download๋ฅผ ์ฌ์ฉํ์ง ์๋๋ค๋ฉด, ์ฌ๊ธฐ์ hf_hub_download๋ฅผ ์ถ๊ฐํด์ผ ํฉ๋๋ค. | |
| # ํ์ฌ Dockerfile์ git+https://github.com/SKTBrain/KOBERT#egg=kobert_tokenizer ๋ก๋๋ง ํฌํจํ๊ณ , | |
| # ๋ชจ๋ธ ํ์ผ ๋ค์ด๋ก๋๋ ํฌํจํ์ง ์์ต๋๋ค. | |
| # ๋ฐ๋ผ์, ๋ชจ๋ธ ํ์ผ์ Hugging Face Hub์์ ๋ค์ด๋ก๋ํ๋ ๋ก์ง์ ๋ค์ ์ถ๊ฐํด์ผ ํฉ๋๋ค. | |
| from huggingface_hub import hf_hub_download | |
| HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" # ์ฌ์ฉ์๋์ ์ค์ Hugging Face ์ ์ฅ์ ID | |
| HF_MODEL_FILENAME = "textClassifierModel.pt" | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME) | |
| print(f"๋ชจ๋ธ ํ์ผ์ด '{model_path}'์ ์ฑ๊ณต์ ์ผ๋ก ๋ค์ด๋ก๋๋์์ต๋๋ค.") | |
| # ๋ชจ๋ธ์ state_dict๋ฅผ ๋ก๋ํฉ๋๋ค. | |
| loaded_state_dict = torch.load(model_path, map_location=device) | |
| # state_dict ํค ์กฐ์ (ํ์ํ ๊ฒฝ์ฐ) | |
| new_state_dict = collections.OrderedDict() | |
| for k, v in loaded_state_dict.items(): | |
| name = k | |
| if name.startswith('module.'): | |
| name = name[7:] | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict) | |
| model.to(device) # ๋ชจ๋ธ์ ๋๋ฐ์ด์ค๋ก ์ด๋ | |
| model.eval() # ์ถ๋ก ๋ชจ๋๋ก ์ค์ | |
| print("๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต.") | |
| except Exception as e: | |
| print(f"Error: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| sys.exit(1) # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ ์๋น์ค ์์ํ์ง ์์ | |
| # โ ๋ฐ์ดํฐ์ ์์ฑ์ ํ์ํ ํ๋ผ๋ฏธํฐ | |
| max_len = 64 | |
| batch_size = 32 | |
| # โ ์์ธก ํจ์ | |
| def predict(predict_sentence): | |
| data = [predict_sentence, '0'] | |
| dataset_another = [data] | |
| # num_workers๋ ๋ฐฐํฌ ํ๊ฒฝ์์ 0์ผ๋ก ์ค์ ๊ถ์ฅ | |
| another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False) | |
| test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0) | |
| model.eval() # ์์ธก ์ ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์ | |
| with torch.no_grad(): # ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋นํ์ฑํ | |
| for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader): | |
| token_ids = token_ids.long().to(device) | |
| segment_ids = segment_ids.long().to(device) | |
| out = model(token_ids, valid_length, segment_ids) | |
| logits = out | |
| logits = logits.detach().cpu().numpy() | |
| predicted_category_index = np.argmax(logits) | |
| predicted_category_name = list(category.keys())[predicted_category_index] | |
| return predicted_category_name | |
| # โ ์๋ํฌ์ธํธ ์ ์ | |
| class InputText(BaseModel): | |
| text: str | |
| def root(): | |
| return {"message": "Text Classification API (KoBERT)"} | |
| async def predict_route(item: InputText): | |
| result = predict(item.text) | |
| return {"text": item.text, "classification": result} | |