python_roberta_hf / xtreme_distil_use.py
WildOjisan's picture
.
e8a2c53
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# 1. ๋ชจ๋ธ์ด ์ €์žฅ๋œ ํด๋” ๊ฒฝ๋กœ ์ง€์ •
LOAD_MODEL_PATH = "./xtreme-distil-review-classifier"
# 2. GPU/CPU ์žฅ์น˜ ์„ค์ •
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์‚ฌ์šฉ ์žฅ์น˜: {device}")
# 3. ์ €์žฅ๋œ ํ† ํฌ๋‚˜์ด์ €์™€ ๋ชจ๋ธ ๋กœ๋“œ
# ์ €์žฅ๋œ config.json๊ณผ model.safetensors ํŒŒ์ผ์„ ๋ฐ”ํƒ•์œผ๋กœ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
print(f"\n--- ๋ชจ๋ธ ๋กœ๋“œ ์ค‘: {LOAD_MODEL_PATH} ---")
loaded_tokenizer = AutoTokenizer.from_pretrained(LOAD_MODEL_PATH)
loaded_model = AutoModelForSequenceClassification.from_pretrained(LOAD_MODEL_PATH)
# ๋ชจ๋ธ์„ ์„ค์ •๋œ ์žฅ์น˜(GPU ๋˜๋Š” CPU)๋กœ ์ด๋™
loaded_model.to(device)
loaded_model.eval() # ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ • (ํ•„์ˆ˜)
# 4. ๋ถ„๋ฅ˜ ํ•จ์ˆ˜ ์ •์˜
def classify_review(text):
# ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๊ณ  ์žฅ์น˜๋กœ ์ด๋™
inputs = loaded_tokenizer(
text,
return_tensors="pt", # PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
padding=True,
truncation=True
).to(device)
# ๋ชจ๋ธ ์ถ”๋ก  (Inference)
with torch.no_grad():
outputs = loaded_model(**inputs)
# ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
probabilities = torch.softmax(outputs.logits, dim=1)
predicted_class_id = probabilities.argmax().item()
# ๋ ˆ์ด๋ธ” ๋งคํ•‘ (ํŒŒ์ธ ํŠœ๋‹ ์‹œ ์„ค์ •ํ•œ 0: ๋ถ€์ •, 1: ๊ธ์ • ๊ธฐ์ค€)
label_map = {0: "๋ถ€์ • (Negative)", 1: "๊ธ์ • (Positive)"}
predicted_label = label_map[predicted_class_id]
confidence = probabilities[0][predicted_class_id].item()
return predicted_label, confidence
# 5. ์ƒˆ๋กœ์šด ๋‹น๊ทผ๋งˆ์ผ“ ๋ฆฌ๋ทฐ ํ…Œ์ŠคํŠธ ์‹คํ–‰
new_reviews = [
"๋งค๋„ˆ๊ฐ€ ์ •๋ง ์ข‹์œผ์„ธ์š”! ๊ธฐ๋ถ„ ์ข‹์€ ๊ฑฐ๋ž˜๋„ค์š”",
"๋ฌผ๊ฑด ์ƒํƒœ๊ฐ€ ์ƒ๊ฐ๋ณด๋‹ค ๋„ˆ๋ฌด ์•ˆ ์ข‹์•„์„œ ์†์•˜๋‹ค๋Š” ๋А๋‚Œ์ด ๋“ญ๋‹ˆ๋‹ค.",
"๋น ๋ฅธ ๊ฑฐ๋ž˜ ๊ฐ์‚ฌํ•ฉ๋‹ˆ๋‹ค. ๋ฌธ์ œ ์—†์ด ์ž˜ ๋ฐ›์•˜์–ด์š”.",
"์—ฐ๋ฝ์„ ์•ˆ๋ฐ›๋„ค์š”",
]
print("\n--- ์ƒˆ๋กœ์šด ๋ฆฌ๋ทฐ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ ---")
for review in new_reviews:
label, confidence = classify_review(review)
print(f"๋ฆฌ๋ทฐ: '{review}'")
print(f" -> ์˜ˆ์ธก ๋ถ„๋ฅ˜: **{label}** (ํ™•๋ฅ : {confidence:.4f})")
print("-" * 35)