Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from transformers import AutoModel | |
| # KoBERT ์ ์ฉ ํ ํฌ๋์ด์ ๋ก๋ (Hugging Face ํ ํฌ๋์ด์ ์ ๋ค๋ฆ) | |
| from kobert_tokenizer import KoBERTTokenizer | |
| # 1. GPU/CPU ์ฅ์น ์ค์ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"์ฌ์ฉ ์ฅ์น: {device}") | |
| # 2. ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ (์ถ๊ฐ ์์ ) | |
| MODEL_NAME = "monologg/kobert" | |
| # ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํ ๋ 'monologg/kobert' ๋์ | |
| # SKT Brain์ ๊ณต์ ์ ์ฅ์ ์ด๋ฆ์ธ 'skt/kobert-base-v1'์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋ ์์ ์ ์ ๋๋ค. | |
| tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1') | |
| model = AutoModel.from_pretrained(MODEL_NAME) | |
| # ๋ชจ๋ธ์ ์ค์ ๋ ์ฅ์น(GPU ๋๋ CPU)๋ก ์ด๋ | |
| model.to(device) | |
| # 3. ์๋ฒ ๋ฉ(Embedding) ์ถ์ถ ํจ์ ์ ์ | |
| def get_kobert_embedding(text): | |
| # ํ ์คํธ ํ ํฐํ ๋ฐ ์ ๋ ฅ ํ์์ผ๋ก ๋ณํ | |
| inputs = tokenizer.batch_encode_plus( | |
| [text], # ๋ฆฌ์คํธ ํํ๋ก ์ ๋ ฅ | |
| padding='max_length', | |
| max_length=64, # ์ต๋ ๊ธธ์ด ์ง์ (ํ์์ ๋ฐ๋ผ ์กฐ์ ) | |
| truncation=True, | |
| return_tensors="pt" # PyTorch ํ ์๋ก ๋ฐํ | |
| ).to(device) | |
| # ๋ชจ๋ธ ์ถ๋ก (Inference) | |
| with torch.no_grad(): | |
| # output์๋ last_hidden_state (๊ฐ ํ ํฐ์ ์๋ฒ ๋ฉ) ๋ฑ์ด ํฌํจ๋ฉ๋๋ค. | |
| outputs = model(**inputs) | |
| # ๋ฌธ์ฅ ์๋ฒ ๋ฉ ์ถ์ถ: [CLS] ํ ํฐ์ ์๋ฒ ๋ฉ์ ์ฌ์ฉํฉ๋๋ค. | |
| # last_hidden_state์ ์ฒซ ๋ฒ์งธ ํ ํฐ (์ธ๋ฑ์ค 0)์ด [CLS] ํ ํฐ์ด๋ฉฐ, ์ ์ฒด ๋ฌธ์ฅ์ ๋ํํฉ๋๋ค. | |
| # shape: (1, 768) | |
| sentence_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
| return sentence_embedding[0] # numpy ๋ฐฐ์ด (768์ฐจ์)๋ก ๋ฐํ | |
| # 4. ๋น๊ทผ๋ง์ผ ๋ฆฌ๋ทฐ ์์ ์คํ | |
| review_sentences = [ | |
| "ํ๋งค์๋ ๋งค๋๊ฐ ๋๋ฌด ์ข์์ ๊ธฐ๋ถ ์ข์ ๊ฑฐ๋์์ต๋๋ค.", | |
| "๋ฌผ๊ฑด ์ํ๊ฐ ์๊ฐ๋ณด๋ค ๋ณ๋ก์ฌ์ ์์ฝ๋ค์. ๋ค์์ ๊ฑฐ๋ ์ ํ ๊ฒ ๊ฐ์์.", | |
| "์ด ์์ ๊ฑฐ ๋ชจ๋ธ์ ์ค๊ณ ์์ธ๊ฐ ์ด๋ ์ ๋์ผ๊น์?", | |
| ] | |
| print("\n--- KoBERT ์๋ฒ ๋ฉ ์ถ์ถ ๊ฒฐ๊ณผ ---") | |
| for sentence in review_sentences: | |
| embedding = get_kobert_embedding(sentence) | |
| print(f"๋ฌธ์ฅ: '{sentence}'") | |
| print(f" -> ์๋ฒ ๋ฉ ์ฐจ์: {embedding.shape}") # 768์ฐจ์ | |
| print(f" -> ์๋ฒ ๋ฉ ๋ฒกํฐ ์ผ๋ถ (์ฒซ 5๊ฐ): {embedding[:5].round(4)}") | |
| print("-" * 30) |