import os import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel from typing import List, Dict # --------------------------------------------------- # 학습 스크립트에서 사용된 모델 클래스와 함수 (변경 없음) # --------------------------------------------------- class ColBERTEncoder(nn.Module): def __init__(self, model_name: str, colbert_dim: int): super().__init__() self.encoder = AutoModel.from_pretrained(model_name) hidden = self.encoder.config.hidden_size self.proj = nn.Linear(hidden, colbert_dim, bias=False) def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) H = out.last_hidden_state H = self.proj(H) H = F.normalize(H, p=2, dim=-1) return H def colbert_logits( Q: torch.Tensor, MQ: torch.Tensor, D: torch.Tensor, MD: torch.Tensor, ) -> torch.Tensor: sim = torch.einsum("qd,kd->qk", Q.view(-1, Q.size(-1)), D.view(-1, D.size(-1))) sim = sim.view(Q.size(0), Q.size(1), D.size(0), D.size(1)) sim = sim.masked_fill(~MD.bool().unsqueeze(0).unsqueeze(1), -1e4) sim = sim.max(dim=-1).values sim = sim.masked_fill(~MQ.bool().unsqueeze(-1), 0) scores = sim.sum(dim=1) return scores.squeeze(0) # --------------------------------------------------- # 추론 및 결과 출력을 위한 헬퍼 함수 # --------------------------------------------------- def run_inference(model: ColBERTEncoder, tokenizer: AutoTokenizer, query: str, documents: List[str], device: torch.device): """주어진 모델로 추론을 실행하고 결과를 출력하는 함수""" # 쿼리 및 문서 인코딩 with torch.no_grad(): q_inputs = tokenizer(query, return_tensors="pt", max_length=64, truncation=True).to(device) Hq = model(**q_inputs) d_inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt", max_length=192).to(device) Hd = model(**d_inputs) # ColBERT 점수 계산 scores = [] for i in range(len(documents)): score = colbert_logits( Q=Hq, MQ=q_inputs['attention_mask'], D=Hd[i].unsqueeze(0), MD=d_inputs['attention_mask'][i].unsqueeze(0) ) scores.append(score.item()) # 결과 출력 ranked_results = sorted(zip(scores, documents), key=lambda x: x[0], reverse=True) for i, (score, doc) in enumerate(ranked_results): print(f" Rank {i+1} (Score: {score:.2f}): {doc}") # --------------------------------------------------- # 메인 비교 로직 # --------------------------------------------------- def main(): # --- ⚠️ 사용자가 수정해야 할 부분 --- MODEL_NAME = "google/embeddinggemma-300m" COLBERT_DIM = 128 CHECKPOINT_PATH = "ckpts_dist/vB/epoch1" # ------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}\n") # 1. 파인튜닝된 모델 로딩 print("Loading fine-tuned model...") tokenizer = AutoTokenizer.from_pretrained(os.path.join(CHECKPOINT_PATH, "tokenizer")) finetuned_model = ColBERTEncoder(MODEL_NAME, COLBERT_DIM).to(device) finetuned_model.encoder = AutoModel.from_pretrained(os.path.join(CHECKPOINT_PATH, "encoder")).to(device) proj_path = os.path.join(CHECKPOINT_PATH, "proj.pt") finetuned_model.proj.load_state_dict(torch.load(proj_path, map_location=device)) finetuned_model.eval() print("Fine-tuned model loaded.") # 2. 원본(pre-trained) 모델 로딩 print("\nLoading original (pre-trained) model for comparison...") original_model = ColBERTEncoder(MODEL_NAME, COLBERT_DIM).to(device) # encoder는 허깅페이스에서 바로 로드, proj 레이어는 랜덤 초기화 상태 그대로 둠 original_model.eval() print("Original model loaded.") # 3. 검색할 쿼리와 문서 정의 query = "일론 머스크가 설립한 전기차 회사는 어디야?" documents = [ "스페이스X는 재사용 가능한 로켓을 개발하여 우주 탐사 비용을 크게 낮췄습니다.", # 정답과 같은 인물, 다른 주제 (강력한 오답 후보 1) "테슬라는 모델 S, 3, X, Y를 생산하며 오토파일럿 기능으로 유명합니다.", # ✅ 키워드 없이 의미적으로 정답 "아마존 웹 서비스(AWS)는 클라우드 컴퓨팅 시장의 선두주자입니다.", # 관련 없는 내용 "일본의 수도는 도쿄입니다. 벚꽃이 아름다운 도시죠.", "대한민국의 수도는 서울입니다. 서울은 경제와 문화의 중심지입니다.", "수도권 전철은 서울과 주변 도시를 연결하는 중요한 교통수단입니다.", "프랑스의 수도는 파리이며, 에펠탑으로 유명합니다.", ] print("\n" + "="*50) print(f"Query: {query}") print("="*50 + "\n") # 4. 각 모델로 추론 실행 및 결과 비교 print("--- 1. ✅ Fine-tuned Model Results ---") run_inference(finetuned_model, tokenizer, query, documents, device) print("\n--- 2. ❌ Original Model Results ---") run_inference(original_model, tokenizer, query, documents, device) if __name__ == "__main__": main()