# Copyright 2026 Guildhouse Dev # SPDX-License-Identifier: Apache-2.0 """Entra identity driver — GSAP §2.2. Validates Entra-issued JWTs directly via JWKS verification. Extracts device_id for compliance gating, MFA status, roles, and constructs DID from Entra tenant + oid. """ import logging import time from typing import Any, Optional import httpx from jose import JWTError, jwt as jose_jwt from .base import AuthResult, ElevationRequired, IdentityDriver logger = logging.getLogger(__name__) # JWKS cache: {tenant_id: (keys, fetched_at)} _jwks_cache: dict[str, tuple[dict[str, Any], float]] = {} _JWKS_TTL = 86400 # 24 hours async def _get_jwks(tenant_id: str) -> dict[str, Any]: """Fetch and cache Entra JWKS keys.""" cached = _jwks_cache.get(tenant_id) if cached and (time.time() - cached[1]) < _JWKS_TTL: return cached[0] url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys" async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get(url) resp.raise_for_status() jwks = resp.json() _jwks_cache[tenant_id] = (jwks, time.time()) return jwks def _find_signing_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]: """Find the key matching the JWT kid header.""" for key in jwks.get("keys", []): if key.get("kid") == kid: return key return None class EntraDriver(IdentityDriver): """Identity driver for direct Entra JWT validation.""" async def authenticate(self) -> AuthResult: token_data = self.config.get("_token_data", {}) raw_token = self.config.get("_raw_token", "") tenant_id = self.config.get("entra_tenant_id", "") expected_audience = self.config.get("entra_client_id", "") if not token_data: return AuthResult( status=AuthResult.STATUS_DENIED, denial_reason="No token in context.", ) # If we have raw_token and tenant_id, perform JWKS verification. if raw_token and tenant_id: try: jwks = await _get_jwks(tenant_id) # Extract kid from unverified header unverified_header = jose_jwt.get_unverified_header(raw_token) kid = unverified_header.get("kid", "") signing_key = _find_signing_key(jwks, kid) if not signing_key: return AuthResult( status=AuthResult.STATUS_DENIED, denial_reason=f"No matching signing key for kid={kid}", ) # Verify signature, exp, nbf, iss, aud verified = jose_jwt.decode( raw_token, signing_key, algorithms=["RS256"], audience=expected_audience, issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0", ) # Use verified claims instead of unverified decode token_data = verified except JWTError as e: return AuthResult( status=AuthResult.STATUS_DENIED, denial_reason=f"JWT verification failed: {e}", ) except httpx.HTTPError as e: logger.warning("JWKS fetch failed, falling back to unverified: %s", e) # Fall through to use the unverified token_data # Extract claims oid = token_data.get("oid", "") tid = token_data.get("tid", tenant_id) roles = token_data.get("roles", []) acrs = token_data.get("acrs", []) amr = token_data.get("amr", []) device_id = token_data.get("deviceid") or token_data.get("device_id") upn = token_data.get("preferred_username") or token_data.get("upn", "") display_name = token_data.get("name", upn) if not oid: return AuthResult( status=AuthResult.STATUS_DENIED, denial_reason="Token missing oid claim.", ) # Check role requirement for requested accord requested_accord = self.config.get("requested_accord", "") required_role = self.config.get("accord_roles", {}).get(requested_accord, "") suffix = self.config.get("elevated_suffix", "-elevated") elevation_active = [r for r in roles if r.endswith(suffix)] if required_role and required_role not in roles: return AuthResult( status=AuthResult.STATUS_PENDING_ELEVATION, elevation_required=ElevationRequired( role=required_role, activation_url="/governance/elevate/", instructions=f"Request elevation to '{required_role}' via POST /governance/elevate/", mechanism="entra_pim", ), ) # Construct DID domain = self.config.get("domain", "guildhouse.dev") principal_did = f"did:web:{domain}:principal:{oid}" # MFA detection mfa_satisfied = "mfa" in amr or "ngcmfa" in amr return AuthResult( status=AuthResult.STATUS_AUTHORIZED, principal_did=principal_did, display_name=display_name, stable_id=oid, token_jti=token_data.get("jti", ""), elevation_active=elevation_active, mfa_satisfied=mfa_satisfied, device_id=device_id, ) async def revoke(self, session_id: str) -> None: logger.info("Entra revoke: %s", session_id)