from fastapi import FastAPI, Depends from pydantic import BaseModel, Field from typing import Dict, List, Optional import models from database import engine, SessionLocal from sqlalchemy.orm import Session # 필수 import from models import Article, AnalysisResult from crossref_model import get_crossref_score_and_reason from mismatch_model import calculate_mismatch_score # <-- 팀원 함수 불러오기 from aggro_model import get_aggro_score models.Base.metadata.create_all(bind=engine) from database import SessionLocal # database.py에서 정의한 SessionLocal 가져오기 # DB 세션을 열고 닫는 Dependency 함수 정의 def get_db(): db = SessionLocal() try: yield db finally: db.close() # --- 1. API 명세서 (Request/Response) --- # Pydantic 모델, API 명세 초안 class ArticleRequest(BaseModel): """프론트가 보낼 요청 형식""" article_title: str = Field(..., example="기사 제목이 여기 들어갑니다") article_body: str = Field(..., example="기사 본문 텍스트...") class FoundURL(BaseModel): """SBERT가 검증한 URL 객체""" url: str similarity: float = Field(..., example=0.85) class ScoreBreakdown(BaseModel): """개별 점수 상세내역 형식""" score: float = Field(..., example=0.95) reason: str = Field(..., example="'충격' 키워드 사용") recommendation: str = Field(..., example="중립적인 단어로 수정하세요.") found_urls: Optional[List[FoundURL]] = None class AnalysisResponse(BaseModel): """백엔드->프론트 최종 응답 형식""" final_risk_score: float = Field(..., example=0.82) final_risk_level: str = Field(..., example="위험") breakdown: Dict[str, ScoreBreakdown] # --- 2. FastAPI 앱 생성 --- app = FastAPI() # --- 3. API 엔드포인트 --- @app.post("/api/v1/analyze", response_model=AnalysisResponse) def analyze_article(request: ArticleRequest,db: Session = Depends(get_db)): """기사 분석 API""" # 1. AggroScore aggro_result = get_aggro_score(request.article_title) real_aggro = ScoreBreakdown( score=aggro_result["score"], reason=aggro_result["reason"], recommendation=aggro_result["recommendation"], found_urls=None ) # 2. mismatch mismatch_result = calculate_mismatch_score(request.article_title, request.article_body) real_mismatch = ScoreBreakdown( score=mismatch_result["score"], reason=mismatch_result["reason"], recommendation=mismatch_result["recommendation"], found_urls=None ) # 3. crossref real_crossref_data = get_crossref_score_and_reason(request.article_body) SIMILARITY_THRESHOLD = 0.7 # 70% 이상 일치하는 것만 보여주기 # 유사도가 높은 순으로 정렬 sorted_urls = sorted( real_crossref_data["paired_results"], key=lambda x: x["similarity"], reverse=True ) # 임계값(THRESHOLD) 이상의 URL만 필터링 filtered_urls = [ FoundURL(url=item["url"], similarity=item["similarity"]) for item in sorted_urls if item["similarity"] >= SIMILARITY_THRESHOLD ] # 최종 CrossrefScore 객체 생성 (필터링된 URL 포함) final_crossref = ScoreBreakdown( score=real_crossref_data["score"], reason=real_crossref_data["reason"], recommendation=real_crossref_data["recommendation"], found_urls=filtered_urls ) # 최종 위험도 계산 w_aggro = 0.2 w_mismatch = 0.3 w_crossref = 0.5 final_score = (real_aggro.score * w_aggro) + \ (real_mismatch.score * w_mismatch) + \ (final_crossref.score * w_crossref) final_level = "안전" if final_score >= 0.7: final_level = "위험" elif final_score >= 0.4: final_level = "주의" # ------------------------------------------------ # 🛑 [핵심 추가] DB 저장 로직 시작 # ------------------------------------------------ # 1. 'articles' 테이블에 기사 저장 new_article = Article( title=request.article_title, body=request.article_body, source="Swagger UI Test" # 테스트용 출처 입력 ) db.add(new_article) # 2. article_id를 얻기 위해 Flush (아직 commit은 하지 않음) db.flush() # 3. 'analysis_results' 테이블에 분석 결과 저장 new_result = AnalysisResult( article_id=new_article.article_id, # 외래 키 연결 aggro_score=real_aggro.score, mismatch_score=real_mismatch.score, crossref_score=final_crossref.score, final_risk=final_score, ) db.add(new_result) # 4. 모든 변경 사항을 DB에 영구 저장 (트랜잭션 완료) db.commit() # ------------------------------------------------ # API 명세서(AnalysisResponse) 형식에 맞춰서 반환 return AnalysisResponse( final_risk_score=round(final_score, 4), # 소수점 4자리로 반올림 final_risk_level=final_level, breakdown={ "aggro_score": real_aggro, "mismatch_score": real_mismatch, "crossref_score": final_crossref } ) @app.get("/") def read_root(): return {"message": "AI 기사 분석 서버"}