sigridjineth's picture
Initial model upload
36b0f2b verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from typing import List, Dict
# ---------------------------------------------------
# ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์—์„œ ์‚ฌ์šฉ๋œ ๋ชจ๋ธ ํด๋ž˜์Šค์™€ ํ•จ์ˆ˜ (๋ณ€๊ฒฝ ์—†์Œ)
# ---------------------------------------------------
class ColBERTEncoder(nn.Module):
def __init__(self, model_name: str, colbert_dim: int):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
hidden = self.encoder.config.hidden_size
self.proj = nn.Linear(hidden, colbert_dim, bias=False)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
H = out.last_hidden_state
H = self.proj(H)
H = F.normalize(H, p=2, dim=-1)
return H
def colbert_logits(
Q: torch.Tensor, MQ: torch.Tensor,
D: torch.Tensor, MD: torch.Tensor,
) -> torch.Tensor:
sim = torch.einsum("qd,kd->qk", Q.view(-1, Q.size(-1)), D.view(-1, D.size(-1)))
sim = sim.view(Q.size(0), Q.size(1), D.size(0), D.size(1))
sim = sim.masked_fill(~MD.bool().unsqueeze(0).unsqueeze(1), -1e4)
sim = sim.max(dim=-1).values
sim = sim.masked_fill(~MQ.bool().unsqueeze(-1), 0)
scores = sim.sum(dim=1)
return scores.squeeze(0)
# ---------------------------------------------------
# ์ถ”๋ก  ๋ฐ ๊ฒฐ๊ณผ ์ถœ๋ ฅ์„ ์œ„ํ•œ ํ—ฌํผ ํ•จ์ˆ˜
# ---------------------------------------------------
def run_inference(model: ColBERTEncoder, tokenizer: AutoTokenizer, query: str, documents: List[str], device: torch.device):
"""์ฃผ์–ด์ง„ ๋ชจ๋ธ๋กœ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜"""
# ์ฟผ๋ฆฌ ๋ฐ ๋ฌธ์„œ ์ธ์ฝ”๋”ฉ
with torch.no_grad():
q_inputs = tokenizer(query, return_tensors="pt", max_length=64, truncation=True).to(device)
Hq = model(**q_inputs)
d_inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt", max_length=192).to(device)
Hd = model(**d_inputs)
# ColBERT ์ ์ˆ˜ ๊ณ„์‚ฐ
scores = []
for i in range(len(documents)):
score = colbert_logits(
Q=Hq, MQ=q_inputs['attention_mask'],
D=Hd[i].unsqueeze(0), MD=d_inputs['attention_mask'][i].unsqueeze(0)
)
scores.append(score.item())
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
ranked_results = sorted(zip(scores, documents), key=lambda x: x[0], reverse=True)
for i, (score, doc) in enumerate(ranked_results):
print(f" Rank {i+1} (Score: {score:.2f}): {doc}")
# ---------------------------------------------------
# ๋ฉ”์ธ ๋น„๊ต ๋กœ์ง
# ---------------------------------------------------
def main():
# --- โš ๏ธ ์‚ฌ์šฉ์ž๊ฐ€ ์ˆ˜์ •ํ•ด์•ผ ํ•  ๋ถ€๋ถ„ ---
MODEL_NAME = "google/embeddinggemma-300m"
COLBERT_DIM = 128
CHECKPOINT_PATH = "ckpts_dist/vB/epoch1"
# ------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")
# 1. ํŒŒ์ธํŠœ๋‹๋œ ๋ชจ๋ธ ๋กœ๋”ฉ
print("Loading fine-tuned model...")
tokenizer = AutoTokenizer.from_pretrained(os.path.join(CHECKPOINT_PATH, "tokenizer"))
finetuned_model = ColBERTEncoder(MODEL_NAME, COLBERT_DIM).to(device)
finetuned_model.encoder = AutoModel.from_pretrained(os.path.join(CHECKPOINT_PATH, "encoder")).to(device)
proj_path = os.path.join(CHECKPOINT_PATH, "proj.pt")
finetuned_model.proj.load_state_dict(torch.load(proj_path, map_location=device))
finetuned_model.eval()
print("Fine-tuned model loaded.")
# 2. ์›๋ณธ(pre-trained) ๋ชจ๋ธ ๋กœ๋”ฉ
print("\nLoading original (pre-trained) model for comparison...")
original_model = ColBERTEncoder(MODEL_NAME, COLBERT_DIM).to(device)
# encoder๋Š” ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋ฐ”๋กœ ๋กœ๋“œ, proj ๋ ˆ์ด์–ด๋Š” ๋žœ๋ค ์ดˆ๊ธฐํ™” ์ƒํƒœ ๊ทธ๋Œ€๋กœ ๋‘ 
original_model.eval()
print("Original model loaded.")
# 3. ๊ฒ€์ƒ‰ํ•  ์ฟผ๋ฆฌ์™€ ๋ฌธ์„œ ์ •์˜
query = "์ผ๋ก  ๋จธ์Šคํฌ๊ฐ€ ์„ค๋ฆฝํ•œ ์ „๊ธฐ์ฐจ ํšŒ์‚ฌ๋Š” ์–ด๋””์•ผ?"
documents = [
"์ŠคํŽ˜์ด์ŠคX๋Š” ์žฌ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋กœ์ผ“์„ ๊ฐœ๋ฐœํ•˜์—ฌ ์šฐ์ฃผ ํƒ์‚ฌ ๋น„์šฉ์„ ํฌ๊ฒŒ ๋‚ฎ์ท„์Šต๋‹ˆ๋‹ค.", # ์ •๋‹ต๊ณผ ๊ฐ™์€ ์ธ๋ฌผ, ๋‹ค๋ฅธ ์ฃผ์ œ (๊ฐ•๋ ฅํ•œ ์˜ค๋‹ต ํ›„๋ณด 1)
"ํ…Œ์Šฌ๋ผ๋Š” ๋ชจ๋ธ S, 3, X, Y๋ฅผ ์ƒ์‚ฐํ•˜๋ฉฐ ์˜คํ† ํŒŒ์ผ๋Ÿฟ ๊ธฐ๋Šฅ์œผ๋กœ ์œ ๋ช…ํ•ฉ๋‹ˆ๋‹ค.", # โœ… ํ‚ค์›Œ๋“œ ์—†์ด ์˜๋ฏธ์ ์œผ๋กœ ์ •๋‹ต
"์•„๋งˆ์กด ์›น ์„œ๋น„์Šค(AWS)๋Š” ํด๋ผ์šฐ๋“œ ์ปดํ“จํŒ… ์‹œ์žฅ์˜ ์„ ๋‘์ฃผ์ž์ž…๋‹ˆ๋‹ค.", # ๊ด€๋ จ ์—†๋Š” ๋‚ด์šฉ
"์ผ๋ณธ์˜ ์ˆ˜๋„๋Š” ๋„์ฟ„์ž…๋‹ˆ๋‹ค. ๋ฒš๊ฝƒ์ด ์•„๋ฆ„๋‹ค์šด ๋„์‹œ์ฃ .",
"๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š” ์„œ์šธ์ž…๋‹ˆ๋‹ค. ์„œ์šธ์€ ๊ฒฝ์ œ์™€ ๋ฌธํ™”์˜ ์ค‘์‹ฌ์ง€์ž…๋‹ˆ๋‹ค.",
"์ˆ˜๋„๊ถŒ ์ „์ฒ ์€ ์„œ์šธ๊ณผ ์ฃผ๋ณ€ ๋„์‹œ๋ฅผ ์—ฐ๊ฒฐํ•˜๋Š” ์ค‘์š”ํ•œ ๊ตํ†ต์ˆ˜๋‹จ์ž…๋‹ˆ๋‹ค.",
"ํ”„๋ž‘์Šค์˜ ์ˆ˜๋„๋Š” ํŒŒ๋ฆฌ์ด๋ฉฐ, ์—ํŽ ํƒ‘์œผ๋กœ ์œ ๋ช…ํ•ฉ๋‹ˆ๋‹ค.",
]
print("\n" + "="*50)
print(f"Query: {query}")
print("="*50 + "\n")
# 4. ๊ฐ ๋ชจ๋ธ๋กœ ์ถ”๋ก  ์‹คํ–‰ ๋ฐ ๊ฒฐ๊ณผ ๋น„๊ต
print("--- 1. โœ… Fine-tuned Model Results ---")
run_inference(finetuned_model, tokenizer, query, documents, device)
print("\n--- 2. โŒ Original Model Results ---")
run_inference(original_model, tokenizer, query, documents, device)
if __name__ == "__main__":
main()