Spaces:
Paused
Paused
| import traceback | |
| from typing import Optional | |
| from fastapi import HTTPException | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.caching.caching import DualCache | |
| from litellm.integrations.custom_logger import CustomLogger | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| class _PROXY_AzureContentSafety( | |
| CustomLogger | |
| ): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class | |
| # Class variables or attributes | |
| def __init__(self, endpoint, api_key, thresholds=None): | |
| try: | |
| from azure.ai.contentsafety.aio import ContentSafetyClient | |
| from azure.ai.contentsafety.models import ( | |
| AnalyzeTextOptions, | |
| AnalyzeTextOutputType, | |
| TextCategory, | |
| ) | |
| from azure.core.credentials import AzureKeyCredential | |
| from azure.core.exceptions import HttpResponseError | |
| except Exception as e: | |
| raise Exception( | |
| f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" | |
| ) | |
| self.endpoint = endpoint | |
| self.api_key = api_key | |
| self.text_category = TextCategory | |
| self.analyze_text_options = AnalyzeTextOptions | |
| self.analyze_text_output_type = AnalyzeTextOutputType | |
| self.azure_http_error = HttpResponseError | |
| self.thresholds = self._configure_thresholds(thresholds) | |
| self.client = ContentSafetyClient( | |
| self.endpoint, AzureKeyCredential(self.api_key) | |
| ) | |
| def _configure_thresholds(self, thresholds=None): | |
| default_thresholds = { | |
| self.text_category.HATE: 4, | |
| self.text_category.SELF_HARM: 4, | |
| self.text_category.SEXUAL: 4, | |
| self.text_category.VIOLENCE: 4, | |
| } | |
| if thresholds is None: | |
| return default_thresholds | |
| for key, default in default_thresholds.items(): | |
| if key not in thresholds: | |
| thresholds[key] = default | |
| return thresholds | |
| def _compute_result(self, response): | |
| result = {} | |
| category_severity = { | |
| item.category: item.severity for item in response.categories_analysis | |
| } | |
| for category in self.text_category: | |
| severity = category_severity.get(category) | |
| if severity is not None: | |
| result[category] = { | |
| "filtered": severity >= self.thresholds[category], | |
| "severity": severity, | |
| } | |
| return result | |
| async def test_violation(self, content: str, source: Optional[str] = None): | |
| verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) | |
| # Construct a request | |
| request = self.analyze_text_options( | |
| text=content, | |
| output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, | |
| ) | |
| # Analyze text | |
| try: | |
| response = await self.client.analyze_text(request) | |
| except self.azure_http_error: | |
| verbose_proxy_logger.debug( | |
| "Error in Azure Content-Safety: %s", traceback.format_exc() | |
| ) | |
| verbose_proxy_logger.debug(traceback.format_exc()) | |
| raise | |
| result = self._compute_result(response) | |
| verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) | |
| for key, value in result.items(): | |
| if value["filtered"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail={ | |
| "error": "Violated content safety policy", | |
| "source": source, | |
| "category": key, | |
| "severity": value["severity"], | |
| }, | |
| ) | |
| async def async_pre_call_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| cache: DualCache, | |
| data: dict, | |
| call_type: str, # "completion", "embeddings", "image_generation", "moderation" | |
| ): | |
| verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") | |
| try: | |
| if call_type == "completion" and "messages" in data: | |
| for m in data["messages"]: | |
| if "content" in m and isinstance(m["content"], str): | |
| await self.test_violation(content=m["content"], source="input") | |
| except HTTPException as e: | |
| raise e | |
| except Exception as e: | |
| verbose_proxy_logger.error( | |
| "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format( | |
| str(e) | |
| ) | |
| ) | |
| verbose_proxy_logger.debug(traceback.format_exc()) | |
| async def async_post_call_success_hook( | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response, | |
| ): | |
| verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") | |
| if isinstance(response, litellm.ModelResponse) and isinstance( | |
| response.choices[0], litellm.utils.Choices | |
| ): | |
| await self.test_violation( | |
| content=response.choices[0].message.content or "", source="output" | |
| ) | |
| # async def async_post_call_streaming_hook( | |
| # self, | |
| # user_api_key_dict: UserAPIKeyAuth, | |
| # response: str, | |
| # ): | |
| # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook") | |
| # await self.test_violation(content=response, source="output") | |