Spaces:
Paused
Paused
| # +-----------------------------------------------+ | |
| # | | | |
| # | PII Masking | | |
| # | with Microsoft Presidio | | |
| # | https://github.com/BerriAI/litellm/issues/ | | |
| # +-----------------------------------------------+ | |
| # | |
| # Tell us how we can improve! - Krrish & Ishaan | |
| import asyncio | |
| import json | |
| import uuid | |
| from typing import Any, List, Optional, Tuple, Union | |
| import aiohttp | |
| from pydantic import BaseModel | |
| import litellm # noqa: E401 | |
| from litellm import get_secret | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.caching.caching import DualCache | |
| from litellm.integrations.custom_guardrail import ( | |
| CustomGuardrail, | |
| log_guardrail_information, | |
| ) | |
| from litellm.proxy._types import UserAPIKeyAuth | |
| from litellm.types.guardrails import GuardrailEventHooks | |
| from litellm.utils import ( | |
| EmbeddingResponse, | |
| ImageResponse, | |
| ModelResponse, | |
| StreamingChoices, | |
| ) | |
| class PresidioPerRequestConfig(BaseModel): | |
| """ | |
| presdio params that can be controlled per request, api key | |
| """ | |
| language: Optional[str] = None | |
| class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): | |
| user_api_key_cache = None | |
| ad_hoc_recognizers = None | |
| # Class variables or attributes | |
| def __init__( | |
| self, | |
| mock_testing: bool = False, | |
| mock_redacted_text: Optional[dict] = None, | |
| presidio_analyzer_api_base: Optional[str] = None, | |
| presidio_anonymizer_api_base: Optional[str] = None, | |
| output_parse_pii: Optional[bool] = False, | |
| presidio_ad_hoc_recognizers: Optional[str] = None, | |
| logging_only: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| if logging_only is True: | |
| self.logging_only = True | |
| kwargs["event_hook"] = GuardrailEventHooks.logging_only | |
| super().__init__(**kwargs) | |
| self.pii_tokens: dict = ( | |
| {} | |
| ) # mapping of PII token to original text - only used with Presidio `replace` operation | |
| self.mock_redacted_text = mock_redacted_text | |
| self.output_parse_pii = output_parse_pii or False | |
| if mock_testing is True: # for testing purposes only | |
| return | |
| ad_hoc_recognizers = presidio_ad_hoc_recognizers | |
| if ad_hoc_recognizers is not None: | |
| try: | |
| with open(ad_hoc_recognizers, "r") as file: | |
| self.ad_hoc_recognizers = json.load(file) | |
| except FileNotFoundError: | |
| raise Exception(f"File not found. file_path={ad_hoc_recognizers}") | |
| except json.JSONDecodeError as e: | |
| raise Exception( | |
| f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}" | |
| ) | |
| except Exception as e: | |
| raise Exception( | |
| f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" | |
| ) | |
| self.validate_environment( | |
| presidio_analyzer_api_base=presidio_analyzer_api_base, | |
| presidio_anonymizer_api_base=presidio_anonymizer_api_base, | |
| ) | |
| def validate_environment( | |
| self, | |
| presidio_analyzer_api_base: Optional[str] = None, | |
| presidio_anonymizer_api_base: Optional[str] = None, | |
| ): | |
| self.presidio_analyzer_api_base: Optional[ | |
| str | |
| ] = presidio_analyzer_api_base or get_secret( | |
| "PRESIDIO_ANALYZER_API_BASE", None | |
| ) # type: ignore | |
| self.presidio_anonymizer_api_base: Optional[ | |
| str | |
| ] = presidio_anonymizer_api_base or litellm.get_secret( | |
| "PRESIDIO_ANONYMIZER_API_BASE", None | |
| ) # type: ignore | |
| if self.presidio_analyzer_api_base is None: | |
| raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment") | |
| if not self.presidio_analyzer_api_base.endswith("/"): | |
| self.presidio_analyzer_api_base += "/" | |
| if not ( | |
| self.presidio_analyzer_api_base.startswith("http://") | |
| or self.presidio_analyzer_api_base.startswith("https://") | |
| ): | |
| # add http:// if unset, assume communicating over private network - e.g. render | |
| self.presidio_analyzer_api_base = ( | |
| "http://" + self.presidio_analyzer_api_base | |
| ) | |
| if self.presidio_anonymizer_api_base is None: | |
| raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment") | |
| if not self.presidio_anonymizer_api_base.endswith("/"): | |
| self.presidio_anonymizer_api_base += "/" | |
| if not ( | |
| self.presidio_anonymizer_api_base.startswith("http://") | |
| or self.presidio_anonymizer_api_base.startswith("https://") | |
| ): | |
| # add http:// if unset, assume communicating over private network - e.g. render | |
| self.presidio_anonymizer_api_base = ( | |
| "http://" + self.presidio_anonymizer_api_base | |
| ) | |
| async def check_pii( | |
| self, | |
| text: str, | |
| output_parse_pii: bool, | |
| presidio_config: Optional[PresidioPerRequestConfig], | |
| request_data: dict, | |
| ) -> str: | |
| """ | |
| [TODO] make this more performant for high-throughput scenario | |
| """ | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| if self.mock_redacted_text is not None: | |
| redacted_text = self.mock_redacted_text | |
| else: | |
| # Make the first request to /analyze | |
| # Construct Request 1 | |
| analyze_url = f"{self.presidio_analyzer_api_base}analyze" | |
| analyze_payload = {"text": text, "language": "en"} | |
| if presidio_config and presidio_config.language: | |
| analyze_payload["language"] = presidio_config.language | |
| if self.ad_hoc_recognizers is not None: | |
| analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers | |
| # End of constructing Request 1 | |
| analyze_payload.update( | |
| self.get_guardrail_dynamic_request_body_params( | |
| request_data=request_data | |
| ) | |
| ) | |
| redacted_text = None | |
| verbose_proxy_logger.debug( | |
| "Making request to: %s with payload: %s", | |
| analyze_url, | |
| analyze_payload, | |
| ) | |
| async with session.post( | |
| analyze_url, json=analyze_payload | |
| ) as response: | |
| analyze_results = await response.json() | |
| # Make the second request to /anonymize | |
| anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize" | |
| verbose_proxy_logger.debug("Making request to: %s", anonymize_url) | |
| anonymize_payload = { | |
| "text": text, | |
| "analyzer_results": analyze_results, | |
| } | |
| async with session.post( | |
| anonymize_url, json=anonymize_payload | |
| ) as response: | |
| redacted_text = await response.json() | |
| new_text = text | |
| if redacted_text is not None: | |
| verbose_proxy_logger.debug("redacted_text: %s", redacted_text) | |
| for item in redacted_text["items"]: | |
| start = item["start"] | |
| end = item["end"] | |
| replacement = item["text"] # replacement token | |
| if item["operator"] == "replace" and output_parse_pii is True: | |
| # check if token in dict | |
| # if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing | |
| if replacement in self.pii_tokens: | |
| replacement = replacement + str(uuid.uuid4()) | |
| self.pii_tokens[replacement] = new_text[ | |
| start:end | |
| ] # get text it'll replace | |
| new_text = new_text[:start] + replacement + new_text[end:] | |
| return redacted_text["text"] | |
| else: | |
| raise Exception(f"Invalid anonymizer response: {redacted_text}") | |
| except Exception as e: | |
| raise e | |
| async def async_pre_call_hook( | |
| self, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| cache: DualCache, | |
| data: dict, | |
| call_type: str, | |
| ): | |
| """ | |
| - Check if request turned off pii | |
| - Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls') | |
| - Take the request data | |
| - Call /analyze -> get the results | |
| - Call /anonymize w/ the analyze results -> get the redacted text | |
| For multiple messages in /chat/completions, we'll need to call them in parallel. | |
| """ | |
| try: | |
| content_safety = data.get("content_safety", None) | |
| verbose_proxy_logger.debug("content_safety: %s", content_safety) | |
| presidio_config = self.get_presidio_settings_from_request_data(data) | |
| if call_type == "completion": # /chat/completions requests | |
| messages = data["messages"] | |
| tasks = [] | |
| for m in messages: | |
| if isinstance(m["content"], str): | |
| tasks.append( | |
| self.check_pii( | |
| text=m["content"], | |
| output_parse_pii=self.output_parse_pii, | |
| presidio_config=presidio_config, | |
| request_data=data, | |
| ) | |
| ) | |
| responses = await asyncio.gather(*tasks) | |
| for index, r in enumerate(responses): | |
| if isinstance(messages[index]["content"], str): | |
| messages[index][ | |
| "content" | |
| ] = r # replace content with redacted string | |
| verbose_proxy_logger.info( | |
| f"Presidio PII Masking: Redacted pii message: {data['messages']}" | |
| ) | |
| data["messages"] = messages | |
| return data | |
| except Exception as e: | |
| raise e | |
| def logging_hook( | |
| self, kwargs: dict, result: Any, call_type: str | |
| ) -> Tuple[dict, Any]: | |
| from concurrent.futures import ThreadPoolExecutor | |
| def run_in_new_loop(): | |
| """Run the coroutine in a new event loop within this thread.""" | |
| new_loop = asyncio.new_event_loop() | |
| try: | |
| asyncio.set_event_loop(new_loop) | |
| return new_loop.run_until_complete( | |
| self.async_logging_hook( | |
| kwargs=kwargs, result=result, call_type=call_type | |
| ) | |
| ) | |
| finally: | |
| new_loop.close() | |
| asyncio.set_event_loop(None) | |
| try: | |
| # First, try to get the current event loop | |
| _ = asyncio.get_running_loop() | |
| # If we're already in an event loop, run in a separate thread | |
| # to avoid nested event loop issues | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(run_in_new_loop) | |
| return future.result() | |
| except RuntimeError: | |
| # No running event loop, we can safely run in this thread | |
| return run_in_new_loop() | |
| async def async_logging_hook( | |
| self, kwargs: dict, result: Any, call_type: str | |
| ) -> Tuple[dict, Any]: | |
| """ | |
| Masks the input before logging to langfuse, datadog, etc. | |
| """ | |
| if ( | |
| call_type == "completion" or call_type == "acompletion" | |
| ): # /chat/completions requests | |
| messages: Optional[List] = kwargs.get("messages", None) | |
| tasks = [] | |
| if messages is None: | |
| return kwargs, result | |
| presidio_config = self.get_presidio_settings_from_request_data(kwargs) | |
| for m in messages: | |
| text_str = "" | |
| if m["content"] is None: | |
| continue | |
| if isinstance(m["content"], str): | |
| text_str = m["content"] | |
| tasks.append( | |
| self.check_pii( | |
| text=text_str, | |
| output_parse_pii=False, | |
| presidio_config=presidio_config, | |
| request_data=kwargs, | |
| ) | |
| ) # need to pass separately b/c presidio has context window limits | |
| responses = await asyncio.gather(*tasks) | |
| for index, r in enumerate(responses): | |
| if isinstance(messages[index]["content"], str): | |
| messages[index][ | |
| "content" | |
| ] = r # replace content with redacted string | |
| verbose_proxy_logger.info( | |
| f"Presidio PII Masking: Redacted pii message: {messages}" | |
| ) | |
| kwargs["messages"] = messages | |
| return kwargs, result | |
| async def async_post_call_success_hook( # type: ignore | |
| self, | |
| data: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| response: Union[ModelResponse, EmbeddingResponse, ImageResponse], | |
| ): | |
| """ | |
| Output parse the response object to replace the masked tokens with user sent values | |
| """ | |
| verbose_proxy_logger.debug( | |
| f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" | |
| ) | |
| if self.output_parse_pii is False and litellm.output_parse_pii is False: | |
| return response | |
| if isinstance(response, ModelResponse) and not isinstance( | |
| response.choices[0], StreamingChoices | |
| ): # /chat/completions requests | |
| if isinstance(response.choices[0].message.content, str): | |
| verbose_proxy_logger.debug( | |
| f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}" | |
| ) | |
| for key, value in self.pii_tokens.items(): | |
| response.choices[0].message.content = response.choices[ | |
| 0 | |
| ].message.content.replace(key, value) | |
| return response | |
| def get_presidio_settings_from_request_data( | |
| self, data: dict | |
| ) -> Optional[PresidioPerRequestConfig]: | |
| if "metadata" in data: | |
| _metadata = data["metadata"] | |
| _guardrail_config = _metadata.get("guardrail_config") | |
| if _guardrail_config: | |
| _presidio_config = PresidioPerRequestConfig(**_guardrail_config) | |
| return _presidio_config | |
| return None | |
| def print_verbose(self, print_statement): | |
| try: | |
| verbose_proxy_logger.debug(print_statement) | |
| if litellm.set_verbose: | |
| print(print_statement) # noqa | |
| except Exception: | |
| pass | |