Spaces:
Paused
Paused
| # +-------------------------------------------------------------+ | |
| # | |
| # Use Aim Security Guardrails for your LLM calls | |
| # https://www.aim.security/ | |
| # | |
| # +-------------------------------------------------------------+ | |
| import asyncio | |
| import json | |
| import os | |
| from typing import Any, AsyncGenerator, Literal, Optional, Union | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel | |
| from websockets.asyncio.client import ClientConnection, connect | |
| from litellm import DualCache | |
| from litellm._version import version as litellm_version | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.integrations.custom_guardrail import CustomGuardrail | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| get_async_httpx_client, | |
| httpxSpecialProvider, | |
| ) | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.proxy.proxy_server import StreamingCallbackError | |
| from litellm.types.utils import ( | |
| Choices, | |
| EmbeddingResponse, | |
| ImageResponse, | |
| ModelResponse, | |
| ModelResponseStream, | |
| ) | |
| class AimGuardrailMissingSecrets(Exception): | |
| pass | |
| class AimGuardrail(CustomGuardrail): | |
| def __init__( | |
| self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs | |
| ): | |
| self.async_handler = get_async_httpx_client( | |
| llm_provider=httpxSpecialProvider.GuardrailCallback | |
| ) | |
| self.api_key = api_key or os.environ.get("AIM_API_KEY") | |
| if not self.api_key: | |
| msg = ( | |
| "Couldn't get Aim api key, either set the `AIM_API_KEY` in the environment or " | |
| "pass it as a parameter to the guardrail in the config file" | |
| ) | |
| raise AimGuardrailMissingSecrets(msg) | |
| self.api_base = ( | |
| api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security" | |
| ) | |
| self.ws_api_base = self.api_base.replace("http://", "ws://").replace( | |
| "https://", "wss://" | |
| ) | |
| super().__init__(**kwargs) | |
| async def async_pre_call_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| cache: DualCache, | |
| data: dict, | |
| call_type: Literal[ | |
| "completion", | |
| "text_completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| "pass_through_endpoint", | |
| "rerank", | |
| ], | |
| ) -> Union[Exception, str, dict, None]: | |
| verbose_proxy_logger.debug("Inside AIM Pre-Call Hook") | |
| await self.call_aim_guardrail( | |
| data, hook="pre_call", key_alias=user_api_key_dict.key_alias | |
| ) | |
| return data | |
| async def async_moderation_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| call_type: Literal[ | |
| "completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| "responses", | |
| ], | |
| ) -> Union[Exception, str, dict, None]: | |
| verbose_proxy_logger.debug("Inside AIM Moderation Hook") | |
| await self.call_aim_guardrail( | |
| data, hook="moderation", key_alias=user_api_key_dict.key_alias | |
| ) | |
| return data | |
| async def call_aim_guardrail( | |
| self, data: dict, hook: str, key_alias: Optional[str] | |
| ) -> None: | |
| user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") | |
| call_id = data.get("litellm_call_id") | |
| headers = self._build_aim_headers( | |
| hook=hook, | |
| key_alias=key_alias, | |
| user_email=user_email, | |
| litellm_call_id=call_id, | |
| ) | |
| response = await self.async_handler.post( | |
| f"{self.api_base}/detect/openai", | |
| headers=headers, | |
| json={"messages": data.get("messages", [])}, | |
| ) | |
| response.raise_for_status() | |
| res = response.json() | |
| detected = res["detected"] | |
| verbose_proxy_logger.info( | |
| "Aim: detected: {detected}, enabled policies: {policies}".format( | |
| detected=detected, | |
| policies=list(res["details"].keys()), | |
| ), | |
| ) | |
| if detected: | |
| raise HTTPException(status_code=400, detail=res["detection_message"]) | |
| async def call_aim_guardrail_on_output( | |
| self, request_data: dict, output: str, hook: str, key_alias: Optional[str] | |
| ) -> Optional[str]: | |
| user_email = ( | |
| request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") | |
| ) | |
| call_id = request_data.get("litellm_call_id") | |
| response = await self.async_handler.post( | |
| f"{self.api_base}/detect/output", | |
| headers=self._build_aim_headers( | |
| hook=hook, | |
| key_alias=key_alias, | |
| user_email=user_email, | |
| litellm_call_id=call_id, | |
| ), | |
| json={"output": output, "messages": request_data.get("messages", [])}, | |
| ) | |
| response.raise_for_status() | |
| res = response.json() | |
| detected = res["detected"] | |
| verbose_proxy_logger.info( | |
| "Aim: detected: {detected}, enabled policies: {policies}".format( | |
| detected=detected, | |
| policies=list(res["details"].keys()), | |
| ), | |
| ) | |
| if detected: | |
| return res["detection_message"] | |
| return None | |
| def _build_aim_headers( | |
| self, | |
| *, | |
| hook: str, | |
| key_alias: Optional[str], | |
| user_email: Optional[str], | |
| litellm_call_id: Optional[str], | |
| ): | |
| """ | |
| A helper function to build the http headers that are required by AIM guardrails. | |
| """ | |
| return ( | |
| { | |
| "Authorization": f"Bearer {self.api_key}", | |
| # Used by Aim to apply only the guardrails that should be applied in a specific request phase. | |
| "x-aim-litellm-hook": hook, | |
| # Used by Aim to track LiteLLM version and provide backward compatibility. | |
| "x-aim-litellm-version": litellm_version, | |
| } | |
| # Used by Aim to track together single call input and output | |
| | ({"x-aim-litellm-call-id": litellm_call_id} if litellm_call_id else {}) | |
| # Used by Aim to track guardrails violations by user. | |
| | ({"x-aim-user-email": user_email} if user_email else {}) | |
| | ( | |
| { | |
| # Used by Aim apply only the guardrails that are associated with the key alias. | |
| "x-aim-litellm-key-alias": key_alias, | |
| } | |
| if key_alias | |
| else {} | |
| ) | |
| ) | |
| async def async_post_call_success_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse], | |
| ) -> Any: | |
| if ( | |
| isinstance(response, ModelResponse) | |
| and response.choices | |
| and isinstance(response.choices[0], Choices) | |
| ): | |
| content = response.choices[0].message.content or "" | |
| detection = await self.call_aim_guardrail_on_output( | |
| data, content, hook="output", key_alias=user_api_key_dict.key_alias | |
| ) | |
| if detection: | |
| raise HTTPException(status_code=400, detail=detection) | |
| async def async_post_call_streaming_iterator_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response, | |
| request_data: dict, | |
| ) -> AsyncGenerator[ModelResponseStream, None]: | |
| user_email = ( | |
| request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email") | |
| ) | |
| call_id = request_data.get("litellm_call_id") | |
| async with connect( | |
| f"{self.ws_api_base}/detect/output/ws", | |
| additional_headers=self._build_aim_headers( | |
| hook="output", | |
| key_alias=user_api_key_dict.key_alias, | |
| user_email=user_email, | |
| litellm_call_id=call_id, | |
| ), | |
| ) as websocket: | |
| sender = asyncio.create_task( | |
| self.forward_the_stream_to_aim(websocket, response) | |
| ) | |
| while True: | |
| result = json.loads(await websocket.recv()) | |
| if verified_chunk := result.get("verified_chunk"): | |
| yield ModelResponseStream.model_validate(verified_chunk) | |
| else: | |
| sender.cancel() | |
| if result.get("done"): | |
| return | |
| if blocking_message := result.get("blocking_message"): | |
| raise StreamingCallbackError(blocking_message) | |
| verbose_proxy_logger.error( | |
| f"Unknown message received from AIM: {result}" | |
| ) | |
| return | |
| async def forward_the_stream_to_aim( | |
| self, | |
| websocket: ClientConnection, | |
| response_iter, | |
| ) -> None: | |
| async for chunk in response_iter: | |
| if isinstance(chunk, BaseModel): | |
| chunk = chunk.model_dump_json() | |
| if isinstance(chunk, dict): | |
| chunk = json.dumps(chunk) | |
| await websocket.send(chunk) | |
| await websocket.send(json.dumps({"done": True})) | |