fastapi-gsap/gsap_broker/drivers/entra.py
Tyler J King 8196396ce6 feat(drivers): add native Entra identity driver
Validates Entra JWTs directly via JWKS verification.
Extracts device_id for compliance gating, MFA status,
roles, and constructs DID from Entra tenant + oid.
Adds device_id field to AuthResult dataclass.

Signed-off-by: Tyler King <tking@guildhouse.dev>
2026-04-14 05:19:54 -04:00

152 lines
5.5 KiB
Python

# Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Entra identity driver — GSAP §2.2.
Validates Entra-issued JWTs directly via JWKS verification.
Extracts device_id for compliance gating, MFA status, roles,
and constructs DID from Entra tenant + oid.
"""
import logging
import time
from typing import Any, Optional
import httpx
from jose import JWTError, jwt as jose_jwt
from .base import AuthResult, ElevationRequired, IdentityDriver
logger = logging.getLogger(__name__)
# JWKS cache: {tenant_id: (keys, fetched_at)}
_jwks_cache: dict[str, tuple[dict[str, Any], float]] = {}
_JWKS_TTL = 86400 # 24 hours
async def _get_jwks(tenant_id: str) -> dict[str, Any]:
"""Fetch and cache Entra JWKS keys."""
cached = _jwks_cache.get(tenant_id)
if cached and (time.time() - cached[1]) < _JWKS_TTL:
return cached[0]
url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(url)
resp.raise_for_status()
jwks = resp.json()
_jwks_cache[tenant_id] = (jwks, time.time())
return jwks
def _find_signing_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]:
"""Find the key matching the JWT kid header."""
for key in jwks.get("keys", []):
if key.get("kid") == kid:
return key
return None
class EntraDriver(IdentityDriver):
"""Identity driver for direct Entra JWT validation."""
async def authenticate(self) -> AuthResult:
token_data = self.config.get("_token_data", {})
raw_token = self.config.get("_raw_token", "")
tenant_id = self.config.get("entra_tenant_id", "")
expected_audience = self.config.get("entra_client_id", "")
if not token_data:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason="No token in context.",
)
# If we have raw_token and tenant_id, perform JWKS verification.
if raw_token and tenant_id:
try:
jwks = await _get_jwks(tenant_id)
# Extract kid from unverified header
unverified_header = jose_jwt.get_unverified_header(raw_token)
kid = unverified_header.get("kid", "")
signing_key = _find_signing_key(jwks, kid)
if not signing_key:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason=f"No matching signing key for kid={kid}",
)
# Verify signature, exp, nbf, iss, aud
verified = jose_jwt.decode(
raw_token,
signing_key,
algorithms=["RS256"],
audience=expected_audience,
issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0",
)
# Use verified claims instead of unverified decode
token_data = verified
except JWTError as e:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason=f"JWT verification failed: {e}",
)
except httpx.HTTPError as e:
logger.warning("JWKS fetch failed, falling back to unverified: %s", e)
# Fall through to use the unverified token_data
# Extract claims
oid = token_data.get("oid", "")
tid = token_data.get("tid", tenant_id)
roles = token_data.get("roles", [])
acrs = token_data.get("acrs", [])
amr = token_data.get("amr", [])
device_id = token_data.get("deviceid") or token_data.get("device_id")
upn = token_data.get("preferred_username") or token_data.get("upn", "")
display_name = token_data.get("name", upn)
if not oid:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason="Token missing oid claim.",
)
# Check role requirement for requested accord
requested_accord = self.config.get("requested_accord", "")
required_role = self.config.get("accord_roles", {}).get(requested_accord, "")
suffix = self.config.get("elevated_suffix", "-elevated")
elevation_active = [r for r in roles if r.endswith(suffix)]
if required_role and required_role not in roles:
return AuthResult(
status=AuthResult.STATUS_PENDING_ELEVATION,
elevation_required=ElevationRequired(
role=required_role,
activation_url="/governance/elevate/",
instructions=f"Request elevation to '{required_role}' via POST /governance/elevate/",
mechanism="entra_pim",
),
)
# Construct DID
domain = self.config.get("domain", "guildhouse.dev")
principal_did = f"did:web:{domain}:principal:{oid}"
# MFA detection
mfa_satisfied = "mfa" in amr or "ngcmfa" in amr
return AuthResult(
status=AuthResult.STATUS_AUTHORIZED,
principal_did=principal_did,
display_name=display_name,
stable_id=oid,
token_jti=token_data.get("jti", ""),
elevation_active=elevation_active,
mfa_satisfied=mfa_satisfied,
device_id=device_id,
)
async def revoke(self, session_id: str) -> None:
logger.info("Entra revoke: %s", session_id)