File size: 5,348 Bytes
a12ec62
2755fb0
 
 
a12ec62
 
 
2755fb0
7134b06
 
a12ec62
 
 
 
 
 
 
 
 
 
 
2755fb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07134b9
2755fb0
 
 
 
a12ec62
7134b06
 
 
 
2755fb0
7134b06
 
 
 
2755fb0
 
 
7134b06
 
 
 
 
 
 
2755fb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae484c
4615044
 
2755fb0
7134b06
 
2755fb0
 
 
 
 
 
 
a12ec62
 
 
2755fb0
a12ec62
 
 
 
 
 
 
 
 
 
 
 
 
 
7134b06
 
a12ec62
 
 
 
 
 
 
 
 
2755fb0
 
1ae484c
2755fb0
 
7134b06
 
2755fb0
 
 
 
 
 
7134b06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from fastapi import FastAPI, Depends
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import models
from database import engine, SessionLocal
from sqlalchemy.orm import Session # ํ•„์ˆ˜ import
from models import Article, AnalysisResult
from crossref_model import get_crossref_score_and_reason
from mismatch_model import calculate_mismatch_score  # <-- ํŒ€์› ํ•จ์ˆ˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from aggro_model import get_aggro_score

models.Base.metadata.create_all(bind=engine)
from database import SessionLocal # database.py์—์„œ ์ •์˜ํ•œ SessionLocal ๊ฐ€์ ธ์˜ค๊ธฐ
# DB ์„ธ์…˜์„ ์—ด๊ณ  ๋‹ซ๋Š” Dependency ํ•จ์ˆ˜ ์ •์˜
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# --- 1. API ๋ช…์„ธ์„œ (Request/Response) ---
# Pydantic ๋ชจ๋ธ, API ๋ช…์„ธ ์ดˆ์•ˆ

class ArticleRequest(BaseModel):
    """ํ”„๋ก ํŠธ๊ฐ€ ๋ณด๋‚ผ ์š”์ฒญ ํ˜•์‹"""
    article_title: str = Field(..., example="๊ธฐ์‚ฌ ์ œ๋ชฉ์ด ์—ฌ๊ธฐ ๋“ค์–ด๊ฐ‘๋‹ˆ๋‹ค")
    article_body: str = Field(..., example="๊ธฐ์‚ฌ ๋ณธ๋ฌธ ํ…์ŠคํŠธ...")

class FoundURL(BaseModel):
    """SBERT๊ฐ€ ๊ฒ€์ฆํ•œ URL ๊ฐ์ฒด"""
    url: str
    similarity: float = Field(..., example=0.85)

class ScoreBreakdown(BaseModel):
    """๊ฐœ๋ณ„ ์ ์ˆ˜ ์ƒ์„ธ๋‚ด์—ญ ํ˜•์‹"""
    score: float = Field(..., example=0.95)
    reason: str = Field(..., example="'์ถฉ๊ฒฉ' ํ‚ค์›Œ๋“œ ์‚ฌ์šฉ")
    recommendation: str = Field(..., example="์ค‘๋ฆฝ์ ์ธ ๋‹จ์–ด๋กœ ์ˆ˜์ •ํ•˜์„ธ์š”.")
    found_urls: Optional[List[FoundURL]] = None

class AnalysisResponse(BaseModel):
    """๋ฐฑ์—”๋“œ->ํ”„๋ก ํŠธ ์ตœ์ข… ์‘๋‹ต ํ˜•์‹"""
    final_risk_score: float = Field(..., example=0.82)
    final_risk_level: str = Field(..., example="์œ„ํ—˜")
    breakdown: Dict[str, ScoreBreakdown]

# --- 2. FastAPI ์•ฑ ์ƒ์„ฑ ---
app = FastAPI()

# --- 3. API ์—”๋“œํฌ์ธํŠธ ---

