Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| # 1. GPU/CPU ์ฅ์น ์ค์ | |
| # CUDA (GPU) ์ฌ์ฉ ๊ฐ๋ฅํ๋ฉด 'cuda', ์๋๋ฉด 'cpu'๋ก ์ค์ ํฉ๋๋ค. | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"์ฌ์ฉ ์ฅ์น: {device}") | |
| # 2. ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ | |
| # XLM-RoBERTa-base๋ Sequence Classification์ด ์๋, ์ผ๋ฐ ์๋ฒ ๋ฉ ์ถ์ถ ๋ชจ๋ธ๋ก ๋ก๋ํฉ๋๋ค. | |
| MODEL_NAME = "FacebookAI/xlm-roberta-base" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModel.from_pretrained(MODEL_NAME) | |
| # ๋ชจ๋ธ์ ์ค์ ๋ ์ฅ์น(GPU ๋๋ CPU)๋ก ์ด๋ | |
| model.to(device) | |
| # 3. ์๋ฒ ๋ฉ(Embedding) ์ถ์ถ ํจ์ ์ ์ | |
| def get_text_embedding(text): | |
| # ํ ์คํธ๋ฅผ ํ ํฐํํ๊ณ ์ฅ์น๋ก ์ด๋ | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", # PyTorch ํ ์๋ก ๋ฐํ | |
| padding=True, | |
| truncation=True | |
| ).to(device) | |
| # ๋ชจ๋ธ ์ถ๋ก (Inference) | |
| with torch.no_grad(): | |
| # output์๋ last_hidden_state (๊ฐ ํ ํฐ์ ์๋ฒ ๋ฉ) ๋ฑ์ด ํฌํจ๋ฉ๋๋ค. | |
| outputs = model(**inputs) | |
| # ๋ฌธ์ฅ ์๋ฒ ๋ฉ ์ถ์ถ: [CLS] ํ ํฐ์ ์๋ฒ ๋ฉ์ ์ฌ์ฉํฉ๋๋ค. | |
| # last_hidden_state์ ์ฒซ ๋ฒ์งธ ํ ํฐ (์ธ๋ฑ์ค 0)์ด [CLS] ํ ํฐ์ด๋ฉฐ, ์ ์ฒด ๋ฌธ์ฅ์ ๋ํํฉ๋๋ค. | |
| # shape: (1, 768) | |
| sentence_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() | |
| return sentence_embedding[0] # numpy ๋ฐฐ์ด๋ก ๋ฐํ | |
| # 4. ๋น๊ทผ๋ง์ผ ๋ฆฌ๋ทฐ ์์ ์คํ | |
| review_sentences = [ | |
| "๋งค๋๊ฐ ์ ๋ง ์ข์ผ์๊ณ ๋ฌผ๊ฑด๋ ๊นจ๋ํด์ ๋ง์กฑ์ค๋ฌ์ ์ด์.", | |
| "์ด๊ฑด ์ข ์๋๋ฏ. ๋ฌผ๊ฑด ์ํ๋ ๋ณ๋ก๊ณ ๋ต๋ณ๋ ๋๋ ธ์ต๋๋ค.", | |
| "์ด ๋ชจ๋ธ์ ์ค๊ณ ์์ธ๋ ์ผ๋ง์ธ๊ฐ์?", # ์ผ๋ฐ์ ์ธ ์ง๋ฌธ ๋ฌธ์ฅ | |
| "This is a great product for the price." # ์ธ๊ตญ์ด ๋ฌธ์ฅ๋ ์ฒ๋ฆฌ ๊ฐ๋ฅ | |
| ] | |
| print("\n--- XLM-RoBERTa ์๋ฒ ๋ฉ ์ถ์ถ ๊ฒฐ๊ณผ ---") | |
| for sentence in review_sentences: | |
| embedding = get_text_embedding(sentence) | |
| print(f"๋ฌธ์ฅ: '{sentence}'") | |
| print(f" -> ์๋ฒ ๋ฉ ์ฐจ์: {embedding.shape}") # 768์ฐจ์ | |
| print(f" -> ์๋ฒ ๋ฉ ๋ฒกํฐ ์ผ๋ถ (์ฒซ 5๊ฐ): {embedding[:5].round(4)}") | |
| print("-" * 20) |