File size: 3,968 Bytes
02af15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9a4f82
02af15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5a2c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02af15b
 
 
 
 
 
a5a2c66
 
 
 
 
 
 
 
02af15b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""FastAPI exception handlers aligned with HTTP API contract."""

from __future__ import annotations

import logging
from typing import Any, Dict, Optional, Tuple

from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException

logger = logging.getLogger(__name__)

DEFAULT_ERRORS: Dict[int, Tuple[str, str]] = {
    status.HTTP_400_BAD_REQUEST: ("validation_error", "Invalid request payload"),
    status.HTTP_401_UNAUTHORIZED: ("unauthorized", "Authorization required"),
    status.HTTP_403_FORBIDDEN: ("forbidden", "Forbidden"),
    status.HTTP_404_NOT_FOUND: ("not_found", "Resource not found"),
    status.HTTP_409_CONFLICT: ("version_conflict", "Resource version conflict"),
    status.HTTP_413_CONTENT_TOO_LARGE: (
        "payload_too_large",
        "Payload exceeds allowed size",
    ),
    status.HTTP_500_INTERNAL_SERVER_ERROR: ("internal_error", "Internal server error"),
}


def _normalize_error(
    status_code: int, detail: Any
) -> Tuple[str, str, Optional[Dict[str, Any]]]:
    default_error, default_message = DEFAULT_ERRORS.get(
        status_code, DEFAULT_ERRORS[status.HTTP_500_INTERNAL_SERVER_ERROR]
    )
    if isinstance(detail, dict):
        error = detail.get("error", default_error)
        message = detail.get("message", default_message)
        detail_payload = detail.get("detail")
        if detail_payload is None:
            remainder = {
                k: v for k, v in detail.items() if k not in {"error", "message", "detail"}
            }
            detail_payload = remainder or None
        return error, message, detail_payload
    if isinstance(detail, str) and detail:
        return default_error, detail, None
    return default_error, default_message, None


def _response(status_code: int, detail: Any) -> JSONResponse:
    error, message, extra = _normalize_error(status_code, detail)
    return JSONResponse(
        status_code=status_code, content={"error": error, "message": message, "detail": extra}
    )


async def validation_exception_handler(
    request: Request, exc: RequestValidationError
) -> JSONResponse:
    # Transform pydantic errors into more user-friendly format
    errors = []
    for error in exc.errors():
        field = ".".join(str(loc) for loc in error["loc"] if loc != "body")
        errors.append({
            "field": field or "request",
            "reason": error["msg"],
            "type": error["type"]
        })
    
    detail = {
        "error": "validation_error",
        "message": "Request validation failed",
        "detail": {"fields": errors}
    }
    logger.warning(
        "Validation error",
        extra={"url": str(request.url), "errors": errors}
    )
    return _response(status.HTTP_400_BAD_REQUEST, detail)


async def http_exception_handler(
    request: Request, exc: StarletteHTTPException
) -> JSONResponse:
    logger.warning(
        "HTTP exception",
        extra={
            "url": str(request.url),
            "status_code": exc.status_code,
            "detail": exc.detail
        }
    )
    return _response(exc.status_code, exc.detail)


async def internal_exception_handler(request: Request, exc: Exception) -> JSONResponse:
    logger.exception("Unhandled exception: %s", exc)
    return _response(status.HTTP_500_INTERNAL_SERVER_ERROR, exc.args[0] if exc.args else None)


def register_error_handlers(app: FastAPI) -> None:
    """Attach shared exception handlers to the FastAPI application."""
    app.add_exception_handler(RequestValidationError, validation_exception_handler)
    app.add_exception_handler(StarletteHTTPException, http_exception_handler)
    app.add_exception_handler(Exception, internal_exception_handler)


__all__ = [
    "register_error_handlers",
    "validation_exception_handler",
    "http_exception_handler",
    "internal_exception_handler",
]