|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from typing import List, Dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ColBERTEncoder(nn.Module): |
|
|
def __init__(self, model_name: str, colbert_dim: int): |
|
|
super().__init__() |
|
|
self.encoder = AutoModel.from_pretrained(model_name) |
|
|
hidden = self.encoder.config.hidden_size |
|
|
self.proj = nn.Linear(hidden, colbert_dim, bias=False) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) |
|
|
H = out.last_hidden_state |
|
|
H = self.proj(H) |
|
|
H = F.normalize(H, p=2, dim=-1) |
|
|
return H |
|
|
|
|
|
def colbert_logits( |
|
|
Q: torch.Tensor, MQ: torch.Tensor, |
|
|
D: torch.Tensor, MD: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
sim = torch.einsum("qd,kd->qk", Q.view(-1, Q.size(-1)), D.view(-1, D.size(-1))) |
|
|
sim = sim.view(Q.size(0), Q.size(1), D.size(0), D.size(1)) |
|
|
sim = sim.masked_fill(~MD.bool().unsqueeze(0).unsqueeze(1), -1e4) |
|
|
sim = sim.max(dim=-1).values |
|
|
sim = sim.masked_fill(~MQ.bool().unsqueeze(-1), 0) |
|
|
scores = sim.sum(dim=1) |
|
|
return scores.squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_inference(model: ColBERTEncoder, tokenizer: AutoTokenizer, query: str, documents: List[str], device: torch.device): |
|
|
"""์ฃผ์ด์ง ๋ชจ๋ธ๋ก ์ถ๋ก ์ ์คํํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋ ํจ์""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
q_inputs = tokenizer(query, return_tensors="pt", max_length=64, truncation=True).to(device) |
|
|
Hq = model(**q_inputs) |
|
|
|
|
|
d_inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt", max_length=192).to(device) |
|
|
Hd = model(**d_inputs) |
|
|
|
|
|
|
|
|
scores = [] |
|
|
for i in range(len(documents)): |
|
|
score = colbert_logits( |
|
|
Q=Hq, MQ=q_inputs['attention_mask'], |
|
|
D=Hd[i].unsqueeze(0), MD=d_inputs['attention_mask'][i].unsqueeze(0) |
|
|
) |
|
|
scores.append(score.item()) |
|
|
|
|
|
|
|
|
ranked_results = sorted(zip(scores, documents), key=lambda x: x[0], reverse=True) |
|
|
for i, (score, doc) in enumerate(ranked_results): |
|
|
print(f" Rank {i+1} (Score: {score:.2f}): {doc}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
MODEL_NAME = "google/embeddinggemma-300m" |
|
|
COLBERT_DIM = 128 |
|
|
CHECKPOINT_PATH = "ckpts_dist/vB/epoch1" |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}\n") |
|
|
|
|
|
|
|
|
print("Loading fine-tuned model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(os.path.join(CHECKPOINT_PATH, "tokenizer")) |
|
|
finetuned_model = ColBERTEncoder(MODEL_NAME, COLBERT_DIM).to(device) |
|
|
finetuned_model.encoder = AutoModel.from_pretrained(os.path.join(CHECKPOINT_PATH, "encoder")).to(device) |
|
|
proj_path = os.path.join(CHECKPOINT_PATH, "proj.pt") |
|
|
finetuned_model.proj.load_state_dict(torch.load(proj_path, map_location=device)) |
|
|
finetuned_model.eval() |
|
|
print("Fine-tuned model loaded.") |
|
|
|
|
|
|
|
|
print("\nLoading original (pre-trained) model for comparison...") |
|
|
original_model = ColBERTEncoder(MODEL_NAME, COLBERT_DIM).to(device) |
|
|
|
|
|
original_model.eval() |
|
|
print("Original model loaded.") |
|
|
|
|
|
|
|
|
query = "์ผ๋ก ๋จธ์คํฌ๊ฐ ์ค๋ฆฝํ ์ ๊ธฐ์ฐจ ํ์ฌ๋ ์ด๋์ผ?" |
|
|
documents = [ |
|
|
"์คํ์ด์คX๋ ์ฌ์ฌ์ฉ ๊ฐ๋ฅํ ๋ก์ผ์ ๊ฐ๋ฐํ์ฌ ์ฐ์ฃผ ํ์ฌ ๋น์ฉ์ ํฌ๊ฒ ๋ฎ์ท์ต๋๋ค.", |
|
|
"ํ
์ฌ๋ผ๋ ๋ชจ๋ธ S, 3, X, Y๋ฅผ ์์ฐํ๋ฉฐ ์คํ ํ์ผ๋ฟ ๊ธฐ๋ฅ์ผ๋ก ์ ๋ช
ํฉ๋๋ค.", |
|
|
"์๋ง์กด ์น ์๋น์ค(AWS)๋ ํด๋ผ์ฐ๋ ์ปดํจํ
์์ฅ์ ์ ๋์ฃผ์์
๋๋ค.", |
|
|
"์ผ๋ณธ์ ์๋๋ ๋์ฟ์
๋๋ค. ๋ฒ๊ฝ์ด ์๋ฆ๋ค์ด ๋์์ฃ .", |
|
|
"๋ํ๋ฏผ๊ตญ์ ์๋๋ ์์ธ์
๋๋ค. ์์ธ์ ๊ฒฝ์ ์ ๋ฌธํ์ ์ค์ฌ์ง์
๋๋ค.", |
|
|
"์๋๊ถ ์ ์ฒ ์ ์์ธ๊ณผ ์ฃผ๋ณ ๋์๋ฅผ ์ฐ๊ฒฐํ๋ ์ค์ํ ๊ตํต์๋จ์
๋๋ค.", |
|
|
"ํ๋์ค์ ์๋๋ ํ๋ฆฌ์ด๋ฉฐ, ์ํ ํ์ผ๋ก ์ ๋ช
ํฉ๋๋ค.", |
|
|
] |
|
|
|
|
|
print("\n" + "="*50) |
|
|
print(f"Query: {query}") |
|
|
print("="*50 + "\n") |
|
|
|
|
|
|
|
|
print("--- 1. โ
Fine-tuned Model Results ---") |
|
|
run_inference(finetuned_model, tokenizer, query, documents, device) |
|
|
|
|
|
print("\n--- 2. โ Original Model Results ---") |
|
|
run_inference(original_model, tokenizer, query, documents, device) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|