Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # 1. ๋ชจ๋ธ์ด ์ ์ฅ๋ ํด๋ ๊ฒฝ๋ก ์ง์ | |
| LOAD_MODEL_PATH = "./xtreme-distil-review-classifier" | |
| # 2. GPU/CPU ์ฅ์น ์ค์ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"์ฌ์ฉ ์ฅ์น: {device}") | |
| # 3. ์ ์ฅ๋ ํ ํฌ๋์ด์ ์ ๋ชจ๋ธ ๋ก๋ | |
| # ์ ์ฅ๋ config.json๊ณผ model.safetensors ํ์ผ์ ๋ฐํ์ผ๋ก ๋ก๋ํฉ๋๋ค. | |
| print(f"\n--- ๋ชจ๋ธ ๋ก๋ ์ค: {LOAD_MODEL_PATH} ---") | |
| loaded_tokenizer = AutoTokenizer.from_pretrained(LOAD_MODEL_PATH) | |
| loaded_model = AutoModelForSequenceClassification.from_pretrained(LOAD_MODEL_PATH) | |
| # ๋ชจ๋ธ์ ์ค์ ๋ ์ฅ์น(GPU ๋๋ CPU)๋ก ์ด๋ | |
| loaded_model.to(device) | |
| loaded_model.eval() # ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์ (ํ์) | |
| # 4. ๋ถ๋ฅ ํจ์ ์ ์ | |
| def classify_review(text): | |
| # ํ ์คํธ๋ฅผ ํ ํฐํํ๊ณ ์ฅ์น๋ก ์ด๋ | |
| inputs = loaded_tokenizer( | |
| text, | |
| return_tensors="pt", # PyTorch ํ ์๋ก ๋ฐํ | |
| padding=True, | |
| truncation=True | |
| ).to(device) | |
| # ๋ชจ๋ธ ์ถ๋ก (Inference) | |
| with torch.no_grad(): | |
| outputs = loaded_model(**inputs) | |
| # ๊ฒฐ๊ณผ ์ฒ๋ฆฌ | |
| probabilities = torch.softmax(outputs.logits, dim=1) | |
| predicted_class_id = probabilities.argmax().item() | |
| # ๋ ์ด๋ธ ๋งคํ (ํ์ธ ํ๋ ์ ์ค์ ํ 0: ๋ถ์ , 1: ๊ธ์ ๊ธฐ์ค) | |
| label_map = {0: "๋ถ์ (Negative)", 1: "๊ธ์ (Positive)"} | |
| predicted_label = label_map[predicted_class_id] | |
| confidence = probabilities[0][predicted_class_id].item() | |
| return predicted_label, confidence | |
| # 5. ์๋ก์ด ๋น๊ทผ๋ง์ผ ๋ฆฌ๋ทฐ ํ ์คํธ ์คํ | |
| new_reviews = [ | |
| "๋งค๋๊ฐ ์ ๋ง ์ข์ผ์ธ์! ๊ธฐ๋ถ ์ข์ ๊ฑฐ๋๋ค์", | |
| "๋ฌผ๊ฑด ์ํ๊ฐ ์๊ฐ๋ณด๋ค ๋๋ฌด ์ ์ข์์ ์์๋ค๋ ๋๋์ด ๋ญ๋๋ค.", | |
| "๋น ๋ฅธ ๊ฑฐ๋ ๊ฐ์ฌํฉ๋๋ค. ๋ฌธ์ ์์ด ์ ๋ฐ์์ด์.", | |
| "์ฐ๋ฝ์ ์๋ฐ๋ค์", | |
| ] | |
| print("\n--- ์๋ก์ด ๋ฆฌ๋ทฐ ๋ถ๋ฅ ๊ฒฐ๊ณผ ---") | |
| for review in new_reviews: | |
| label, confidence = classify_review(review) | |
| print(f"๋ฆฌ๋ทฐ: '{review}'") | |
| print(f" -> ์์ธก ๋ถ๋ฅ: **{label}** (ํ๋ฅ : {confidence:.4f})") | |
| print("-" * 35) |