Spaces:
Paused
Paused
| # +-------------------------------------------------------------+ | |
| # | |
| # Use GuardrailsAI for your LLM calls | |
| # | |
| # +-------------------------------------------------------------+ | |
| # Thank you for using Litellm! - Krrish & Ishaan | |
| import json | |
| from typing import Optional, TypedDict | |
| from fastapi import HTTPException | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.integrations.custom_guardrail import ( | |
| CustomGuardrail, | |
| log_guardrail_information, | |
| ) | |
| from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
| get_content_from_model_response, | |
| ) | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.proxy.common_utils.callback_utils import ( | |
| add_guardrail_to_applied_guardrails_header, | |
| ) | |
| from litellm.types.guardrails import GuardrailEventHooks | |
| class GuardrailsAIResponse(TypedDict): | |
| callId: str | |
| rawLlmOutput: str | |
| validatedOutput: str | |
| validationPassed: bool | |
| class GuardrailsAI(CustomGuardrail): | |
| def __init__( | |
| self, | |
| guard_name: str, | |
| api_base: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| if guard_name is None: | |
| raise Exception( | |
| "GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'" | |
| ) | |
| # store kwargs as optional_params | |
| self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000" | |
| self.guardrails_ai_guard_name = guard_name | |
| self.optional_params = kwargs | |
| supported_event_hooks = [GuardrailEventHooks.post_call] | |
| super().__init__(supported_event_hooks=supported_event_hooks, **kwargs) | |
| async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict): | |
| from httpx import URL | |
| data = { | |
| "llmOutput": llm_output, | |
| **self.get_guardrail_dynamic_request_body_params(request_data=request_data), | |
| } | |
| _json_data = json.dumps(data) | |
| response = await litellm.module_level_aclient.post( | |
| url=str( | |
| URL(self.guardrails_ai_api_base).join( | |
| f"guards/{self.guardrails_ai_guard_name}/validate" | |
| ) | |
| ), | |
| data=_json_data, | |
| headers={ | |
| "Content-Type": "application/json", | |
| }, | |
| ) | |
| verbose_proxy_logger.debug("guardrails_ai response: %s", response) | |
| _json_response = GuardrailsAIResponse(**response.json()) # type: ignore | |
| if _json_response.get("validationPassed") is False: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "error": "Violated guardrail policy", | |
| "guardrails_ai_response": _json_response, | |
| }, | |
| ) | |
| return _json_response | |
| async def async_post_call_success_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response, | |
| ): | |
| """ | |
| Runs on response from LLM API call | |
| It can be used to reject a response | |
| """ | |
| event_type: GuardrailEventHooks = GuardrailEventHooks.post_call | |
| if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
| return | |
| if not isinstance(response, litellm.ModelResponse): | |
| return | |
| response_str: str = get_content_from_model_response(response) | |
| if response_str is not None and len(response_str) > 0: | |
| await self.make_guardrails_ai_api_request( | |
| llm_output=response_str, request_data=data | |
| ) | |
| add_guardrail_to_applied_guardrails_header( | |
| request_data=data, guardrail_name=self.guardrail_name | |
| ) | |
| return | |