Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| import os | |
| import uuid | |
| from datetime import datetime, timedelta, timezone | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional | |
| from urllib.parse import quote | |
| from litellm._logging import verbose_logger | |
| from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils | |
| from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase | |
| from litellm.proxy._types import CommonProxyErrors | |
| from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus | |
| from litellm.types.integrations.gcs_bucket import * | |
| from litellm.types.utils import StandardLoggingPayload | |
| if TYPE_CHECKING: | |
| from litellm.llms.vertex_ai.vertex_llm_base import VertexBase | |
| else: | |
| VertexBase = Any | |
| class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils): | |
| def __init__(self, bucket_name: Optional[str] = None) -> None: | |
| from litellm.proxy.proxy_server import premium_user | |
| super().__init__(bucket_name=bucket_name) | |
| # Init Batch logging settings | |
| self.log_queue: List[GCSLogQueueItem] = [] | |
| self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE)) | |
| self.flush_interval = int( | |
| os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS) | |
| ) | |
| asyncio.create_task(self.periodic_flush()) | |
| self.flush_lock = asyncio.Lock() | |
| super().__init__( | |
| flush_lock=self.flush_lock, | |
| batch_size=self.batch_size, | |
| flush_interval=self.flush_interval, | |
| ) | |
| AdditionalLoggingUtils.__init__(self) | |
| if premium_user is not True: | |
| raise ValueError( | |
| f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" | |
| ) | |
| #### ASYNC #### | |
| async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): | |
| from litellm.proxy.proxy_server import premium_user | |
| if premium_user is not True: | |
| raise ValueError( | |
| f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" | |
| ) | |
| try: | |
| verbose_logger.debug( | |
| "GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s", | |
| kwargs, | |
| response_obj, | |
| ) | |
| logging_payload: Optional[StandardLoggingPayload] = kwargs.get( | |
| "standard_logging_object", None | |
| ) | |
| if logging_payload is None: | |
| raise ValueError("standard_logging_object not found in kwargs") | |
| # Add to logging queue - this will be flushed periodically | |
| self.log_queue.append( | |
| GCSLogQueueItem( | |
| payload=logging_payload, kwargs=kwargs, response_obj=response_obj | |
| ) | |
| ) | |
| except Exception as e: | |
| verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") | |
| async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): | |
| try: | |
| verbose_logger.debug( | |
| "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s", | |
| kwargs, | |
| response_obj, | |
| ) | |
| logging_payload: Optional[StandardLoggingPayload] = kwargs.get( | |
| "standard_logging_object", None | |
| ) | |
| if logging_payload is None: | |
| raise ValueError("standard_logging_object not found in kwargs") | |
| # Add to logging queue - this will be flushed periodically | |
| self.log_queue.append( | |
| GCSLogQueueItem( | |
| payload=logging_payload, kwargs=kwargs, response_obj=response_obj | |
| ) | |
| ) | |
| except Exception as e: | |
| verbose_logger.exception(f"GCS Bucket logging error: {str(e)}") | |
| async def async_send_batch(self): | |
| """ | |
| Process queued logs in batch - sends logs to GCS Bucket | |
| GCS Bucket does not have a Batch endpoint to batch upload logs | |
| Instead, we | |
| - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds | |
| - during async_send_batch, we make 1 POST request per log to GCS Bucket | |
| """ | |
| if not self.log_queue: | |
| return | |
| for log_item in self.log_queue: | |
| logging_payload = log_item["payload"] | |
| kwargs = log_item["kwargs"] | |
| response_obj = log_item.get("response_obj", None) or {} | |
| gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config( | |
| kwargs | |
| ) | |
| headers = await self.construct_request_headers( | |
| vertex_instance=gcs_logging_config["vertex_instance"], | |
| service_account_json=gcs_logging_config["path_service_account"], | |
| ) | |
| bucket_name = gcs_logging_config["bucket_name"] | |
| object_name = self._get_object_name(kwargs, logging_payload, response_obj) | |
| try: | |
| await self._log_json_data_on_gcs( | |
| headers=headers, | |
| bucket_name=bucket_name, | |
| object_name=object_name, | |
| logging_payload=logging_payload, | |
| ) | |
| except Exception as e: | |
| # don't let one log item fail the entire batch | |
| verbose_logger.exception( | |
| f"GCS Bucket error logging payload to GCS bucket: {str(e)}" | |
| ) | |
| pass | |
| # Clear the queue after processing | |
| self.log_queue.clear() | |
| def _get_object_name( | |
| self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any | |
| ) -> str: | |
| """ | |
| Get the object name to use for the current payload | |
| """ | |
| current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc)) | |
| if logging_payload.get("error_str", None) is not None: | |
| object_name = self._generate_failure_object_name( | |
| request_date_str=current_date, | |
| ) | |
| else: | |
| object_name = self._generate_success_object_name( | |
| request_date_str=current_date, | |
| response_id=response_obj.get("id", ""), | |
| ) | |
| # used for testing | |
| _litellm_params = kwargs.get("litellm_params", None) or {} | |
| _metadata = _litellm_params.get("metadata", None) or {} | |
| if "gcs_log_id" in _metadata: | |
| object_name = _metadata["gcs_log_id"] | |
| return object_name | |
| async def get_request_response_payload( | |
| self, | |
| request_id: str, | |
| start_time_utc: Optional[datetime], | |
| end_time_utc: Optional[datetime], | |
| ) -> Optional[dict]: | |
| """ | |
| Get the request and response payload for a given `request_id` | |
| Tries current day, next day, and previous day until it finds the payload | |
| """ | |
| if start_time_utc is None: | |
| raise ValueError( | |
| "start_time_utc is required for getting a payload from GCS Bucket" | |
| ) | |
| # Try current day, next day, and previous day | |
| dates_to_try = [ | |
| start_time_utc, | |
| start_time_utc + timedelta(days=1), | |
| start_time_utc - timedelta(days=1), | |
| ] | |
| date_str = None | |
| for date in dates_to_try: | |
| try: | |
| date_str = self._get_object_date_from_datetime(datetime_obj=date) | |
| object_name = self._generate_success_object_name( | |
| request_date_str=date_str, | |
| response_id=request_id, | |
| ) | |
| encoded_object_name = quote(object_name, safe="") | |
| response = await self.download_gcs_object(encoded_object_name) | |
| if response is not None: | |
| loaded_response = json.loads(response) | |
| return loaded_response | |
| except Exception as e: | |
| verbose_logger.debug( | |
| f"Failed to fetch payload for date {date_str}: {str(e)}" | |
| ) | |
| continue | |
| return None | |
| def _generate_success_object_name( | |
| self, | |
| request_date_str: str, | |
| response_id: str, | |
| ) -> str: | |
| return f"{request_date_str}/{response_id}" | |
| def _generate_failure_object_name( | |
| self, | |
| request_date_str: str, | |
| ) -> str: | |
| return f"{request_date_str}/failure-{uuid.uuid4().hex}" | |
| def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str: | |
| return datetime_obj.strftime("%Y-%m-%d") | |
| async def async_health_check(self) -> IntegrationHealthCheckStatus: | |
| raise NotImplementedError("GCS Bucket does not support health check") | |