Spaces:
Running
Running
File size: 4,876 Bytes
7134b06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import torch.nn.functional as F
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
# ๋๋ฐ์ด์ค ์ค์ (GPU ์ฐ์ , ์์ผ๋ฉด CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"โ
ํ์ฌ ์คํ ํ๊ฒฝ: {device}")
# =============================================================================
# 2. ๋ชจ๋ธ ๋ก๋ (์๊ฐ์ด ์กฐ๊ธ ๊ฑธ๋ฆด ์ ์์ต๋๋ค)
# =============================================================================
print("\nโณ [1/3] KoBART ์์ฝ ๋ชจ๋ธ ๋ก๋ฉ ์ค...")
kobart_summarizer = pipeline(
"summarization",
model="gogamza/kobart-summarization",
device=0 if torch.cuda.is_available() else -1
)
print("โณ [2/3] SBERT ์ ์ฌ๋ ๋ชจ๋ธ ๋ก๋ฉ ์ค...")
sbert_model = SentenceTransformer('jhgan/ko-sroberta-multitask')
print("โณ [3/3] NLI(๋ชจ์ ํ์ง) ๋ชจ๋ธ ๋ก๋ฉ ์ค...")
nli_model_name = "Huffon/klue-roberta-base-nli"
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)
print("๐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์๋ฃ!\n")
# =============================================================================
# 3. ๋์ฐ๋ฏธ ํจ์ ์ ์ (Worker Functions)
# =============================================================================
def summarize_kobart_strict(text):
"""KoBART๋ฅผ ์ฌ์ฉํ์ฌ ๋ณธ๋ฌธ์ ์์ฝํฉ๋๋ค."""
# ๋ณธ๋ฌธ์ด ๋๋ฌด ์งง์ผ๋ฉด ์์ฝ ์๋ต (์ค๋ฅ ๋ฐฉ์ง)
if len(text) < 50:
return text
try:
result = kobart_summarizer(
text,
min_length=15,
max_length=128,
num_beams=4,
no_repeat_ngram_size=3,
early_stopping=True
)[0]['summary_text']
return result.strip()
except Exception as e:
return text[:100] # ์คํจ ์ ์๋ถ๋ถ ๋ฐํ
def get_cosine_similarity(title, summary):
"""SBERT๋ก ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค."""
emb1 = sbert_model.encode(title, convert_to_tensor=True)
emb2 = sbert_model.encode(summary, convert_to_tensor=True)
return util.cos_sim(emb1, emb2).item()
def get_mismatch_score(summary, title):
"""NLI ๋ชจ๋ธ๋ก ์์ฝ๋ฌธ(์ ์ )๊ณผ ์ ๋ชฉ(๊ฐ์ค) ์ฌ์ด์ ๋ชจ์ ํ๋ฅ ์ ๊ณ์ฐํฉ๋๋ค."""
inputs = nli_tokenizer(
summary, title,
return_tensors='pt',
truncation=True,
max_length=512
).to(device)
# RoBERTa ๋ชจ๋ธ ์๋ฌ ๋ฐฉ์ง (token_type_ids ์ ๊ฑฐ)
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
with torch.no_grad():
outputs = nli_model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)[0]
# Huffon/klue-roberta-base-nli ๋ผ๋ฒจ ์์: [Entailment, Neutral, Contradiction]
# ๋ชจ์(Contradiction) ํ๋ฅ ๋ฐํ (Index 2)
return round(probs[2].item(), 4)
# =============================================================================
# 4. ์ต์ข
๋ฉ์ธ ํจ์ (Main Logic)
# =============================================================================
def calculate_mismatch_score(article_title, article_body):
"""
Grid Search ๊ฒฐ๊ณผ ์ต์ ๊ฐ์ค์น ์ ์ฉ:
- w1 (SBERT, ์๋ฏธ์ ๊ฑฐ๋ฆฌ): 0.8
- w2 (NLI, ๋
ผ๋ฆฌ์ ๋ชจ์): 0.2
- Threshold (์๊ณ๊ฐ): 0.45 ์ด์์ด๋ฉด '์ํ'
"""
#if not (kobart_summarizer and sbert_model and nli_model):
# return {"score": 0.0, "reason": "๋ชจ๋ธ ๋ก๋ฉ ์คํจ", "recommendation": "์๋ฒ ํ์ธ ํ์"}
# 1. ๋ณธ๋ฌธ ์์ฝ
summary = summarize_kobart_strict(article_body)
# 2. SBERT ์๋ฏธ์ ๊ฑฐ๋ฆฌ (1 - ์ ์ฌ๋)
sbert_sim = get_cosine_similarity(article_title, summary)
semantic_distance = 1 - sbert_sim
# 3. NLI ๋
ผ๋ฆฌ์ ๋ชจ์ ํ๋ฅ
nli_contradiction = get_mismatch_score(summary, article_title)
# 4. ์ต์ข
์ ์ ์ฐ์ถ
w1, w2 = 0.8, 0.2
final_score = (w1 * semantic_distance) + (w2 * nli_contradiction)
reason = (
f"[๋๋ฒ๊ทธ ๋ชจ๋]\n"
f"1. ์์ฝ๋ฌธ: {summary}\n"
f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n"
f"3. NLI ๋ชจ์: {nli_contradiction:.4f}"
)
#reason = f"์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ์๋ฏธ์ ๊ฑฐ๋ฆฌ({semantic_distance:.4f})์ ๋ชจ์ ํ๋ฅ ({nli_contradiction:.4f})์ด ๋ฐ์๋์์ต๋๋ค."
# 5. ๊ฒฐ๊ณผ ํ์ (Threshold 0.45 ๊ธฐ์ค)
if final_score >= 0.45:
recommendation = "์ ๋ชฉ์ด ๋ณธ๋ฌธ์ ๋ด์ฉ์ ์๊ณกํ๊ฑฐ๋ ๋ชจ์๋ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค."
else:
recommendation = "์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ๋ด์ฉ์ด ์ผ์นํฉ๋๋ค."
# main.py๋ก ์ ๋ฌํ ๋ฐ์ดํฐ
return {
"score": round(final_score, 4),
"reason": reason,
"recommendation": recommendation
} |