Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| from sentence_transformers import SentenceTransformer, util | |
| # ๋๋ฐ์ด์ค ์ค์ (GPU ์ฐ์ , ์์ผ๋ฉด CPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"โ ํ์ฌ ์คํ ํ๊ฒฝ: {device}") | |
| # ============================================================================= | |
| # 2. ๋ชจ๋ธ ๋ก๋ (์๊ฐ์ด ์กฐ๊ธ ๊ฑธ๋ฆด ์ ์์ต๋๋ค) | |
| # ============================================================================= | |
| print("\nโณ [1/3] KoBART ์์ฝ ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| kobart_summarizer = pipeline( | |
| "summarization", | |
| model="gogamza/kobart-summarization", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("โณ [2/3] SBERT ์ ์ฌ๋ ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| sbert_model = SentenceTransformer('jhgan/ko-sroberta-multitask') | |
| print("โณ [3/3] NLI(๋ชจ์ ํ์ง) ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| nli_model_name = "Huffon/klue-roberta-base-nli" | |
| nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name) | |
| nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device) | |
| print("๐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์๋ฃ!\n") | |
| # ============================================================================= | |
| # 3. ๋์ฐ๋ฏธ ํจ์ ์ ์ (Worker Functions) | |
| # ============================================================================= | |
| def summarize_kobart_strict(text): | |
| """KoBART๋ฅผ ์ฌ์ฉํ์ฌ ๋ณธ๋ฌธ์ ์์ฝํฉ๋๋ค.""" | |
| # ๋ณธ๋ฌธ์ด ๋๋ฌด ์งง์ผ๋ฉด ์์ฝ ์๋ต (์ค๋ฅ ๋ฐฉ์ง) | |
| if len(text) < 50: | |
| return text | |
| try: | |
| result = kobart_summarizer( | |
| text, | |
| min_length=15, | |
| max_length=128, | |
| num_beams=4, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True | |
| )[0]['summary_text'] | |
| return result.strip() | |
| except Exception as e: | |
| return text[:100] # ์คํจ ์ ์๋ถ๋ถ ๋ฐํ | |
| def get_cosine_similarity(title, summary): | |
| """SBERT๋ก ์ ๋ชฉ๊ณผ ์์ฝ๋ฌธ์ ์ฝ์ฌ์ธ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.""" | |
| emb1 = sbert_model.encode(title, convert_to_tensor=True) | |
| emb2 = sbert_model.encode(summary, convert_to_tensor=True) | |
| return util.cos_sim(emb1, emb2).item() | |
| def get_mismatch_score(summary, title): | |
| """NLI ๋ชจ๋ธ๋ก ์์ฝ๋ฌธ(์ ์ )๊ณผ ์ ๋ชฉ(๊ฐ์ค) ์ฌ์ด์ ๋ชจ์ ํ๋ฅ ์ ๊ณ์ฐํฉ๋๋ค.""" | |
| inputs = nli_tokenizer( | |
| summary, title, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=512 | |
| ).to(device) | |
| # RoBERTa ๋ชจ๋ธ ์๋ฌ ๋ฐฉ์ง (token_type_ids ์ ๊ฑฐ) | |
| if "token_type_ids" in inputs: | |
| del inputs["token_type_ids"] | |
| with torch.no_grad(): | |
| outputs = nli_model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1)[0] | |
| # Huffon/klue-roberta-base-nli ๋ผ๋ฒจ ์์: [Entailment, Neutral, Contradiction] | |
| # ๋ชจ์(Contradiction) ํ๋ฅ ๋ฐํ (Index 2) | |
| return round(probs[2].item(), 4) | |
| # ============================================================================= | |
| # 4. ์ต์ข ๋ฉ์ธ ํจ์ (Main Logic) | |
| # ============================================================================= | |
| def calculate_mismatch_score(article_title, article_body): | |
| """ | |
| Grid Search ๊ฒฐ๊ณผ ์ต์ ๊ฐ์ค์น ์ ์ฉ: | |
| - w1 (SBERT, ์๋ฏธ์ ๊ฑฐ๋ฆฌ): 0.8 | |
| - w2 (NLI, ๋ ผ๋ฆฌ์ ๋ชจ์): 0.2 | |
| - Threshold (์๊ณ๊ฐ): 0.45 ์ด์์ด๋ฉด '์ํ' | |
| """ | |
| #if not (kobart_summarizer and sbert_model and nli_model): | |
| # return {"score": 0.0, "reason": "๋ชจ๋ธ ๋ก๋ฉ ์คํจ", "recommendation": "์๋ฒ ํ์ธ ํ์"} | |
| # 1. ๋ณธ๋ฌธ ์์ฝ | |
| summary = summarize_kobart_strict(article_body) | |
| # 2. SBERT ์๋ฏธ์ ๊ฑฐ๋ฆฌ (1 - ์ ์ฌ๋) | |
| sbert_sim = get_cosine_similarity(article_title, summary) | |
| semantic_distance = 1 - sbert_sim | |
| # 3. NLI ๋ ผ๋ฆฌ์ ๋ชจ์ ํ๋ฅ | |
| nli_contradiction = get_mismatch_score(summary, article_title) | |
| # 4. ์ต์ข ์ ์ ์ฐ์ถ | |
| w1, w2 = 0.8, 0.2 | |
| final_score = (w1 * semantic_distance) + (w2 * nli_contradiction) | |
| reason = ( | |
| f"[๋๋ฒ๊ทธ ๋ชจ๋]\n" | |
| f"1. ์์ฝ๋ฌธ: {summary}\n" | |
| f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n" | |
| f"3. NLI ๋ชจ์: {nli_contradiction:.4f}" | |
| ) | |
| #reason = f"์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ์๋ฏธ์ ๊ฑฐ๋ฆฌ({semantic_distance:.4f})์ ๋ชจ์ ํ๋ฅ ({nli_contradiction:.4f})์ด ๋ฐ์๋์์ต๋๋ค." | |
| # 5. ๊ฒฐ๊ณผ ํ์ (Threshold 0.45 ๊ธฐ์ค) | |
| if final_score >= 0.45: | |
| recommendation = "์ ๋ชฉ์ด ๋ณธ๋ฌธ์ ๋ด์ฉ์ ์๊ณกํ๊ฑฐ๋ ๋ชจ์๋ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค." | |
| else: | |
| recommendation = "์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ๋ด์ฉ์ด ์ผ์นํฉ๋๋ค." | |
| # main.py๋ก ์ ๋ฌํ ๋ฐ์ดํฐ | |
| return { | |
| "score": round(final_score, 4), | |
| "reason": reason, | |
| "recommendation": recommendation | |
| } |