@app.post("/api/v1/analyze", response_model=AnalysisResponse)
def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
    """๊ธฐ์‚ฌ ๋ถ„์„ API"""

    # 1. AggroScore
    aggro_result = get_aggro_score(request.article_title)

    real_aggro = ScoreBreakdown(
        score=aggro_result["score"],
        reason=aggro_result["reason"],
        recommendation=aggro_result["recommendation"],
        found_urls=None
    )
    
    # 2. mismatch
    mismatch_result = calculate_mismatch_score(request.article_title, request.article_body)

    real_mismatch = ScoreBreakdown(
        score=mismatch_result["score"],
        reason=mismatch_result["reason"],
        recommendation=mismatch_result["recommendation"],
        found_urls=None
    )

    # 3. crossref
    real_crossref_data = get_crossref_score_and_reason(request.article_body)
    SIMILARITY_THRESHOLD = 0.7 # 70% ์ด์ƒ ์ผ์น˜ํ•˜๋Š” ๊ฒƒ๋งŒ ๋ณด์—ฌ์ฃผ๊ธฐ
    # ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ ์ˆœ์œผ๋กœ ์ •๋ ฌ
    sorted_urls = sorted(
        real_crossref_data["paired_results"], 
        key=lambda x: x["similarity"], 
        reverse=True
    )
    # ์ž„๊ณ„๊ฐ’(THRESHOLD) ์ด์ƒ์˜ URL๋งŒ ํ•„ํ„ฐ๋ง
    filtered_urls = [
        FoundURL(url=item["url"], similarity=item["similarity"]) 
        for item in sorted_urls 
        if item["similarity"] >= SIMILARITY_THRESHOLD
    ]
    # ์ตœ์ข… CrossrefScore ๊ฐ์ฒด ์ƒ์„ฑ (ํ•„ํ„ฐ๋ง๋œ URL ํฌํ•จ)
    final_crossref = ScoreBreakdown(
        score=real_crossref_data["score"],
        reason=real_crossref_data["reason"],
        recommendation=real_crossref_data["recommendation"],
        found_urls=filtered_urls
    )

    # ์ตœ์ข… ์œ„ํ—˜๋„ ๊ณ„์‚ฐ
    w_aggro = 0.2
    w_mismatch = 0.5
    w_crossref = 0.3
    
    final_score = (real_aggro.score * w_aggro) + \
                  (real_mismatch.score * w_mismatch) + \
                  (final_crossref.score * w_crossref)
    
    final_level = "์•ˆ์ „"
    if final_score >= 0.7:
        final_level = "์œ„ํ—˜"
    elif final_score >= 0.4:
        final_level = "์ฃผ์˜"
# ------------------------------------------------
    # ๐Ÿ›‘ [ํ•ต์‹ฌ ์ถ”๊ฐ€] DB ์ €์žฅ ๋กœ์ง ์‹œ์ž‘
    # ------------------------------------------------

    # 1. 'articles' ํ…Œ์ด๋ธ”์— ๊ธฐ์‚ฌ ์ €์žฅ
    new_article = Article(
        title=request.article_title,
        body=request.article_body,
        source="Swagger UI Test" # ํ…Œ์ŠคํŠธ์šฉ ์ถœ์ฒ˜ ์ž…๋ ฅ
    )
    db.add(new_article)
    
    # 2. article_id๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด Flush (์•„์ง commit์€ ํ•˜์ง€ ์•Š์Œ)
    db.flush() 
    
    # 3. 'analysis_results' ํ…Œ์ด๋ธ”์— ๋ถ„์„ ๊ฒฐ๊ณผ ์ €์žฅ
    new_result = AnalysisResult(
        article_id=new_article.article_id, # ์™ธ๋ž˜ ํ‚ค ์—ฐ๊ฒฐ
        aggro_score=real_aggro.score,
        mismatch_score=real_mismatch.score,
        crossref_score=final_crossref.score,
        final_risk=final_score,
    )
    db.add(new_result)
    
    # 4. ๋ชจ๋“  ๋ณ€๊ฒฝ ์‚ฌํ•ญ์„ DB์— ์˜๊ตฌ ์ €์žฅ (ํŠธ๋žœ์žญ์…˜ ์™„๋ฃŒ)
    db.commit()
    
    # ------------------------------------------------
    # API ๋ช…์„ธ์„œ(AnalysisResponse) ํ˜•์‹์— ๋งž์ถฐ์„œ ๋ฐ˜ํ™˜
    return AnalysisResponse(
        final_risk_score=round(final_score, 4), # ์†Œ์ˆ˜์  4์ž๋ฆฌ๋กœ ๋ฐ˜์˜ฌ๋ฆผ
        final_risk_level=final_level,
        breakdown={
            "aggro_score": real_aggro,
            "mismatch_score": real_mismatch,
            "crossref_score": final_crossref
        }
    )

@app.get("/")
def read_root():
    return {"message": "AI ๊ธฐ์‚ฌ ๋ถ„์„ ์„œ๋ฒ„"}