File size: 4,876 Bytes
7134b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn.functional as F
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util

# ๋””๋ฐ”์ด์Šค ์„ค์ • (GPU ์šฐ์„ , ์—†์œผ๋ฉด CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"โœ… ํ˜„์žฌ ์‹คํ–‰ ํ™˜๊ฒฝ: {device}")

# =============================================================================
# 2. ๋ชจ๋ธ ๋กœ๋“œ (์‹œ๊ฐ„์ด ์กฐ๊ธˆ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค)
# =============================================================================
print("\nโณ [1/3] KoBART ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
kobart_summarizer = pipeline(
    "summarization",
    model="gogamza/kobart-summarization",
    device=0 if torch.cuda.is_available() else -1
)

print("โณ [2/3] SBERT ์œ ์‚ฌ๋„ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
sbert_model = SentenceTransformer('jhgan/ko-sroberta-multitask')

print("โณ [3/3] NLI(๋ชจ์ˆœ ํƒ์ง€) ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
nli_model_name = "Huffon/klue-roberta-base-nli"
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)

print("๐ŸŽ‰ ๋ชจ๋“  ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!\n")

# =============================================================================
# 3. ๋„์šฐ๋ฏธ ํ•จ์ˆ˜ ์ •์˜ (Worker Functions)
# =============================================================================

def summarize_kobart_strict(text):
    """KoBART๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณธ๋ฌธ์„ ์š”์•ฝํ•ฉ๋‹ˆ๋‹ค."""
    # ๋ณธ๋ฌธ์ด ๋„ˆ๋ฌด ์งง์œผ๋ฉด ์š”์•ฝ ์ƒ๋žต (์˜ค๋ฅ˜ ๋ฐฉ์ง€)
    if len(text) < 50: 
        return text
    
    try:
        result = kobart_summarizer(
            text,
            min_length=15, 
            max_length=128,
            num_beams=4, 
            no_repeat_ngram_size=3, 
            early_stopping=True
        )[0]['summary_text']
        return result.strip()
    except Exception as e:
        return text[:100] # ์‹คํŒจ ์‹œ ์•ž๋ถ€๋ถ„ ๋ฐ˜ํ™˜

def get_cosine_similarity(title, summary):
    """SBERT๋กœ ์ œ๋ชฉ๊ณผ ์š”์•ฝ๋ฌธ์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
    emb1 = sbert_model.encode(title, convert_to_tensor=True)
    emb2 = sbert_model.encode(summary, convert_to_tensor=True)
    return util.cos_sim(emb1, emb2).item()

def get_mismatch_score(summary, title):
    """NLI ๋ชจ๋ธ๋กœ ์š”์•ฝ๋ฌธ(์ „์ œ)๊ณผ ์ œ๋ชฉ(๊ฐ€์„ค) ์‚ฌ์ด์˜ ๋ชจ์ˆœ ํ™•๋ฅ ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค."""
    inputs = nli_tokenizer(
        summary, title,
        return_tensors='pt',
        truncation=True,
        max_length=512
    ).to(device)
    
    # RoBERTa ๋ชจ๋ธ ์—๋Ÿฌ ๋ฐฉ์ง€ (token_type_ids ์ œ๊ฑฐ)
    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]
    
    with torch.no_grad():
        outputs = nli_model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)[0]
    
    # Huffon/klue-roberta-base-nli ๋ผ๋ฒจ ์ˆœ์„œ: [Entailment, Neutral, Contradiction]
    # ๋ชจ์ˆœ(Contradiction) ํ™•๋ฅ  ๋ฐ˜ํ™˜ (Index 2)
    return round(probs[2].item(), 4)

# =============================================================================
# 4. ์ตœ์ข… ๋ฉ”์ธ ํ•จ์ˆ˜ (Main Logic)
# =============================================================================

def calculate_mismatch_score(article_title, article_body):
    """
    Grid Search ๊ฒฐ๊ณผ ์ตœ์  ๊ฐ€์ค‘์น˜ ์ ์šฉ:
    - w1 (SBERT, ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ): 0.8
    - w2 (NLI, ๋…ผ๋ฆฌ์  ๋ชจ์ˆœ): 0.2
    - Threshold (์ž„๊ณ„๊ฐ’): 0.45 ์ด์ƒ์ด๋ฉด '์œ„ํ—˜'
    """
    #if not (kobart_summarizer and sbert_model and nli_model):
     #   return {"score": 0.0, "reason": "๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ", "recommendation": "์„œ๋ฒ„ ํ™•์ธ ํ•„์š”"}

    # 1. ๋ณธ๋ฌธ ์š”์•ฝ
    summary = summarize_kobart_strict(article_body)
    
    # 2. SBERT ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ (1 - ์œ ์‚ฌ๋„)
    sbert_sim = get_cosine_similarity(article_title, summary)
    semantic_distance = 1 - sbert_sim

    # 3. NLI ๋…ผ๋ฆฌ์  ๋ชจ์ˆœ ํ™•๋ฅ 
    nli_contradiction = get_mismatch_score(summary, article_title)
    
    # 4. ์ตœ์ข… ์ ์ˆ˜ ์‚ฐ์ถœ
    w1, w2 = 0.8, 0.2
    final_score = (w1 * semantic_distance) + (w2 * nli_contradiction)
    reason = (
        f"[๋””๋ฒ„๊ทธ ๋ชจ๋“œ]\n"
        f"1. ์š”์•ฝ๋ฌธ: {summary}\n"
        f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n"
        f"3. NLI ๋ชจ์ˆœ: {nli_contradiction:.4f}"
    )
    
    #reason = f"์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ({semantic_distance:.4f})์™€ ๋ชจ์ˆœ ํ™•๋ฅ ({nli_contradiction:.4f})์ด ๋ฐ˜์˜๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
    
    # 5. ๊ฒฐ๊ณผ ํŒ์ • (Threshold 0.45 ๊ธฐ์ค€)
    if final_score >= 0.45:
        recommendation = "์ œ๋ชฉ์ด ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์„ ์™œ๊ณกํ•˜๊ฑฐ๋‚˜ ๋ชจ์ˆœ๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Šต๋‹ˆ๋‹ค."
    else:
        recommendation = "์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์ด ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค."

    # main.py๋กœ ์ „๋‹ฌํ•  ๋ฐ์ดํ„ฐ
    return {
        "score": round(final_score, 4),
        "reason": reason,
        "recommendation": recommendation
    }