GraphRag
/
graphrag-ollama
/lib
/python3.12
/site-packages
/azure
/identity
/_credentials
/silent.py
| # ------------------------------------ | |
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| # ------------------------------------ | |
| import platform | |
| import time | |
| from typing import Dict, Optional, Any | |
| from msal import PublicClientApplication, TokenCache | |
| from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions | |
| from azure.core.exceptions import ClientAuthenticationError | |
| from .. import CredentialUnavailableError | |
| from .._internal import resolve_tenant, validate_tenant_id, within_dac | |
| from .._internal.decorators import wrap_exceptions | |
| from .._internal.msal_client import MsalClient | |
| from .._internal.shared_token_cache import NO_TOKEN | |
| from .._persistent_cache import _load_persistent_cache, TokenCachePersistenceOptions | |
| from .. import AuthenticationRecord | |
| class SilentAuthenticationCredential: | |
| """Internal class for authenticating from the default shared cache given an AuthenticationRecord. | |
| :param authentication_record: an AuthenticationRecord from which to authenticate | |
| :type authentication_record: ~azure.identity.AuthenticationRecord | |
| :keyword str tenant_id: tenant ID of the application the credential is authenticating for. Defaults to the tenant | |
| """ | |
| def __init__( | |
| self, authentication_record: AuthenticationRecord, *, tenant_id: Optional[str] = None, **kwargs | |
| ) -> None: | |
| self._auth_record = authentication_record | |
| # authenticate in the tenant that produced the record unless "tenant_id" specifies another | |
| self._tenant_id = tenant_id or self._auth_record.tenant_id | |
| validate_tenant_id(self._tenant_id) | |
| self._cache = kwargs.pop("_cache", None) | |
| self._cae_cache = kwargs.pop("_cae_cache", None) | |
| if self._cache or self._cae_cache: | |
| self._custom_cache = True | |
| else: | |
| self._custom_cache = False | |
| self._cache_persistence_options = kwargs.pop("cache_persistence_options", None) | |
| self._client_applications: Dict[str, PublicClientApplication] = {} | |
| self._cae_client_applications: Dict[str, PublicClientApplication] = {} | |
| self._additionally_allowed_tenants = kwargs.pop("additionally_allowed_tenants", []) | |
| self._client = MsalClient(**kwargs) | |
| def __enter__(self) -> "SilentAuthenticationCredential": | |
| self._client.__enter__() | |
| return self | |
| def __exit__(self, *args): | |
| self._client.__exit__(*args) | |
| def close(self) -> None: | |
| self.__exit__() | |
| def get_token( | |
| self, | |
| *scopes: str, | |
| claims: Optional[str] = None, | |
| tenant_id: Optional[str] = None, | |
| enable_cae: bool = False, | |
| **kwargs: Any, | |
| ) -> AccessToken: | |
| options: TokenRequestOptions = {} | |
| if claims: | |
| options["claims"] = claims | |
| if tenant_id: | |
| options["tenant_id"] = tenant_id | |
| options["enable_cae"] = enable_cae | |
| token_info = self._get_token_base(*scopes, options=options, base_method_name="get_token", **kwargs) | |
| return AccessToken(token_info.token, token_info.expires_on) | |
| def get_token_info(self, *scopes: str, options: Optional[TokenRequestOptions] = None) -> AccessTokenInfo: | |
| return self._get_token_base(*scopes, options=options, base_method_name="get_token_info") | |
| def _get_token_base( | |
| self, | |
| *scopes: str, | |
| options: Optional[TokenRequestOptions] = None, | |
| base_method_name: str = "get_token_info", | |
| **kwargs: Any, | |
| ) -> AccessTokenInfo: | |
| if not scopes: | |
| raise ValueError(f"'{base_method_name}' requires at least one scope") | |
| options = options or {} | |
| claims = options.get("claims") | |
| tenant_id = options.get("tenant_id") | |
| enable_cae = options.get("enable_cae", False) | |
| token_cache = self._cae_cache if enable_cae else self._cache | |
| # Try to load the cache if it is None. | |
| if not token_cache: | |
| token_cache = self._initialize_cache(is_cae=enable_cae) | |
| # If the cache is still None, raise an error. | |
| if not token_cache: | |
| if within_dac.get(): | |
| raise CredentialUnavailableError(message="Shared token cache unavailable") | |
| raise ClientAuthenticationError(message="Shared token cache unavailable") | |
| return self._acquire_token_silent(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, **kwargs) | |
| def _initialize_cache(self, is_cae: bool = False) -> Optional[TokenCache]: | |
| # If no cache options were provided, the default cache will be used. This credential accepts the | |
| # user's default cache regardless of whether it's encrypted. It doesn't create a new cache. If the | |
| # default cache exists, the user must have created it earlier. If it's unencrypted, the user must | |
| # have allowed that. | |
| cache_options = self._cache_persistence_options or TokenCachePersistenceOptions(allow_unencrypted_storage=True) | |
| if platform.system() not in {"Darwin", "Linux", "Windows"}: | |
| raise CredentialUnavailableError(message="Shared token cache is not supported on this platform.") | |
| if not self._cache and not is_cae: | |
| try: | |
| self._cache = _load_persistent_cache(cache_options, is_cae) | |
| except Exception: # pylint:disable=broad-except | |
| return None | |
| if not self._cae_cache and is_cae: | |
| try: | |
| self._cae_cache = _load_persistent_cache(cache_options, is_cae) | |
| except Exception: # pylint:disable=broad-except | |
| return None | |
| return self._cae_cache if is_cae else self._cache | |
| def _get_client_application(self, **kwargs: Any): | |
| tenant_id = resolve_tenant( | |
| self._tenant_id, additionally_allowed_tenants=self._additionally_allowed_tenants, **kwargs | |
| ) | |
| client_applications_map = self._client_applications | |
| capabilities = None | |
| token_cache = self._cache | |
| if kwargs.get("enable_cae"): | |
| client_applications_map = self._cae_client_applications | |
| # CP1 = can handle claims challenges (CAE) | |
| capabilities = ["CP1"] | |
| token_cache = self._cae_cache | |
| if tenant_id not in client_applications_map: | |
| client_applications_map[tenant_id] = PublicClientApplication( | |
| client_id=self._auth_record.client_id, | |
| authority="https://{}/{}".format(self._auth_record.authority, tenant_id), | |
| token_cache=token_cache, | |
| http_client=self._client, | |
| client_capabilities=capabilities, | |
| ) | |
| return client_applications_map[tenant_id] | |
| def _acquire_token_silent(self, *scopes: str, **kwargs: Any) -> AccessTokenInfo: | |
| """Silently acquire a token from MSAL. | |
| :param str scopes: desired scopes for the access token | |
| :return: an access token | |
| :rtype: ~azure.core.credentials.AccessToken | |
| """ | |
| result = None | |
| client_application = self._get_client_application(**kwargs) | |
| accounts_for_user = client_application.get_accounts(username=self._auth_record.username) | |
| if not accounts_for_user: | |
| raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.") | |
| for account in accounts_for_user: | |
| if account.get("home_account_id") != self._auth_record.home_account_id: | |
| continue | |
| now = int(time.time()) | |
| result = client_application.acquire_token_silent_with_error( | |
| list(scopes), account=account, claims_challenge=kwargs.get("claims") | |
| ) | |
| if result and "access_token" in result and "expires_in" in result: | |
| refresh_on = int(result["refresh_on"]) if "refresh_on" in result else None | |
| return AccessTokenInfo( | |
| result["access_token"], | |
| now + int(result["expires_in"]), | |
| token_type=result.get("token_type", "Bearer"), | |
| refresh_on=refresh_on, | |
| ) | |
| # if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently | |
| if result: | |
| # cache contains a matching refresh token but STS returned an error response when MSAL tried to use it | |
| message = "Token acquisition failed" | |
| details = result.get("error_description") or result.get("error") | |
| if details: | |
| message += ": {}".format(details) | |
| raise ClientAuthenticationError(message=message) | |
| # cache doesn't contain a matching refresh (or access) token | |
| raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username)) | |
| def __getstate__(self) -> Dict[str, Any]: | |
| state = self.__dict__.copy() | |
| # Remove the non-picklable entries | |
| if not self._custom_cache: | |
| del state["_cache"] | |
| del state["_cae_cache"] | |
| return state | |
| def __setstate__(self, state: Dict[str, Any]) -> None: | |
| self.__dict__.update(state) | |
| # Re-create the unpickable entries | |
| if not self._custom_cache: | |
| self._cache = None | |
| self._cae_cache = None | |