Spaces:
Running
Running
File size: 5,348 Bytes
a12ec62 2755fb0 a12ec62 2755fb0 7134b06 a12ec62 2755fb0 07134b9 2755fb0 a12ec62 7134b06 2755fb0 7134b06 2755fb0 7134b06 2755fb0 1ae484c 4615044 2755fb0 7134b06 2755fb0 a12ec62 2755fb0 a12ec62 7134b06 a12ec62 2755fb0 1ae484c 2755fb0 7134b06 2755fb0 7134b06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from fastapi import FastAPI, Depends
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import models
from database import engine, SessionLocal
from sqlalchemy.orm import Session # ํ์ import
from models import Article, AnalysisResult
from crossref_model import get_crossref_score_and_reason
from mismatch_model import calculate_mismatch_score # <-- ํ์ ํจ์ ๋ถ๋ฌ์ค๊ธฐ
from aggro_model import get_aggro_score
models.Base.metadata.create_all(bind=engine)
from database import SessionLocal # database.py์์ ์ ์ํ SessionLocal ๊ฐ์ ธ์ค๊ธฐ
# DB ์ธ์
์ ์ด๊ณ ๋ซ๋ Dependency ํจ์ ์ ์
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# --- 1. API ๋ช
์ธ์ (Request/Response) ---
# Pydantic ๋ชจ๋ธ, API ๋ช
์ธ ์ด์
class ArticleRequest(BaseModel):
"""ํ๋ก ํธ๊ฐ ๋ณด๋ผ ์์ฒญ ํ์"""
article_title: str = Field(..., example="๊ธฐ์ฌ ์ ๋ชฉ์ด ์ฌ๊ธฐ ๋ค์ด๊ฐ๋๋ค")
article_body: str = Field(..., example="๊ธฐ์ฌ ๋ณธ๋ฌธ ํ
์คํธ...")
class FoundURL(BaseModel):
"""SBERT๊ฐ ๊ฒ์ฆํ URL ๊ฐ์ฒด"""
url: str
similarity: float = Field(..., example=0.85)
class ScoreBreakdown(BaseModel):
"""๊ฐ๋ณ ์ ์ ์์ธ๋ด์ญ ํ์"""
score: float = Field(..., example=0.95)
reason: str = Field(..., example="'์ถฉ๊ฒฉ' ํค์๋ ์ฌ์ฉ")
recommendation: str = Field(..., example="์ค๋ฆฝ์ ์ธ ๋จ์ด๋ก ์์ ํ์ธ์.")
found_urls: Optional[List[FoundURL]] = None
class AnalysisResponse(BaseModel):
"""๋ฐฑ์๋->ํ๋ก ํธ ์ต์ข
์๋ต ํ์"""
final_risk_score: float = Field(..., example=0.82)
final_risk_level: str = Field(..., example="์ํ")
breakdown: Dict[str, ScoreBreakdown]
# --- 2. FastAPI ์ฑ ์์ฑ ---
app = FastAPI()
# --- 3. API ์๋ํฌ์ธํธ ---
@app.post("/api/v1/analyze", response_model=AnalysisResponse)
def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)):
"""๊ธฐ์ฌ ๋ถ์ API"""
# 1. AggroScore
aggro_result = get_aggro_score(request.article_title)
real_aggro = ScoreBreakdown(
score=aggro_result["score"],
reason=aggro_result["reason"],
recommendation=aggro_result["recommendation"],
found_urls=None
)
# 2. mismatch
mismatch_result = calculate_mismatch_score(request.article_title, request.article_body)
real_mismatch = ScoreBreakdown(
score=mismatch_result["score"],
reason=mismatch_result["reason"],
recommendation=mismatch_result["recommendation"],
found_urls=None
)
# 3. crossref
real_crossref_data = get_crossref_score_and_reason(request.article_body)
SIMILARITY_THRESHOLD = 0.7 # 70% ์ด์ ์ผ์นํ๋ ๊ฒ๋ง ๋ณด์ฌ์ฃผ๊ธฐ
# ์ ์ฌ๋๊ฐ ๋์ ์์ผ๋ก ์ ๋ ฌ
sorted_urls = sorted(
real_crossref_data["paired_results"],
key=lambda x: x["similarity"],
reverse=True
)
# ์๊ณ๊ฐ(THRESHOLD) ์ด์์ URL๋ง ํํฐ๋ง
filtered_urls = [
FoundURL(url=item["url"], similarity=item["similarity"])
for item in sorted_urls
if item["similarity"] >= SIMILARITY_THRESHOLD
]
# ์ต์ข
CrossrefScore ๊ฐ์ฒด ์์ฑ (ํํฐ๋ง๋ URL ํฌํจ)
final_crossref = ScoreBreakdown(
score=real_crossref_data["score"],
reason=real_crossref_data["reason"],
recommendation=real_crossref_data["recommendation"],
found_urls=filtered_urls
)
# ์ต์ข
์ํ๋ ๊ณ์ฐ
w_aggro = 0.2
w_mismatch = 0.5
w_crossref = 0.3
final_score = (real_aggro.score * w_aggro) + \
(real_mismatch.score * w_mismatch) + \
(final_crossref.score * w_crossref)
final_level = "์์ "
if final_score >= 0.7:
final_level = "์ํ"
elif final_score >= 0.4:
final_level = "์ฃผ์"
# ------------------------------------------------
# ๐ [ํต์ฌ ์ถ๊ฐ] DB ์ ์ฅ ๋ก์ง ์์
# ------------------------------------------------
# 1. 'articles' ํ
์ด๋ธ์ ๊ธฐ์ฌ ์ ์ฅ
new_article = Article(
title=request.article_title,
body=request.article_body,
source="Swagger UI Test" # ํ
์คํธ์ฉ ์ถ์ฒ ์
๋ ฅ
)
db.add(new_article)
# 2. article_id๋ฅผ ์ป๊ธฐ ์ํด Flush (์์ง commit์ ํ์ง ์์)
db.flush()
# 3. 'analysis_results' ํ
์ด๋ธ์ ๋ถ์ ๊ฒฐ๊ณผ ์ ์ฅ
new_result = AnalysisResult(
article_id=new_article.article_id, # ์ธ๋ ํค ์ฐ๊ฒฐ
aggro_score=real_aggro.score,
mismatch_score=real_mismatch.score,
crossref_score=final_crossref.score,
final_risk=final_score,
)
db.add(new_result)
# 4. ๋ชจ๋ ๋ณ๊ฒฝ ์ฌํญ์ DB์ ์๊ตฌ ์ ์ฅ (ํธ๋์ญ์
์๋ฃ)
db.commit()
# ------------------------------------------------
# API ๋ช
์ธ์(AnalysisResponse) ํ์์ ๋ง์ถฐ์ ๋ฐํ
return AnalysisResponse(
final_risk_score=round(final_score, 4), # ์์์ 4์๋ฆฌ๋ก ๋ฐ์ฌ๋ฆผ
final_risk_level=final_level,
breakdown={
"aggro_score": real_aggro,
"mismatch_score": real_mismatch,
"crossref_score": final_crossref
}
)
@app.get("/")
def read_root():
return {"message": "AI ๊ธฐ์ฌ ๋ถ์ ์๋ฒ"}
|