Spaces:
Runtime error
Runtime error
File size: 2,410 Bytes
899f482 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
import numpy as np
from transformers import AutoModel
# KoBERT ์ ์ฉ ํ ํฌ๋์ด์ ๋ก๋ (Hugging Face ํ ํฌ๋์ด์ ์ ๋ค๋ฆ)
from kobert_tokenizer import KoBERTTokenizer
# 1. GPU/CPU ์ฅ์น ์ค์
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์ฌ์ฉ ์ฅ์น: {device}")
# 2. ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ (์ถ๊ฐ ์์ )
MODEL_NAME = "monologg/kobert"
# ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ ๋ 'monologg/kobert' ๋์
# SKT Brain์ ๊ณต์ ์ ์ฅ์ ์ด๋ฆ์ธ 'skt/kobert-base-v1'์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ์์ ์ ์
๋๋ค.
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
model = AutoModel.from_pretrained(MODEL_NAME)
# ๋ชจ๋ธ์ ์ค์ ๋ ์ฅ์น(GPU ๋๋ CPU)๋ก ์ด๋
model.to(device)
# 3. ์๋ฒ ๋ฉ(Embedding) ์ถ์ถ ํจ์ ์ ์
def get_kobert_embedding(text):
# ํ
์คํธ ํ ํฐํ ๋ฐ ์
๋ ฅ ํ์์ผ๋ก ๋ณํ
inputs = tokenizer.batch_encode_plus(
[text], # ๋ฆฌ์คํธ ํํ๋ก ์
๋ ฅ
padding='max_length',
max_length=64, # ์ต๋ ๊ธธ์ด ์ง์ (ํ์์ ๋ฐ๋ผ ์กฐ์ )
truncation=True,
return_tensors="pt" # PyTorch ํ
์๋ก ๋ฐํ
).to(device)
# ๋ชจ๋ธ ์ถ๋ก (Inference)
with torch.no_grad():
# output์๋ last_hidden_state (๊ฐ ํ ํฐ์ ์๋ฒ ๋ฉ) ๋ฑ์ด ํฌํจ๋ฉ๋๋ค.
outputs = model(**inputs)
# ๋ฌธ์ฅ ์๋ฒ ๋ฉ ์ถ์ถ: [CLS] ํ ํฐ์ ์๋ฒ ๋ฉ์ ์ฌ์ฉํฉ๋๋ค.
# last_hidden_state์ ์ฒซ ๋ฒ์งธ ํ ํฐ (์ธ๋ฑ์ค 0)์ด [CLS] ํ ํฐ์ด๋ฉฐ, ์ ์ฒด ๋ฌธ์ฅ์ ๋ํํฉ๋๋ค.
# shape: (1, 768)
sentence_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return sentence_embedding[0] # numpy ๋ฐฐ์ด (768์ฐจ์)๋ก ๋ฐํ
# 4. ๋น๊ทผ๋ง์ผ ๋ฆฌ๋ทฐ ์์ ์คํ
review_sentences = [
"ํ๋งค์๋ ๋งค๋๊ฐ ๋๋ฌด ์ข์์ ๊ธฐ๋ถ ์ข์ ๊ฑฐ๋์์ต๋๋ค.",
"๋ฌผ๊ฑด ์ํ๊ฐ ์๊ฐ๋ณด๋ค ๋ณ๋ก์ฌ์ ์์ฝ๋ค์. ๋ค์์ ๊ฑฐ๋ ์ ํ ๊ฒ ๊ฐ์์.",
"์ด ์์ ๊ฑฐ ๋ชจ๋ธ์ ์ค๊ณ ์์ธ๊ฐ ์ด๋ ์ ๋์ผ๊น์?",
]
print("\n--- KoBERT ์๋ฒ ๋ฉ ์ถ์ถ ๊ฒฐ๊ณผ ---")
for sentence in review_sentences:
embedding = get_kobert_embedding(sentence)
print(f"๋ฌธ์ฅ: '{sentence}'")
print(f" -> ์๋ฒ ๋ฉ ์ฐจ์: {embedding.shape}") # 768์ฐจ์
print(f" -> ์๋ฒ ๋ฉ ๋ฒกํฐ ์ผ๋ถ (์ฒซ 5๊ฐ): {embedding[:5].round(4)}")
print("-" * 30) |