from threading import Lock from hashlib import sha256 from .individual_cache import _IndividualCache as IndividualCache from .individual_cache import _ExpiringMapping as ExpiringMapping from .oauth2cli.http import Response from .exceptions import MsalServiceError # https://datatracker.ietf.org/doc/html/rfc8628#section-3.4 DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code" def _get_headers(response): # MSAL's HttpResponse did not have headers until 1.23.0 # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61 # This helper ensures graceful degradation to {} without exception return getattr(response, "headers", {}) class RetryAfterParser(object): FIELD_NAME_LOWER = "Retry-After".lower() def __init__(self, default_value=None): self._default_value = 5 if default_value is None else default_value def parse(self, *, result, **ignored): """Return seconds to throttle""" response = result lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()} if not (response.status_code == 429 or response.status_code >= 500 or self.FIELD_NAME_LOWER in lowercase_headers): return 0 # Quick exit retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value) try: # AAD's retry_after uses integer format only # https://stackoverflow.microsoft.com/questions/264931/264932 delay_seconds = int(retry_after) except ValueError: delay_seconds = self._default_value return min(3600, delay_seconds) def _extract_data(kwargs, key, default=None): data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string return data.get(key) if isinstance(data, dict) else default class NormalizedResponse(Response): """A http response with the shape defined in Response, but contains only the data we will store in cache. """ def __init__(self, raw_response): super().__init__() self.status_code = raw_response.status_code self.text = raw_response.text self.headers = { k.lower(): v for k, v in _get_headers(raw_response).items() # Attempted storing only a small set of headers (such as Retry-After), # but it tends to lead to missing information (such as WWW-Authenticate). # So we store all headers, which are expected to contain only public info, # because we throttle only error responses and public responses. } ## Note: Don't use the following line, ## because when being pickled, it will indirectly pickle the whole raw_response # self.raise_for_status = raw_response.raise_for_status def raise_for_status(self): if self.status_code >= 400: raise MsalServiceError( "HTTP Error: {}".format(self.status_code), error=None, error_description=None, # Historically required, keeping them for now ) class ThrottledHttpClientBase(object): """Throttle the given http_client by storing and retrieving data from cache. This base exists so that: 1. These base post() and get() will return a NormalizedResponse 2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient. Subclasses shall only need to dynamically decorate their post() and get() methods in their __init__() method. """ def __init__(self, http_client, *, http_cache=None): self.http_client = http_client.http_client if isinstance( # If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client http_client, ThrottledHttpClientBase) else http_client self._expiring_mapping = ExpiringMapping( # It will automatically clean up mapping=http_cache if http_cache is not None else {}, capacity=1024, # To prevent cache blowing up especially for CCA lock=Lock(), # TODO: This should ideally also allow customization ) def post(self, *args, **kwargs): return NormalizedResponse(self.http_client.post(*args, **kwargs)) def get(self, *args, **kwargs): return NormalizedResponse(self.http_client.get(*args, **kwargs)) def close(self): return self.http_client.close() @staticmethod def _hash(raw): return sha256(repr(raw).encode("utf-8")).hexdigest() class ThrottledHttpClient(ThrottledHttpClientBase): """A throttled http client that is used by MSAL's non-managed identity clients.""" def __init__(self, *args, default_throttle_time=None, **kwargs): """Decorate self.post() and self.get() dynamically""" super(ThrottledHttpClient, self).__init__(*args, **kwargs) self.post = IndividualCache( # Internal specs requires throttling on at least token endpoint, # here we have a generic patch for POST on all endpoints. mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format( args[0], # It is the url, typically containing authority and tenant _extract_data(kwargs, "client_id"), # Per internal specs _extract_data(kwargs, "scope"), # Per internal specs self._hash( # The followings are all approximations of the "account" concept # to support per-account throttling. # TODO: We may want to disable it for confidential client, though _extract_data(kwargs, "refresh_token", # "account" during refresh _extract_data(kwargs, "code", # "account" of auth code grant _extract_data(kwargs, "username")))), # "account" of ROPC ), expires_in=RetryAfterParser(default_throttle_time or 5).parse, )(self.post) self.post = IndividualCache( # It covers the "UI required cache" mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format( args[0], # It is the url, typically containing authority and tenant self._hash( # Here we use literally all parameters, even those short-lived # parameters containing timestamps (WS-Trust or POP assertion), # because they will automatically be cleaned up by ExpiringMapping. # # Furthermore, there is no need to implement # "interactive requests would reset the cache", # because acquire_token_silent()'s would be automatically unblocked # due to token cache layer operates on top of http cache layer. # # And, acquire_token_silent(..., force_refresh=True) will NOT # bypass http cache, because there is no real gain from that. # We won't bother implement it, nor do we want to encourage # acquire_token_silent(..., force_refresh=True) pattern. str(kwargs.get("params")) + str(kwargs.get("data"))), ), expires_in=lambda result=None, kwargs=None, **ignored: 60 if result.status_code == 400 # Here we choose to cache exact HTTP 400 errors only (rather than 4xx) # because they are the ones defined in OAuth2 # (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) # Other 4xx errors might have different requirements e.g. # "407 Proxy auth required" would need a key including http headers. and not( # Exclude Device Flow whose retry is expected and regulated isinstance(kwargs.get("data"), dict) and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT ) and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator h.lower() for h in _get_headers(result)) else 0, )(self.post) self.get = IndividualCache( # Typically those discovery GETs mapping=self._expiring_mapping, key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format( args[0], # It is the url, sometimes containing inline params self._hash(kwargs.get("params", "")), ), expires_in=lambda result=None, **ignored: 3600*24 if 200 <= result.status_code < 300 else 0, )(self.get)