GraphRag
/
graphrag-ollama
/lib
/python3.12
/site-packages
/azure
/identity
/_internal
/interactive.py
| # ------------------------------------ | |
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| # ------------------------------------ | |
| """Base class for credentials using MSAL for interactive user authentication""" | |
| import abc | |
| import base64 | |
| import json | |
| import logging | |
| import time | |
| from typing import Any, Optional, Iterable, Dict | |
| from urllib.parse import urlparse | |
| from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions | |
| from azure.core.exceptions import ClientAuthenticationError | |
| from .msal_credentials import MsalCredential | |
| from .._auth_record import AuthenticationRecord | |
| from .._constants import KnownAuthorities | |
| from .._exceptions import AuthenticationRequiredError, CredentialUnavailableError | |
| from .._internal import wrap_exceptions | |
| ABC = abc.ABC | |
| _LOGGER = logging.getLogger(__name__) | |
| _DEFAULT_AUTHENTICATE_SCOPES = { | |
| "https://" + KnownAuthorities.AZURE_CHINA: ("https://management.core.chinacloudapi.cn//.default",), | |
| "https://" + KnownAuthorities.AZURE_GOVERNMENT: ("https://management.core.usgovcloudapi.net//.default",), | |
| "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD: ("https://management.core.windows.net//.default",), | |
| } | |
| def _decode_client_info(raw) -> str: | |
| """Decode client info. Taken from msal.oauth2cli.oidc. | |
| :param str raw: base64-encoded client info | |
| :return: decoded client info | |
| :rtype: str | |
| """ | |
| raw += "=" * (-len(raw) % 4) | |
| raw = str(raw) # On Python 2.7, argument of urlsafe_b64decode must be str, not unicode. | |
| return base64.urlsafe_b64decode(raw).decode("utf-8") | |
| def _build_auth_record(response): | |
| """Build an AuthenticationRecord from the result of an MSAL ClientApplication token request. | |
| :param response: The result of a token request | |
| :type response: dict[str, typing.Any] | |
| :return: An AuthenticationRecord | |
| :rtype: ~azure.identity.AuthenticationRecord | |
| :raises ~azure.core.exceptions.ClientAuthenticationError: If the response doesn't contain expected data | |
| """ | |
| try: | |
| id_token = response["id_token_claims"] | |
| if "client_info" in response: | |
| client_info = json.loads(_decode_client_info(response["client_info"])) | |
| home_account_id = "{uid}.{utid}".format(**client_info) | |
| else: | |
| # MSAL uses the subject claim as home_account_id when the STS doesn't provide client_info | |
| home_account_id = id_token["sub"] | |
| # "iss" is the URL of the issuing tenant e.g. https://authority/tenant | |
| issuer = urlparse(id_token["iss"]) | |
| # tenant which issued the token, not necessarily user's home tenant | |
| tenant_id = id_token.get("tid") or issuer.path.strip("/") | |
| # Microsoft Entra ID returns "preferred_username", ADFS returns "upn" | |
| username = id_token.get("preferred_username") or id_token["upn"] | |
| return AuthenticationRecord( | |
| authority=issuer.netloc, | |
| client_id=id_token["aud"], | |
| home_account_id=home_account_id, | |
| tenant_id=tenant_id, | |
| username=username, | |
| ) | |
| except (KeyError, ValueError) as ex: | |
| auth_error = ClientAuthenticationError( | |
| message="Failed to build AuthenticationRecord from unexpected identity token" | |
| ) | |
| raise auth_error from ex | |
| class InteractiveCredential(MsalCredential, ABC): | |
| def __init__( | |
| self, | |
| *, | |
| authentication_record: Optional[AuthenticationRecord] = None, | |
| disable_automatic_authentication: bool = False, | |
| **kwargs: Any, | |
| ) -> None: | |
| self._disable_automatic_authentication = disable_automatic_authentication | |
| self._auth_record = authentication_record | |
| if self._auth_record: | |
| kwargs.pop("client_id", None) # authentication_record overrides client_id argument | |
| tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id | |
| super(InteractiveCredential, self).__init__( | |
| client_id=self._auth_record.client_id, | |
| authority=self._auth_record.authority, | |
| tenant_id=tenant_id, | |
| **kwargs, | |
| ) | |
| else: | |
| super(InteractiveCredential, self).__init__(**kwargs) | |
| def get_token( | |
| self, | |
| *scopes: str, | |
| claims: Optional[str] = None, | |
| tenant_id: Optional[str] = None, | |
| enable_cae: bool = False, | |
| **kwargs: Any, | |
| ) -> AccessToken: | |
| """Request an access token for `scopes`. | |
| This method is called automatically by Azure SDK clients. | |
| :param str scopes: desired scopes for the access token. This method requires at least one scope. | |
| For more information about scopes, see | |
| https://learn.microsoft.com/entra/identity-platform/scopes-oidc. | |
| :keyword str claims: additional claims required in the token, such as those returned in a resource provider's | |
| claims challenge following an authorization failure | |
| :keyword str tenant_id: optional tenant to include in the token request. | |
| :keyword bool enable_cae: indicates whether to enable Continuous Access Evaluation (CAE) for the requested | |
| token. Defaults to False. | |
| :return: An access token with the desired scopes. | |
| :rtype: ~azure.core.credentials.AccessToken | |
| :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks | |
| required data, state, or platform support | |
| :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` | |
| attribute gives a reason. | |
| :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is | |
| configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. | |
| """ | |
| options: TokenRequestOptions = {} | |
| if claims: | |
| options["claims"] = claims | |
| if tenant_id: | |
| options["tenant_id"] = tenant_id | |
| options["enable_cae"] = enable_cae | |
| token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) | |
| return AccessToken(token_info.token, token_info.expires_on) | |
| def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: | |
| """Request an access token for `scopes`. | |
| This is an alternative to `get_token` to enable certain scenarios that require additional properties | |
| on the token. This method is called automatically by Azure SDK clients. | |
| :param str scopes: desired scopes for the access token. This method requires at least one scope. | |
| For more information about scopes, see https://learn.microsoft.com/entra/identity-platform/scopes-oidc. | |
| :keyword options: A dictionary of options for the token request. Unknown options will be ignored. Optional. | |
| :paramtype options: ~azure.core.credentials.TokenRequestOptions | |
| :rtype: AccessTokenInfo | |
| :return: An AccessTokenInfo instance containing information about the token. | |
| :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks | |
| required data, state, or platform support | |
| :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` | |
| attribute gives a reason. | |
| :raises AuthenticationRequiredError: user interaction is necessary to acquire a token, and the credential is | |
| configured not to begin this automatically. Call :func:`authenticate` to begin interactive authentication. | |
| """ | |
| return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") | |
| def _get_token_base( | |
| self, | |
| *scopes: str, | |
| options: Optional[TokenRequestOptions] = None, | |
| base_method_name: str = "get_token_info", | |
| **kwargs: Any, | |
| ) -> AccessTokenInfo: | |
| if not scopes: | |
| message = f"'{base_method_name}' requires at least one scope" | |
| _LOGGER.warning("%s.%s failed: %s", self.__class__.__name__, base_method_name, message) | |
| raise ValueError(message) | |
| allow_prompt = kwargs.pop("_allow_prompt", not self._disable_automatic_authentication) | |
| options = options or {} | |
| claims = options.get("claims") | |
| tenant_id = options.get("tenant_id") | |
| enable_cae = options.get("enable_cae", False) | |
| # Check for arbitrary additional options to enable intermediary support for PoP tokens. | |
| for key in options: | |
| if key not in TokenRequestOptions.__annotations__: # pylint:disable=no-member | |
| kwargs.setdefault(key, options[key]) # type: ignore | |
| try: | |
| token = self._acquire_token_silent( | |
| *scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs | |
| ) | |
| _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) | |
| return token | |
| except Exception as ex: # pylint:disable=broad-except | |
| if not (isinstance(ex, AuthenticationRequiredError) and allow_prompt): | |
| _LOGGER.warning( | |
| "%s.%s failed: %s", | |
| self.__class__.__name__, | |
| base_method_name, | |
| ex, | |
| exc_info=_LOGGER.isEnabledFor(logging.DEBUG), | |
| ) | |
| raise | |
| # silent authentication failed -> authenticate interactively | |
| now = int(time.time()) | |
| try: | |
| result = self._request_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) | |
| if "access_token" not in result: | |
| message = "Authentication failed: {}".format(result.get("error_description") or result.get("error")) | |
| response = self._client.get_error_response(result) | |
| raise ClientAuthenticationError(message=message, response=response) | |
| # this may be the first authentication, or the user may have authenticated a different identity | |
| self._auth_record = _build_auth_record(result) | |
| except Exception as ex: # pylint:disable=broad-except | |
| _LOGGER.warning( | |
| "%s.%s failed: %s", | |
| self.__class__.__name__, | |
| base_method_name, | |
| ex, | |
| exc_info=_LOGGER.isEnabledFor(logging.DEBUG), | |
| ) | |
| raise | |
| _LOGGER.info("%s.%s succeeded", self.__class__.__name__, base_method_name) | |
| refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None | |
| return AccessTokenInfo( | |
| result["access_token"], | |
| now + int(result["expires_in"]), | |
| token_type=result.get("token_type", "Bearer"), | |
| refresh_on=refresh_on, | |
| ) | |
| def authenticate( | |
| self, *, scopes: Optional[Iterable[str]] = None, claims: Optional[str] = None, **kwargs: Any | |
| ) -> AuthenticationRecord: | |
| """Interactively authenticate a user. This method will always generate a challenge to the user. | |
| :keyword Iterable[str] scopes: scopes to request during authentication, such as those provided by | |
| :func:`AuthenticationRequiredError.scopes`. If provided, successful authentication will cache an access token | |
| for these scopes. | |
| :keyword str claims: additional claims required in the token, such as those provided by | |
| :func:`AuthenticationRequiredError.claims` | |
| :rtype: ~azure.identity.AuthenticationRecord | |
| :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message`` | |
| attribute gives a reason. | |
| """ | |
| if not scopes: | |
| if self._authority not in _DEFAULT_AUTHENTICATE_SCOPES: | |
| # the credential is configured to use a cloud whose ARM scope we can't determine | |
| raise CredentialUnavailableError( | |
| message="Authenticating in this environment requires a value for the 'scopes' keyword argument." | |
| ) | |
| scopes = _DEFAULT_AUTHENTICATE_SCOPES[self._authority] | |
| _ = self.get_token(*scopes, _allow_prompt=True, claims=claims, **kwargs) | |
| return self._auth_record # type: ignore | |
| def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: | |
| result = None | |
| claims = kwargs.get("claims") | |
| if self._auth_record: | |
| app = self._get_app(**kwargs) | |
| for account in app.get_accounts(username=self._auth_record.username): | |
| if account.get("home_account_id") != self._auth_record.home_account_id: | |
| continue | |
| now = int(time.time()) | |
| result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims) | |
| if result and "access_token" in result and "expires_in" in result: | |
| refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None | |
| return AccessTokenInfo( | |
| result["access_token"], | |
| now + int(result["expires_in"]), | |
| token_type=result.get("token_type", "Bearer"), | |
| refresh_on=refresh_on, | |
| ) | |
| # if we get this far, result is either None or the content of a Microsoft Entra ID error response | |
| if result: | |
| response = self._client.get_error_response(result) | |
| raise AuthenticationRequiredError(scopes, claims=claims, response=response) | |
| raise AuthenticationRequiredError(scopes, claims=claims) | |
| def _request_token(self, *scopes, **kwargs) -> Dict: | |
| pass | |