Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import uuid | |
| from datetime import datetime, timezone | |
| from typing import Any, List, Optional | |
| from fastapi import status | |
| import litellm | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.proxy._types import ( | |
| GenerateKeyRequest, | |
| GenerateKeyResponse, | |
| KeyRequest, | |
| LiteLLM_AuditLogs, | |
| LiteLLM_VerificationToken, | |
| LitellmTableNames, | |
| ProxyErrorTypes, | |
| ProxyException, | |
| RegenerateKeyRequest, | |
| UpdateKeyRequest, | |
| UserAPIKeyAuth, | |
| WebhookEvent, | |
| ) | |
| # NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager | |
| LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/" | |
| class KeyManagementEventHooks: | |
| async def async_key_generated_hook( | |
| data: GenerateKeyRequest, | |
| response: GenerateKeyResponse, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| litellm_changed_by: Optional[str] = None, | |
| ): | |
| """ | |
| Hook that runs after a successful /key/generate request | |
| Handles the following: | |
| - Sending Email with Key Details | |
| - Storing Audit Logs for key generation | |
| - Storing Generated Key in DB | |
| """ | |
| from litellm.proxy.management_helpers.audit_logs import ( | |
| create_audit_log_for_update, | |
| ) | |
| from litellm.proxy.proxy_server import litellm_proxy_admin_name | |
| if data.send_invite_email is True: | |
| await KeyManagementEventHooks._send_key_created_email( | |
| response.model_dump(exclude_none=True) | |
| ) | |
| # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True | |
| if litellm.store_audit_logs is True: | |
| _updated_values = response.model_dump_json(exclude_none=True) | |
| asyncio.create_task( | |
| create_audit_log_for_update( | |
| request_data=LiteLLM_AuditLogs( | |
| id=str(uuid.uuid4()), | |
| updated_at=datetime.now(timezone.utc), | |
| changed_by=litellm_changed_by | |
| or user_api_key_dict.user_id | |
| or litellm_proxy_admin_name, | |
| changed_by_api_key=user_api_key_dict.api_key, | |
| table_name=LitellmTableNames.KEY_TABLE_NAME, | |
| object_id=response.token_id or "", | |
| action="created", | |
| updated_values=_updated_values, | |
| before_value=None, | |
| ) | |
| ) | |
| ) | |
| # store the generated key in the secret manager | |
| await KeyManagementEventHooks._store_virtual_key_in_secret_manager( | |
| secret_name=data.key_alias or f"virtual-key-{response.token_id}", | |
| secret_token=response.key, | |
| ) | |
| async def async_key_updated_hook( | |
| data: UpdateKeyRequest, | |
| existing_key_row: Any, | |
| response: Any, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| litellm_changed_by: Optional[str] = None, | |
| ): | |
| """ | |
| Post /key/update processing hook | |
| Handles the following: | |
| - Storing Audit Logs for key update | |
| """ | |
| from litellm.proxy.management_helpers.audit_logs import ( | |
| create_audit_log_for_update, | |
| ) | |
| from litellm.proxy.proxy_server import litellm_proxy_admin_name | |
| # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True | |
| if litellm.store_audit_logs is True: | |
| _updated_values = json.dumps(data.json(exclude_none=True), default=str) | |
| _before_value = existing_key_row.json(exclude_none=True) | |
| _before_value = json.dumps(_before_value, default=str) | |
| asyncio.create_task( | |
| create_audit_log_for_update( | |
| request_data=LiteLLM_AuditLogs( | |
| id=str(uuid.uuid4()), | |
| updated_at=datetime.now(timezone.utc), | |
| changed_by=litellm_changed_by | |
| or user_api_key_dict.user_id | |
| or litellm_proxy_admin_name, | |
| changed_by_api_key=user_api_key_dict.api_key, | |
| table_name=LitellmTableNames.KEY_TABLE_NAME, | |
| object_id=data.key, | |
| action="updated", | |
| updated_values=_updated_values, | |
| before_value=_before_value, | |
| ) | |
| ) | |
| ) | |
| async def async_key_rotated_hook( | |
| data: Optional[RegenerateKeyRequest], | |
| existing_key_row: Any, | |
| response: GenerateKeyResponse, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| litellm_changed_by: Optional[str] = None, | |
| ): | |
| # store the generated key in the secret manager | |
| if data is not None and response.token_id is not None: | |
| initial_secret_name = ( | |
| existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}" | |
| ) | |
| await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager( | |
| current_secret_name=initial_secret_name, | |
| new_secret_name=data.key_alias or f"virtual-key-{response.token_id}", | |
| new_secret_value=response.key, | |
| ) | |
| async def async_key_deleted_hook( | |
| data: KeyRequest, | |
| keys_being_deleted: List[LiteLLM_VerificationToken], | |
| response: dict, | |
| user_api_key_dict: UserAPIKeyAuth, | |
| litellm_changed_by: Optional[str] = None, | |
| ): | |
| """ | |
| Post /key/delete processing hook | |
| Handles the following: | |
| - Storing Audit Logs for key deletion | |
| """ | |
| from litellm.proxy.management_helpers.audit_logs import ( | |
| create_audit_log_for_update, | |
| ) | |
| from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client | |
| # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True | |
| # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes | |
| if litellm.store_audit_logs is True and data.keys is not None: | |
| # make an audit log for each team deleted | |
| for key in data.keys: | |
| key_row = await prisma_client.get_data( # type: ignore | |
| token=key, table_name="key", query_type="find_unique" | |
| ) | |
| if key_row is None: | |
| raise ProxyException( | |
| message=f"Key {key} not found", | |
| type=ProxyErrorTypes.bad_request_error, | |
| param="key", | |
| code=status.HTTP_404_NOT_FOUND, | |
| ) | |
| key_row = key_row.json(exclude_none=True) | |
| _key_row = json.dumps(key_row, default=str) | |
| asyncio.create_task( | |
| create_audit_log_for_update( | |
| request_data=LiteLLM_AuditLogs( | |
| id=str(uuid.uuid4()), | |
| updated_at=datetime.now(timezone.utc), | |
| changed_by=litellm_changed_by | |
| or user_api_key_dict.user_id | |
| or litellm_proxy_admin_name, | |
| changed_by_api_key=user_api_key_dict.api_key, | |
| table_name=LitellmTableNames.KEY_TABLE_NAME, | |
| object_id=key, | |
| action="deleted", | |
| updated_values="{}", | |
| before_value=_key_row, | |
| ) | |
| ) | |
| ) | |
| # delete the keys from the secret manager | |
| await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager( | |
| keys_being_deleted=keys_being_deleted | |
| ) | |
| pass | |
| async def _store_virtual_key_in_secret_manager(secret_name: str, secret_token: str): | |
| """ | |
| Store a virtual key in the secret manager | |
| Args: | |
| secret_name: Name of the virtual key | |
| secret_token: Value of the virtual key (example: sk-1234) | |
| """ | |
| if litellm._key_management_settings is not None: | |
| if litellm._key_management_settings.store_virtual_keys is True: | |
| from litellm.secret_managers.base_secret_manager import ( | |
| BaseSecretManager, | |
| ) | |
| # store the key in the secret manager | |
| if isinstance(litellm.secret_manager_client, BaseSecretManager): | |
| await litellm.secret_manager_client.async_write_secret( | |
| secret_name=KeyManagementEventHooks._get_secret_name( | |
| secret_name | |
| ), | |
| secret_value=secret_token, | |
| ) | |
| async def _rotate_virtual_key_in_secret_manager( | |
| current_secret_name: str, new_secret_name: str, new_secret_value: str | |
| ): | |
| """ | |
| Update a virtual key in the secret manager | |
| Args: | |
| secret_name: Name of the virtual key | |
| secret_token: Value of the virtual key (example: sk-1234) | |
| """ | |
| if litellm._key_management_settings is not None: | |
| if litellm._key_management_settings.store_virtual_keys is True: | |
| from litellm.secret_managers.base_secret_manager import ( | |
| BaseSecretManager, | |
| ) | |
| # store the key in the secret manager | |
| if isinstance(litellm.secret_manager_client, BaseSecretManager): | |
| await litellm.secret_manager_client.async_rotate_secret( | |
| current_secret_name=KeyManagementEventHooks._get_secret_name( | |
| current_secret_name | |
| ), | |
| new_secret_name=KeyManagementEventHooks._get_secret_name( | |
| new_secret_name | |
| ), | |
| new_secret_value=new_secret_value, | |
| ) | |
| def _get_secret_name(secret_name: str) -> str: | |
| if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith( | |
| "/" | |
| ): | |
| return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}" | |
| else: | |
| return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}" | |
| async def _delete_virtual_keys_from_secret_manager( | |
| keys_being_deleted: List[LiteLLM_VerificationToken], | |
| ): | |
| """ | |
| Deletes virtual keys from the secret manager | |
| Args: | |
| keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation | |
| """ | |
| if litellm._key_management_settings is not None: | |
| if litellm._key_management_settings.store_virtual_keys is True: | |
| from litellm.secret_managers.base_secret_manager import ( | |
| BaseSecretManager, | |
| ) | |
| if isinstance(litellm.secret_manager_client, BaseSecretManager): | |
| for key in keys_being_deleted: | |
| if key.key_alias is not None: | |
| await litellm.secret_manager_client.async_delete_secret( | |
| secret_name=KeyManagementEventHooks._get_secret_name( | |
| key.key_alias | |
| ) | |
| ) | |
| else: | |
| verbose_proxy_logger.warning( | |
| f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager." | |
| ) | |
| async def _send_key_created_email(response: dict): | |
| from litellm.proxy.proxy_server import general_settings, proxy_logging_obj | |
| if "email" not in general_settings.get("alerting", []): | |
| raise ValueError( | |
| "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" | |
| ) | |
| event = WebhookEvent( | |
| event="key_created", | |
| event_group="key", | |
| event_message="API Key Created", | |
| token=response.get("token", ""), | |
| spend=response.get("spend", 0.0), | |
| max_budget=response.get("max_budget", 0.0), | |
| user_id=response.get("user_id", None), | |
| team_id=response.get("team_id", "Default Team"), | |
| key_alias=response.get("key_alias", None), | |
| ) | |
| # If user configured email alerting - send an Email letting their end-user know the key was created | |
| asyncio.create_task( | |
| proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( | |
| webhook_event=event, | |
| ) | |
| ) | |