python_roberta_hf / roberta_test.py
WildOjisan's picture
.
899f482
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
# 1. GPU/CPU ์žฅ์น˜ ์„ค์ •
# CUDA (GPU) ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋ฉด 'cuda', ์•„๋‹ˆ๋ฉด 'cpu'๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์‚ฌ์šฉ ์žฅ์น˜: {device}")
# 2. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
# XLM-RoBERTa-base๋Š” Sequence Classification์ด ์•„๋‹Œ, ์ผ๋ฐ˜ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ ๋ชจ๋ธ๋กœ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
MODEL_NAME = "FacebookAI/xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
# ๋ชจ๋ธ์„ ์„ค์ •๋œ ์žฅ์น˜(GPU ๋˜๋Š” CPU)๋กœ ์ด๋™
model.to(device)
# 3. ์ž„๋ฒ ๋”ฉ(Embedding) ์ถ”์ถœ ํ•จ์ˆ˜ ์ •์˜
def get_text_embedding(text):
# ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™”ํ•˜๊ณ  ์žฅ์น˜๋กœ ์ด๋™
inputs = tokenizer(
text,
return_tensors="pt", # PyTorch ํ…์„œ๋กœ ๋ฐ˜ํ™˜
padding=True,
truncation=True
).to(device)
# ๋ชจ๋ธ ์ถ”๋ก  (Inference)
with torch.no_grad():
# output์—๋Š” last_hidden_state (๊ฐ ํ† ํฐ์˜ ์ž„๋ฒ ๋”ฉ) ๋“ฑ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.
outputs = model(**inputs)
# ๋ฌธ์žฅ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ: [CLS] ํ† ํฐ์˜ ์ž„๋ฒ ๋”ฉ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# last_hidden_state์˜ ์ฒซ ๋ฒˆ์งธ ํ† ํฐ (์ธ๋ฑ์Šค 0)์ด [CLS] ํ† ํฐ์ด๋ฉฐ, ์ „์ฒด ๋ฌธ์žฅ์„ ๋Œ€ํ‘œํ•ฉ๋‹ˆ๋‹ค.
# shape: (1, 768)
sentence_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return sentence_embedding[0] # numpy ๋ฐฐ์—ด๋กœ ๋ฐ˜ํ™˜
# 4. ๋‹น๊ทผ๋งˆ์ผ“ ๋ฆฌ๋ทฐ ์˜ˆ์ œ ์‹คํ–‰
review_sentences = [
"๋งค๋„ˆ๊ฐ€ ์ •๋ง ์ข‹์œผ์‹œ๊ณ  ๋ฌผ๊ฑด๋„ ๊นจ๋—ํ•ด์„œ ๋งŒ์กฑ์Šค๋Ÿฌ์› ์–ด์š”.",
"์ด๊ฑด ์ข€ ์•„๋‹Œ๋“ฏ. ๋ฌผ๊ฑด ์ƒํƒœ๋„ ๋ณ„๋กœ๊ณ  ๋‹ต๋ณ€๋„ ๋А๋ ธ์Šต๋‹ˆ๋‹ค.",
"์ด ๋ชจ๋ธ์˜ ์ค‘๊ณ  ์‹œ์„ธ๋Š” ์–ผ๋งˆ์ธ๊ฐ€์š”?", # ์ผ๋ฐ˜์ ์ธ ์งˆ๋ฌธ ๋ฌธ์žฅ
"This is a great product for the price." # ์™ธ๊ตญ์–ด ๋ฌธ์žฅ๋„ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅ
]
print("\n--- XLM-RoBERTa ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ ๊ฒฐ๊ณผ ---")
for sentence in review_sentences:
embedding = get_text_embedding(sentence)
print(f"๋ฌธ์žฅ: '{sentence}'")
print(f" -> ์ž„๋ฒ ๋”ฉ ์ฐจ์›: {embedding.shape}") # 768์ฐจ์›
print(f" -> ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ์ผ๋ถ€ (์ฒซ 5๊ฐœ): {embedding[:5].round(4)}")
print("-" * 20)