Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| import gluonnlp as nlp | |
| import numpy as np | |
| import os | |
| import sys | |
| import collections | |
| import logging # λ‘κΉ λͺ¨λ μν¬νΈ | |
| from transformers import AutoTokenizer, BertModel | |
| from torch.utils.data import Dataset, DataLoader | |
| from huggingface_hub import hf_hub_download | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class BERTClassifier(torch.nn.Module): | |
| def __init__(self, | |
| bert, | |
| hidden_size = 768, | |
| num_classes=5, # λΆλ₯ν ν΄λμ€ μ (category λμ λ리 ν¬κΈ°μ μΌμΉ) | |
| dr_rate=None, | |
| params=None): | |
| super(BERTClassifier, self).__init__() | |
| self.bert = bert | |
| self.dr_rate = dr_rate | |
| self.classifier = torch.nn.Linear(hidden_size , num_classes) | |
| if dr_rate: | |
| self.dropout = torch.nn.Dropout(p=dr_rate) | |
| def gen_attention_mask(self, token_ids, valid_length): | |
| attention_mask = torch.zeros_like(token_ids) | |
| for i, v in enumerate(valid_length): | |
| attention_mask[i][:v] = 1 | |
| return attention_mask.float() | |
| def forward(self, token_ids, valid_length, segment_ids): | |
| attention_mask = self.gen_attention_mask(token_ids, valid_length) | |
| _, pooler = self.bert(input_ids=token_ids, token_type_ids=segment_ids.long(), attention_mask=attention_mask.float().to(token_ids.device), return_dict=False) | |
| if self.dr_rate: | |
| out = self.dropout(pooler) | |
| else: | |
| out = pooler | |
| return self.classifier(out) | |
| # --- 2. BERTDataset ν΄λμ€ μ μ --- | |
| class BERTDataset(Dataset): | |
| def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair): | |
| transform = nlp.data.BERTSentenceTransform( | |
| bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair | |
| ) | |
| self.sentences = [transform([i[sent_idx]]) for i in dataset] | |
| self.labels = [np.int32(i[label_idx]) for i in dataset] | |
| def __getitem__(self, i): | |
| return (self.sentences[i] + (self.labels[i],)) | |
| def __len__(self): | |
| return len(self.labels) | |
| app = FastAPI() | |
| device = torch.device("cpu") # Hugging Face Spacesμ λ¬΄λ£ ν°μ΄λ μ£Όλ‘ CPUλ₯Ό μ¬μ©ν©λλ€. | |
| # β category λ‘λ | |
| try: | |
| with open("category.pkl", "rb") as f: | |
| category = pickle.load(f) | |
| logger.info("category.pkl λ‘λ μ±κ³΅.") | |
| except FileNotFoundError: | |
| logger.error("Error: category.pkl νμΌμ μ°Ύμ μ μμ΅λλ€. νλ‘μ νΈ λ£¨νΈμ μλμ§ νμΈνμΈμ.") | |
| sys.exit(1) | |
| # β vocab λ‘λ | |
| try: | |
| with open("vocab.pkl", "rb") as f: | |
| vocab = pickle.load(f) | |
| logger.info("vocab.pkl λ‘λ μ±κ³΅.") | |
| except FileNotFoundError: | |
| logger.error("Error: vocab.pkl νμΌμ μ°Ύμ μ μμ΅λλ€. νλ‘μ νΈ λ£¨νΈμ μλμ§ νμΈνμΈμ.") | |
| sys.exit(1) | |
| # β ν ν¬λμ΄μ λ‘λ | |
| tokenizer = AutoTokenizer.from_pretrained('skt/kobert-base-v1') | |
| logger.info("ν ν¬λμ΄μ λ‘λ μ±κ³΅.") | |
| # β λͺ¨λΈ λ‘λ (Hugging Face Hubμμ λ€μ΄λ‘λ) | |
| try: | |
| HF_MODEL_REPO_ID = "hiddenFront/TextClassifier" | |
| HF_MODEL_FILENAME = "textClassifierModel.pt" | |
| model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME) | |
| logger.info(f"λͺ¨λΈ νμΌμ΄ '{model_path}'μ μ±κ³΅μ μΌλ‘ λ€μ΄λ‘λλμμ΅λλ€.") | |
| bert_base_model = BertModel.from_pretrained('skt/kobert-base-v1') | |
| model = BERTClassifier( | |
| bert_base_model, | |
| dr_rate=0.5, # νμ΅ μ μ¬μ©λ dr_rate κ°μΌλ‘ λ³κ²½νμΈμ. | |
| num_classes=len(category) | |
| ) | |
| loaded_state_dict = torch.load(model_path, map_location=device) | |
| new_state_dict = collections.OrderedDict() | |
| for k, v in loaded_state_dict.items(): | |
| name = k | |
| if name.startswith('module.'): | |
| name = name[7:] | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| logger.info("λͺ¨λΈ λ‘λ μ±κ³΅.") | |
| except Exception as e: | |
| logger.error(f"Error: λͺ¨λΈ λ€μ΄λ‘λ λλ λ‘λ μ€ μ€λ₯ λ°μ: {e}") | |
| sys.exit(1) | |
| # β λ°μ΄ν°μ μμ±μ νμν νλΌλ―Έν° | |
| max_len = 64 | |
| batch_size = 32 | |
| # β μμΈ‘ ν¨μ | |
| def predict(predict_sentence): | |
| data = [predict_sentence, '0'] | |
| dataset_another = [data] | |
| another_test = BERTDataset(dataset_another, 0, 1, tokenizer.tokenize, vocab, max_len, True, False) | |
| test_dataLoader = DataLoader(another_test, batch_size=batch_size, num_workers=0) | |
| model.eval() | |
| with torch.no_grad(): | |
| for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader): | |
| token_ids = token_ids.long().to(device) | |
| segment_ids = segment_ids.long().to(device) | |
| out = model(token_ids, valid_length, segment_ids) | |
| logits = out # λͺ¨λΈμ μ§μ μΆλ ₯μ λ‘μ§μ λλ€. | |
| probs = torch.nn.functional.softmax(logits, dim=1) # νλ₯ κ³μ° | |
| predicted_category_index = torch.argmax(probs, dim=1).item() # μμΈ‘ μΈλ±μ€ | |
| predicted_category_name = list(category.keys())[predicted_category_index] # μμΈ‘ μΉ΄ν κ³ λ¦¬ μ΄λ¦ | |
| # --- μμΈ‘ μμΈ λ‘κΉ --- | |
| logger.info(f"Input Text: '{predict_sentence}'") | |
| logger.info(f"Raw Logits: {logits.tolist()}") | |
| logger.info(f"Probabilities: {probs.tolist()}") | |
| logger.info(f"Predicted Index: {predicted_category_index}") | |
| logger.info(f"Predicted Label: '{predicted_category_name}'") | |
| # --- μμΈ‘ μμΈ λ‘κΉ λ --- | |
| return predicted_category_name | |
| # β μλν¬μΈνΈ μ μ | |
| class InputText(BaseModel): | |
| text: str | |
| def root(): | |
| return {"message": "Text Classification API (KoBERT)"} | |
| async def predict_route(item: InputText): | |
| result = predict(item.text) | |
| return {"text": item.text, "classification": result} | |