Spaces:
Paused
Paused
| from litellm.proxy._types import UserAPIKeyAuth | |
| async def check_oauth2_token(token: str) -> UserAPIKeyAuth: | |
| """ | |
| Makes a request to the token info endpoint to validate the OAuth2 token. | |
| Args: | |
| token (str): The OAuth2 token to validate. | |
| Returns: | |
| Literal[True]: If the token is valid. | |
| Raises: | |
| ValueError: If the token is invalid, the request fails, or the token info endpoint is not set. | |
| """ | |
| import os | |
| import httpx | |
| from litellm._logging import verbose_proxy_logger | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| get_async_httpx_client, | |
| httpxSpecialProvider, | |
| ) | |
| from litellm.proxy._types import CommonProxyErrors | |
| from litellm.proxy.proxy_server import premium_user | |
| if premium_user is not True: | |
| raise ValueError( | |
| "Oauth2 token validation is only available for premium users" | |
| + CommonProxyErrors.not_premium_user.value | |
| ) | |
| verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token) | |
| # Get the token info endpoint from environment variable | |
| token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT") | |
| user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub") | |
| user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role") | |
| user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id") | |
| if not token_info_endpoint: | |
| raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set") | |
| client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) | |
| headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} | |
| try: | |
| response = await client.get(token_info_endpoint, headers=headers) | |
| # if it's a bad token we expect it to raise an HTTPStatusError | |
| response.raise_for_status() | |
| # If we get here, the request was successful | |
| data = response.json() | |
| verbose_proxy_logger.debug( | |
| "Oauth2 token validation for token=%s, response from /token/info=%s", | |
| token, | |
| data, | |
| ) | |
| # You might want to add additional checks here based on the response | |
| # For example, checking if the token is expired or has the correct scope | |
| user_id = data.get(user_id_field_name) | |
| user_team_id = data.get(user_team_id_field_name) | |
| user_role = data.get(user_role_field_name) | |
| return UserAPIKeyAuth( | |
| api_key=token, | |
| team_id=user_team_id, | |
| user_id=user_id, | |
| user_role=user_role, | |
| ) | |
| except httpx.HTTPStatusError as e: | |
| # This will catch any 4xx or 5xx errors | |
| raise ValueError(f"Oauth 2.0 Token validation failed: {e}") | |
| except Exception as e: | |
| # This will catch any other errors (like network issues) | |
| raise ValueError(f"An error occurred during token validation: {e}") | |