project-tdm / mismatch_model.py
hy
Fix server code and sync with remote
7134b06
import torch
import torch.nn.functional as F
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
# ๋””๋ฐ”์ด์Šค ์„ค์ • (GPU ์šฐ์„ , ์—†์œผ๋ฉด CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"โœ… ํ˜„์žฌ ์‹คํ–‰ ํ™˜๊ฒฝ: {device}")
# =============================================================================
# 2. ๋ชจ๋ธ ๋กœ๋“œ (์‹œ๊ฐ„์ด ์กฐ๊ธˆ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค)
# =============================================================================
print("\nโณ [1/3] KoBART ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
kobart_summarizer = pipeline(
"summarization",
model="gogamza/kobart-summarization",
device=0 if torch.cuda.is_available() else -1
)
print("โณ [2/3] SBERT ์œ ์‚ฌ๋„ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
sbert_model = SentenceTransformer('jhgan/ko-sroberta-multitask')
print("โณ [3/3] NLI(๋ชจ์ˆœ ํƒ์ง€) ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
nli_model_name = "Huffon/klue-roberta-base-nli"
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)
print("๐ŸŽ‰ ๋ชจ๋“  ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!\n")
# =============================================================================
# 3. ๋„์šฐ๋ฏธ ํ•จ์ˆ˜ ์ •์˜ (Worker Functions)
# =============================================================================
def summarize_kobart_strict(text):
"""KoBART๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์š”์•ฝํ•ฉ๋‹ˆ๋‹ค."""
# ๋ณธ๋ฌธ์ด ๋„ˆ๋ฌด ์งง์œผ๋ฉด ์š”์•ฝ ์ƒ๋žต (์˜ค๋ฅ˜ ๋ฐฉ์ง€)
if len(text) < 50:
return text
try:
result = kobart_summarizer(
text,
min_length=15,
max_length=128,
num_beams=4,
no_repeat_ngram_size=3,
early_stopping=True
)[0]['summary_text']
return result.strip()
except Exception as e:
return text[:100] # ์‹คํŒจ ์‹œ ์•ž๋ถ€๋ถ„ ๋ฐ˜ํ™˜
def get_cosine_similarity(title, summary):
"""SBERT๋กœ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
emb1 = sbert_model.encode(title, convert_to_tensor=True)
emb2 = sbert_model.encode(summary, convert_to_tensor=True)
return util.cos_sim(emb1, emb2).item()
def get_mismatch_score(summary, title):
"""NLI ๋ชจ๋ธ๋กœ ์š”์•ฝ๋ฌธ(์ „์ œ)๊ณผ ์ œ๋ชฉ(๊ฐ€์„ค) ์‚ฌ์ด์˜ ๋ชจ์ˆœ ํ™•๋ฅ ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
inputs = nli_tokenizer(
summary, title,
return_tensors='pt',
truncation=True,
max_length=512
).to(device)
# RoBERTa ๋ชจ๋ธ ์—๋Ÿฌ ๋ฐฉ์ง€ (token_type_ids ์ œ๊ฑฐ)
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
with torch.no_grad():
outputs = nli_model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
# Huffon/klue-roberta-base-nli ๋ผ๋ฒจ ์ˆœ์„œ: [Entailment, Neutral, Contradiction]
# ๋ชจ์ˆœ(Contradiction) ํ™•๋ฅ  ๋ฐ˜ํ™˜ (Index 2)
return round(probs[2].item(), 4)
# =============================================================================
# 4. ์ตœ์ข… ๋ฉ”์ธ ํ•จ์ˆ˜ (Main Logic)
# =============================================================================
def calculate_mismatch_score(article_title, article_body):
"""
Grid Search ๊ฒฐ๊ณผ ์ตœ์  ๊ฐ€์ค‘์น˜ ์ ์šฉ:
- w1 (SBERT, ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ): 0.8
- w2 (NLI, ๋…ผ๋ฆฌ์  ๋ชจ์ˆœ): 0.2
- Threshold (์ž„๊ณ„๊ฐ’): 0.45 ์ด์ƒ์ด๋ฉด '์œ„ํ—˜'
"""
#if not (kobart_summarizer and sbert_model and nli_model):
# return {"score": 0.0, "reason": "๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ", "recommendation": "์„œ๋ฒ„ ํ™•์ธ ํ•„์š”"}
# 1. ๋ณธ๋ฌธ ์š”์•ฝ
summary = summarize_kobart_strict(article_body)
# 2. SBERT ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ (1 - ์œ ์‚ฌ๋„)
sbert_sim = get_cosine_similarity(article_title, summary)
semantic_distance = 1 - sbert_sim
# 3. NLI ๋…ผ๋ฆฌ์  ๋ชจ์ˆœ ํ™•๋ฅ 
nli_contradiction = get_mismatch_score(summary, article_title)
# 4. ์ตœ์ข… ์ ์ˆ˜ ์‚ฐ์ถœ
w1, w2 = 0.8, 0.2
final_score = (w1 * semantic_distance) + (w2 * nli_contradiction)
reason = (
f"[๋””๋ฒ„๊ทธ ๋ชจ๋“œ]\n"
f"1. ์š”์•ฝ๋ฌธ: {summary}\n"
f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n"
f"3. NLI ๋ชจ์ˆœ: {nli_contradiction:.4f}"
)
#reason = f"์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ({semantic_distance:.4f})์™€ ๋ชจ์ˆœ ํ™•๋ฅ ({nli_contradiction:.4f})์ด ๋ฐ˜์˜๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
# 5. ๊ฒฐ๊ณผ ํŒ์ • (Threshold 0.45 ๊ธฐ์ค€)
if final_score >= 0.45:
recommendation = "์ œ๋ชฉ์ด ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์„ ์™œ๊ณกํ•˜๊ฑฐ๋‚˜ ๋ชจ์ˆœ๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Šต๋‹ˆ๋‹ค."
else:
recommendation = "์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์ด ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค."
# main.py๋กœ ์ „๋‹ฌํ•  ๋ฐ์ดํ„ฐ
return {
"score": round(final_score, 4),
"reason": reason,
"recommendation": recommendation
}