project-tdm / main.py
hy
round
1ae484c
from fastapi import FastAPI, Depends
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import models
from database import engine, SessionLocal
from sqlalchemy.orm import Session # ํ•„์ˆ˜ import
from models import Article, AnalysisResult
from crossref_model import get_crossref_score_and_reason
from mismatch_model import calculate_mismatch_score # <-- ํŒ€์› ํ•จ์ˆ˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
from aggro_model import get_aggro_score
models.Base.metadata.create_all(bind=engine)
from database import SessionLocal # database.py์—์„œ ์ •์˜ํ•œ SessionLocal ๊ฐ€์ ธ์˜ค๊ธฐ
# DB ์„ธ์…˜์„ ์—ด๊ณ  ๋‹ซ๋Š” Dependency ํ•จ์ˆ˜ ์ •์˜
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# --- 1. API ๋ช…์„ธ์„œ (Request/Response) ---
# Pydantic ๋ชจ๋ธ, API ๋ช…์„ธ ์ดˆ์•ˆ
class ArticleRequest(BaseModel):
"""ํ”„๋ก ํŠธ๊ฐ€ ๋ณด๋‚ผ ์š”์ฒญ ํ˜•์‹"""
article_title: str = Field(..., example="๊ธฐ์‚ฌ ์ œ๋ชฉ์ด ์—ฌ๊ธฐ ๋“ค์–ด๊ฐ‘๋‹ˆ๋‹ค")
article_body: str = Field(..., example="๊ธฐ์‚ฌ ๋ณธ๋ฌธ ํ…์ŠคํŠธ...")
class FoundURL(BaseModel):
"""SBERT๊ฐ€ ๊ฒ€์ฆํ•œ URL ๊ฐ์ฒด"""
url: str
similarity: float = Field(..., example=0.85)
class ScoreBreakdown(BaseModel):
"""๊ฐœ๋ณ„ ์ ์ˆ˜ ์ƒ์„ธ๋‚ด์—ญ ํ˜•์‹"""
score: float = Field(..., example=0.95)
reason: str = Field(..., example="'์ถฉ๊ฒฉ' ํ‚ค์›Œ๋“œ ์‚ฌ์šฉ")
recommendation: str = Field(..., example="์ค‘๋ฆฝ์ ์ธ ๋‹จ์–ด๋กœ ์ˆ˜์ •ํ•˜์„ธ์š”.")
found_urls: Optional[List[FoundURL]] = None
class AnalysisResponse(BaseModel):
"""๋ฐฑ์—”๋“œ->ํ”„๋ก ํŠธ ์ตœ์ข… ์‘๋‹ต ํ˜•์‹"""
final_risk_score: float = Field(..., example=0.82)
final_risk_level: str = Field(..., example="์œ„ํ—˜")
breakdown: Dict[str, ScoreBreakdown]
# --- 2. FastAPI ์•ฑ ์ƒ์„ฑ ---
app = FastAPI()
# --- 3. API ์—”๋“œํฌ์ธํŠธ ---
@app.post("/api/v1/analyze", response_model=AnalysisResponse)
def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
"""๊ธฐ์‚ฌ ๋ถ„์„ API"""
# 1. AggroScore
aggro_result = get_aggro_score(request.article_title)
real_aggro = ScoreBreakdown(
score=aggro_result["score"],
reason=aggro_result["reason"],
recommendation=aggro_result["recommendation"],
found_urls=None
)
# 2. mismatch
mismatch_result = calculate_mismatch_score(request.article_title, request.article_body)
real_mismatch = ScoreBreakdown(
score=mismatch_result["score"],
reason=mismatch_result["reason"],
recommendation=mismatch_result["recommendation"],
found_urls=None
)
# 3. crossref
real_crossref_data = get_crossref_score_and_reason(request.article_body)
SIMILARITY_THRESHOLD = 0.7 # 70% ์ด์ƒ ์ผ์น˜ํ•˜๋Š” ๊ฒƒ๋งŒ ๋ณด์—ฌ์ฃผ๊ธฐ
# ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ ์ˆœ์œผ๋กœ ์ •๋ ฌ
sorted_urls = sorted(
real_crossref_data["paired_results"],
key=lambda x: x["similarity"],
reverse=True
)
# ์ž„๊ณ„๊ฐ’(THRESHOLD) ์ด์ƒ์˜ URL๋งŒ ํ•„ํ„ฐ๋ง
filtered_urls = [
FoundURL(url=item["url"], similarity=item["similarity"])
for item in sorted_urls
if item["similarity"] >= SIMILARITY_THRESHOLD
]
# ์ตœ์ข… CrossrefScore ๊ฐ์ฒด ์ƒ์„ฑ (ํ•„ํ„ฐ๋ง๋œ URL ํฌํ•จ)
final_crossref = ScoreBreakdown(
score=real_crossref_data["score"],
reason=real_crossref_data["reason"],
recommendation=real_crossref_data["recommendation"],
found_urls=filtered_urls
)
# ์ตœ์ข… ์œ„ํ—˜๋„ ๊ณ„์‚ฐ
w_aggro = 0.2
w_mismatch = 0.3
w_crossref = 0.5
final_score = (real_aggro.score * w_aggro) + \
(real_mismatch.score * w_mismatch) + \
(final_crossref.score * w_crossref)
final_level = "์•ˆ์ „"
if final_score >= 0.7:
final_level = "์œ„ํ—˜"
elif final_score >= 0.4:
final_level = "์ฃผ์˜"
# ------------------------------------------------
# ๐Ÿ›‘ [ํ•ต์‹ฌ ์ถ”๊ฐ€] DB ์ €์žฅ ๋กœ์ง ์‹œ์ž‘
# ------------------------------------------------
# 1. 'articles' ํ…Œ์ด๋ธ”์— ๊ธฐ์‚ฌ ์ €์žฅ
new_article = Article(
title=request.article_title,
body=request.article_body,
source="Swagger UI Test" # ํ…Œ์ŠคํŠธ์šฉ ์ถœ์ฒ˜ ์ž…๋ ฅ
)
db.add(new_article)
# 2. article_id๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด Flush (์•„์ง commit์€ ํ•˜์ง€ ์•Š์Œ)
db.flush()
# 3. 'analysis_results' ํ…Œ์ด๋ธ”์— ๋ถ„์„ ๊ฒฐ๊ณผ ์ €์žฅ
new_result = AnalysisResult(
article_id=new_article.article_id, # ์™ธ๋ž˜ ํ‚ค ์—ฐ๊ฒฐ
aggro_score=real_aggro.score,
mismatch_score=real_mismatch.score,
crossref_score=final_crossref.score,
final_risk=final_score,
)
db.add(new_result)
# 4. ๋ชจ๋“  ๋ณ€๊ฒฝ ์‚ฌํ•ญ์„ DB์— ์˜๊ตฌ ์ €์žฅ (ํŠธ๋žœ์žญ์…˜ ์™„๋ฃŒ)
db.commit()
# ------------------------------------------------
# API ๋ช…์„ธ์„œ(AnalysisResponse) ํ˜•์‹์— ๋งž์ถฐ์„œ ๋ฐ˜ํ™˜
return AnalysisResponse(
final_risk_score=round(final_score, 4), # ์†Œ์ˆ˜์  4์ž๋ฆฌ๋กœ ๋ฐ˜์˜ฌ๋ฆผ
final_risk_level=final_level,
breakdown={
"aggro_score": real_aggro,
"mismatch_score": real_mismatch,
"crossref_score": final_crossref
}
)
@app.get("/")
def read_root():
return {"message": "AI ๊ธฐ์‚ฌ ๋ถ„์„ ์„œ๋ฒ„"}