hiddenFront commited on
Commit
95b43d8
Β·
verified Β·
1 Parent(s): 7233753

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -76
app.py CHANGED
@@ -1,96 +1,72 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import BertModel, BertForSequenceClassification, AutoTokenizer
3
- from huggingface_hub import hf_hub_download
4
  import torch
5
  import pickle
 
 
6
  import os
7
- import sys
8
- import psutil
 
 
 
9
 
10
  app = FastAPI()
11
  device = torch.device("cpu")
12
 
13
- # category.pkl λ‘œλ“œ
14
- try:
15
- with open("category.pkl", "rb") as f:
16
- category = pickle.load(f)
17
- print("βœ… category.pkl λ‘œλ“œ 성곡.")
18
- except FileNotFoundError:
19
- print("❌ Error: category.pkl νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
20
- sys.exit(1)
21
 
22
- # ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
23
- tokenizer = AutoTokenizer.from_pretrained("skt/kobert-base-v1")
24
- print("βœ… ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ 성곡.")
25
 
26
- class CustomClassifier(torch.nn.Module):
27
- def __init__(self):
28
- super().__init__()
29
- # μ •μ˜ν–ˆλ˜ ꡬ쑰 κ·ΈλŒ€λ‘œ 볡원해야 함
30
- self.bert = BertModel.from_pretrained("skt/kobert-base-v1")
31
- self.classifier = torch.nn.Linear(768, len(category))
32
 
33
- def forward(self, input_ids, attention_mask=None, token_type_ids=None):
34
- outputs = self.bert(input_ids=input_ids,
35
- attention_mask=attention_mask,
36
- token_type_ids=token_type_ids)
37
- pooled_output = outputs[1] # CLS 토큰
38
- return self.classifier(pooled_output)
 
 
 
39
 
40
- HF_MODEL_REPO_ID = "hiddenFront/TextClassifier"
41
- HF_MODEL_FILENAME = "textClassifierModel.pt"
 
42
 
43
- # λ©”λͺ¨λ¦¬ μΈ‘μ • μ „
44
- process = psutil.Process(os.getpid())
45
- mem_before = process.memory_info().rss / (1024 * 1024)
46
- print(f"πŸ“¦ λͺ¨λΈ λ‹€μš΄λ‘œλ“œ μ „ λ©”λͺ¨λ¦¬ μ‚¬μš©λŸ‰: {mem_before:.2f} MB")
 
 
47
 
48
- # λͺ¨λΈ λ‘œλ“œ
49
- try:
50
- model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=HF_MODEL_FILENAME)
51
- print(f"βœ… λͺ¨λΈ 파일 λ‹€μš΄λ‘œλ“œ 성곡: {model_path}")
52
-
53
- state_dict = torch.load(model_path, map_location=device)
54
- model = BertForSequenceClassification.from_pretrained(
55
- "skt/kobert-base-v1",
56
- num_labels=len(category),
57
- state_dict=state_dict,
58
- )
59
- model.to(device)
60
  model.eval()
61
- print("βœ… λͺ¨λΈ λ‘œλ“œ 및 μ€€λΉ„ μ™„λ£Œ.")
62
- except Exception as e:
63
- print(f"❌ Error: λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
64
- sys.exit(1)
 
 
 
 
 
 
65
 
 
 
 
66
 
67
  @app.get("/")
68
- def root(request: Request):
69
- client_host = request.client.host
70
- client_port = request.client.port
71
- return {
72
- "message": "Text Classification API is running!",
73
- "client_ip": client_host,
74
- "client_port": client_port
75
- }
76
 
77
- # 예츑 API
78
  @app.post("/predict")
79
- async def predict_api(request: Request):
80
- data = await request.json()
81
- text = data.get("text")
82
- print("request date", data);
83
- if not text:
84
- return {"error": "No text provided", "classification": "null"}
85
-
86
- encoded = tokenizer.encode_plus(
87
- text, max_length=64, padding='max_length', truncation=True, return_tensors='pt'
88
- )
89
-
90
- with torch.no_grad():
91
- outputs = model(**encoded)
92
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
93
- predicted = torch.argmax(probs, dim=1).item()
94
-
95
- label = list(category.keys())[predicted]
96
- return {"text": text, "classification": label}
 
1
  from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
 
3
  import torch
4
  import pickle
5
+ import gluonnlp as nlp
6
+ import numpy as np
7
  import os
8
+ from kobert_tokenizer import KoBERTTokenizer
9
+ from model import BERTClassifier
10
+ from dataset import BERTDataset
11
+ from transformers import BertModel
12
+ import logging
13
 
14
  app = FastAPI()
15
  device = torch.device("cpu")
16
 
17
+ # βœ… category λ‘œλ“œ
18
+ with open("category.pkl", "rb") as f:
19
+ category = pickle.load(f)
 
 
 
 
 
20
 
21
+ # βœ… vocab λ‘œλ“œ
22
+ with open("vocab.pkl", "rb") as f:
23
+ vocab = pickle.load(f)
24
 
25
+ # βœ… ν† ν¬λ‚˜μ΄μ €
26
+ tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
 
 
 
 
27
 
28
+ # βœ… λͺ¨λΈ λ‘œλ“œ
29
+ model = BERTClassifier(
30
+ BertModel.from_pretrained('skt/kobert-base-v1'),
31
+ dr_rate=0.5,
32
+ num_classes=len(category)
33
+ )
34
+ model.load_state_dict(torch.load("textClassifierModel.pt", map_location=device))
35
+ model.to(device)
36
+ model.eval()
37
 
38
+ # βœ… 데이터셋 생성에 ν•„μš”ν•œ νŒŒλΌλ―Έν„°
39
+ max_len = 64
40
+ batch_size = 32
41
 
42
+ # βœ… 예츑 ν•¨μˆ˜
43
+ def predict(predict_sentence):
44
+ data = [predict_sentence, '0']
45
+ dataset_another = [data]
46
+ another_test = BERTDataset(dataset_another, 0, 1, tokenizer, vocab, max_len, True, False)
47
+ test_dataLoader = torch.utils.data.DataLoader(another_test, batch_size=batch_size, num_workers=0)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  model.eval()
50
+ for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataLoader):
51
+ token_ids = token_ids.long().to(device)
52
+ segment_ids = segment_ids.long().to(device)
53
+
54
+ out = model(token_ids, valid_length, segment_ids)
55
+ test_eval = []
56
+ for i in out:
57
+ logits = i.detach().cpu().numpy()
58
+ test_eval.append(list(category.keys())[np.argmax(logits)])
59
+ return test_eval[0]
60
 
61
+ # βœ… μ—”λ“œν¬μΈνŠΈ μ •μ˜
62
+ class InputText(BaseModel):
63
+ text: str
64
 
65
  @app.get("/")
66
+ def root():
67
+ return {"message": "Text Classification API (KoBERT)"}
 
 
 
 
 
 
68
 
 
69
  @app.post("/predict")
70
+ async def predict_route(item: InputText):
71
+ result = predict(item.text)
72
+ return {"text": item.text, "classification": result}