C-6: ConnectorRuntime enforces capability_mask per operation.
READ-only ACs cannot invoke MUTATE operations (wipe, lock, retire).
C-7: AC validated against database (exists, active, not expired)
before connector invocation.
C-9: Delegated AC capability bounded by delegator's capability.
C-10: Command counter uses atomic SQL increment with limit check.
M-23: expire_stale() uses same atomic SQL pattern.
H-1: Sensitive credential fields hidden from repr/logs via repr=False.
H-2: Stub backend requires ALLOW_STUB_CREDENTIALS=true to activate.
H-3: Kerberos backend raises CredentialResolutionError instead of
returning stub ticket.
H-4: Chronicle INTENT emitted before execution, RESULT after.
H-5: device_id validated as UUID before Graph API URL interpolation.
H-8: ConnectorRuntime enforces governance for all connector invocations.
Signed-off-by: Tyler King <tking@guildhouse.dev>
179 lines
8.8 KiB
Python
179 lines
8.8 KiB
Python
from threading import Lock
|
|
from hashlib import sha256
|
|
|
|
from .individual_cache import _IndividualCache as IndividualCache
|
|
from .individual_cache import _ExpiringMapping as ExpiringMapping
|
|
from .oauth2cli.http import Response
|
|
from .exceptions import MsalServiceError
|
|
|
|
|
|
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
|
|
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
|
|
|
|
|
|
def _get_headers(response):
|
|
# MSAL's HttpResponse did not have headers until 1.23.0
|
|
# https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/581/files#diff-28866b706bc3830cd20485685f20fe79d45b58dce7050e68032e9d9372d68654R61
|
|
# This helper ensures graceful degradation to {} without exception
|
|
return getattr(response, "headers", {})
|
|
|
|
|
|
class RetryAfterParser(object):
|
|
FIELD_NAME_LOWER = "Retry-After".lower()
|
|
def __init__(self, default_value=None):
|
|
self._default_value = 5 if default_value is None else default_value
|
|
|
|
def parse(self, *, result, **ignored):
|
|
"""Return seconds to throttle"""
|
|
response = result
|
|
lowercase_headers = {k.lower(): v for k, v in _get_headers(response).items()}
|
|
if not (response.status_code == 429 or response.status_code >= 500
|
|
or self.FIELD_NAME_LOWER in lowercase_headers):
|
|
return 0 # Quick exit
|
|
retry_after = lowercase_headers.get(self.FIELD_NAME_LOWER, self._default_value)
|
|
try:
|
|
# AAD's retry_after uses integer format only
|
|
# https://stackoverflow.microsoft.com/questions/264931/264932
|
|
delay_seconds = int(retry_after)
|
|
except ValueError:
|
|
delay_seconds = self._default_value
|
|
return min(3600, delay_seconds)
|
|
|
|
|
|
def _extract_data(kwargs, key, default=None):
|
|
data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
|
|
return data.get(key) if isinstance(data, dict) else default
|
|
|
|
|
|
class NormalizedResponse(Response):
|
|
"""A http response with the shape defined in Response,
|
|
but contains only the data we will store in cache.
|
|
"""
|
|
def __init__(self, raw_response):
|
|
super().__init__()
|
|
self.status_code = raw_response.status_code
|
|
self.text = raw_response.text
|
|
self.headers = {
|
|
k.lower(): v for k, v in _get_headers(raw_response).items()
|
|
# Attempted storing only a small set of headers (such as Retry-After),
|
|
# but it tends to lead to missing information (such as WWW-Authenticate).
|
|
# So we store all headers, which are expected to contain only public info,
|
|
# because we throttle only error responses and public responses.
|
|
}
|
|
|
|
## Note: Don't use the following line,
|
|
## because when being pickled, it will indirectly pickle the whole raw_response
|
|
# self.raise_for_status = raw_response.raise_for_status
|
|
def raise_for_status(self):
|
|
if self.status_code >= 400:
|
|
raise MsalServiceError(
|
|
"HTTP Error: {}".format(self.status_code),
|
|
error=None, error_description=None, # Historically required, keeping them for now
|
|
)
|
|
|
|
|
|
class ThrottledHttpClientBase(object):
|
|
"""Throttle the given http_client by storing and retrieving data from cache.
|
|
|
|
This base exists so that:
|
|
1. These base post() and get() will return a NormalizedResponse
|
|
2. The base __init__() will NOT re-throttle even if caller accidentally nested ThrottledHttpClient.
|
|
|
|
Subclasses shall only need to dynamically decorate their post() and get() methods
|
|
in their __init__() method.
|
|
"""
|
|
def __init__(self, http_client, *, http_cache=None):
|
|
self.http_client = http_client.http_client if isinstance(
|
|
# If it is already a ThrottledHttpClientBase, we use its raw (unthrottled) http client
|
|
http_client, ThrottledHttpClientBase) else http_client
|
|
self._expiring_mapping = ExpiringMapping( # It will automatically clean up
|
|
mapping=http_cache if http_cache is not None else {},
|
|
capacity=1024, # To prevent cache blowing up especially for CCA
|
|
lock=Lock(), # TODO: This should ideally also allow customization
|
|
)
|
|
|
|
def post(self, *args, **kwargs):
|
|
return NormalizedResponse(self.http_client.post(*args, **kwargs))
|
|
|
|
def get(self, *args, **kwargs):
|
|
return NormalizedResponse(self.http_client.get(*args, **kwargs))
|
|
|
|
def close(self):
|
|
return self.http_client.close()
|
|
|
|
@staticmethod
|
|
def _hash(raw):
|
|
return sha256(repr(raw).encode("utf-8")).hexdigest()
|
|
|
|
|
|
class ThrottledHttpClient(ThrottledHttpClientBase):
|
|
"""A throttled http client that is used by MSAL's non-managed identity clients."""
|
|
def __init__(self, *args, default_throttle_time=None, **kwargs):
|
|
"""Decorate self.post() and self.get() dynamically"""
|
|
super(ThrottledHttpClient, self).__init__(*args, **kwargs)
|
|
self.post = IndividualCache(
|
|
# Internal specs requires throttling on at least token endpoint,
|
|
# here we have a generic patch for POST on all endpoints.
|
|
mapping=self._expiring_mapping,
|
|
key_maker=lambda func, args, kwargs:
|
|
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
|
|
args[0], # It is the url, typically containing authority and tenant
|
|
_extract_data(kwargs, "client_id"), # Per internal specs
|
|
_extract_data(kwargs, "scope"), # Per internal specs
|
|
self._hash(
|
|
# The followings are all approximations of the "account" concept
|
|
# to support per-account throttling.
|
|
# TODO: We may want to disable it for confidential client, though
|
|
_extract_data(kwargs, "refresh_token", # "account" during refresh
|
|
_extract_data(kwargs, "code", # "account" of auth code grant
|
|
_extract_data(kwargs, "username")))), # "account" of ROPC
|
|
),
|
|
expires_in=RetryAfterParser(default_throttle_time or 5).parse,
|
|
)(self.post)
|
|
|
|
self.post = IndividualCache( # It covers the "UI required cache"
|
|
mapping=self._expiring_mapping,
|
|
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
|
|
args[0], # It is the url, typically containing authority and tenant
|
|
self._hash(
|
|
# Here we use literally all parameters, even those short-lived
|
|
# parameters containing timestamps (WS-Trust or POP assertion),
|
|
# because they will automatically be cleaned up by ExpiringMapping.
|
|
#
|
|
# Furthermore, there is no need to implement
|
|
# "interactive requests would reset the cache",
|
|
# because acquire_token_silent()'s would be automatically unblocked
|
|
# due to token cache layer operates on top of http cache layer.
|
|
#
|
|
# And, acquire_token_silent(..., force_refresh=True) will NOT
|
|
# bypass http cache, because there is no real gain from that.
|
|
# We won't bother implement it, nor do we want to encourage
|
|
# acquire_token_silent(..., force_refresh=True) pattern.
|
|
str(kwargs.get("params")) + str(kwargs.get("data"))),
|
|
),
|
|
expires_in=lambda result=None, kwargs=None, **ignored:
|
|
60
|
|
if result.status_code == 400
|
|
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
|
|
# because they are the ones defined in OAuth2
|
|
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
|
|
# Other 4xx errors might have different requirements e.g.
|
|
# "407 Proxy auth required" would need a key including http headers.
|
|
and not( # Exclude Device Flow whose retry is expected and regulated
|
|
isinstance(kwargs.get("data"), dict)
|
|
and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
|
|
)
|
|
and RetryAfterParser.FIELD_NAME_LOWER not in set( # Otherwise leave it to the Retry-After decorator
|
|
h.lower() for h in _get_headers(result))
|
|
else 0,
|
|
)(self.post)
|
|
|
|
self.get = IndividualCache( # Typically those discovery GETs
|
|
mapping=self._expiring_mapping,
|
|
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
|
|
args[0], # It is the url, sometimes containing inline params
|
|
self._hash(kwargs.get("params", "")),
|
|
),
|
|
expires_in=lambda result=None, **ignored:
|
|
3600*24 if 200 <= result.status_code < 300 else 0,
|
|
)(self.get)
|