Spaces:
Paused
Paused
| # +-------------------------------------------------------------+ | |
| # | |
| # Use AporiaAI for your LLM calls | |
| # | |
| # +-------------------------------------------------------------+ | |
| # Thank you users! We ❤️ you! - Krrish & Ishaan | |
| import os | |
| import sys | |
| sys.path.insert( | |
| 0, os.path.abspath("../..") | |
| ) # Adds the parent directory to the system path | |
| import json | |
| import sys | |
| from typing import Any, List, Literal, Optional | |
| from fastapi import HTTPException | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.integrations.custom_guardrail import ( | |
| CustomGuardrail, | |
| log_guardrail_information, | |
| ) | |
| from litellm.litellm_core_utils.logging_utils import ( | |
| convert_litellm_response_object_to_str, | |
| ) | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| get_async_httpx_client, | |
| httpxSpecialProvider, | |
| ) | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata | |
| from litellm.types.guardrails import GuardrailEventHooks | |
| litellm.set_verbose = True | |
| GUARDRAIL_NAME = "aporia" | |
| class AporiaGuardrail(CustomGuardrail): | |
| def __init__( | |
| self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs | |
| ): | |
| self.async_handler = get_async_httpx_client( | |
| llm_provider=httpxSpecialProvider.GuardrailCallback | |
| ) | |
| self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"] | |
| self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"] | |
| super().__init__(**kwargs) | |
| #### CALL HOOKS - proxy only #### | |
| def transform_messages(self, messages: List[dict]) -> List[dict]: | |
| supported_openai_roles = ["system", "user", "assistant"] | |
| default_role = "other" # for unsupported roles - e.g. tool | |
| new_messages = [] | |
| for m in messages: | |
| if m.get("role", "") in supported_openai_roles: | |
| new_messages.append(m) | |
| else: | |
| new_messages.append( | |
| { | |
| "role": default_role, | |
| **{key: value for key, value in m.items() if key != "role"}, | |
| } | |
| ) | |
| return new_messages | |
| async def prepare_aporia_request( | |
| self, new_messages: List[dict], response_string: Optional[str] = None | |
| ) -> dict: | |
| data: dict[str, Any] = {} | |
| if new_messages is not None: | |
| data["messages"] = new_messages | |
| if response_string is not None: | |
| data["response"] = response_string | |
| # Set validation target | |
| if new_messages and response_string: | |
| data["validation_target"] = "both" | |
| elif new_messages: | |
| data["validation_target"] = "prompt" | |
| elif response_string: | |
| data["validation_target"] = "response" | |
| verbose_proxy_logger.debug("Aporia AI request: %s", data) | |
| return data | |
| async def make_aporia_api_request( | |
| self, | |
| request_data: dict, | |
| new_messages: List[dict], | |
| response_string: Optional[str] = None, | |
| ): | |
| data = await self.prepare_aporia_request( | |
| new_messages=new_messages, response_string=response_string | |
| ) | |
| data.update( | |
| self.get_guardrail_dynamic_request_body_params(request_data=request_data) | |
| ) | |
| _json_data = json.dumps(data) | |
| """ | |
| export APORIO_API_KEY=<your key> | |
| curl https://gr-prd-trial.aporia.com/some-id \ | |
| -X POST \ | |
| -H "X-APORIA-API-KEY: $APORIO_API_KEY" \ | |
| -H "Content-Type: application/json" \ | |
| -d '{ | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": "This is a test prompt" | |
| } | |
| ], | |
| } | |
| ' | |
| """ | |
| response = await self.async_handler.post( | |
| url=self.aporia_api_base + "/validate", | |
| data=_json_data, | |
| headers={ | |
| "X-APORIA-API-KEY": self.aporia_api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| ) | |
| verbose_proxy_logger.debug("Aporia AI response: %s", response.text) | |
| if response.status_code == 200: | |
| # check if the response was flagged | |
| _json_response = response.json() | |
| action: str = _json_response.get( | |
| "action" | |
| ) # possible values are modify, passthrough, block, rephrase | |
| if action == "block": | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "error": "Violated guardrail policy", | |
| "aporia_ai_response": _json_response, | |
| }, | |
| ) | |
| async def async_post_call_success_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response, | |
| ): | |
| from litellm.proxy.common_utils.callback_utils import ( | |
| add_guardrail_to_applied_guardrails_header, | |
| ) | |
| """ | |
| Use this for the post call moderation with Guardrails | |
| """ | |
| event_type: GuardrailEventHooks = GuardrailEventHooks.post_call | |
| if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
| return | |
| response_str: Optional[str] = convert_litellm_response_object_to_str(response) | |
| if response_str is not None: | |
| await self.make_aporia_api_request( | |
| request_data=data, | |
| response_string=response_str, | |
| new_messages=data.get("messages", []), | |
| ) | |
| add_guardrail_to_applied_guardrails_header( | |
| request_data=data, guardrail_name=self.guardrail_name | |
| ) | |
| pass | |
| async def async_moderation_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| call_type: Literal[ | |
| "completion", | |
| "embeddings", | |
| "image_generation", | |
| "moderation", | |
| "audio_transcription", | |
| "responses", | |
| ], | |
| ): | |
| from litellm.proxy.common_utils.callback_utils import ( | |
| add_guardrail_to_applied_guardrails_header, | |
| ) | |
| event_type: GuardrailEventHooks = GuardrailEventHooks.during_call | |
| if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
| return | |
| # old implementation - backwards compatibility | |
| if ( | |
| await should_proceed_based_on_metadata( | |
| data=data, | |
| guardrail_name=GUARDRAIL_NAME, | |
| ) | |
| is False | |
| ): | |
| return | |
| new_messages: Optional[List[dict]] = None | |
| if "messages" in data and isinstance(data["messages"], list): | |
| new_messages = self.transform_messages(messages=data["messages"]) | |
| if new_messages is not None: | |
| await self.make_aporia_api_request( | |
| request_data=data, | |
| new_messages=new_messages, | |
| ) | |
| add_guardrail_to_applied_guardrails_header( | |
| request_data=data, guardrail_name=self.guardrail_name | |
| ) | |
| else: | |
| verbose_proxy_logger.warning( | |
| "Aporia AI: not running guardrail. No messages in data" | |
| ) | |
| pass | |