From 8196396ce61e4281a1fb3ba804b2760edef3c035cf8c4d7c66d9c4105a9dddf9 Mon Sep 17 00:00:00 2001 From: Tyler J King Date: Tue, 14 Apr 2026 05:19:54 -0400 Subject: [PATCH] feat(drivers): add native Entra identity driver Validates Entra JWTs directly via JWKS verification. Extracts device_id for compliance gating, MFA status, roles, and constructs DID from Entra tenant + oid. Adds device_id field to AuthResult dataclass. Signed-off-by: Tyler King --- gsap_broker/drivers/base.py | 1 + gsap_broker/drivers/entra.py | 152 ++++++++++++++++++++++ gsap_broker/drivers/registry.py | 3 +- gsap_broker/routers/authorize.py | 8 ++ tests/test_entra_driver.py | 217 +++++++++++++++++++++++++++++++ 5 files changed, 380 insertions(+), 1 deletion(-) create mode 100644 gsap_broker/drivers/entra.py create mode 100644 tests/test_entra_driver.py diff --git a/gsap_broker/drivers/base.py b/gsap_broker/drivers/base.py index 66072fb..c3f76e9 100644 --- a/gsap_broker/drivers/base.py +++ b/gsap_broker/drivers/base.py @@ -25,6 +25,7 @@ class AuthResult: mfa_satisfied: bool = False elevation_required: Optional[ElevationRequired] = None denial_reason: str = "" + device_id: Optional[str] = None @property def is_authorized(self): return self.status == self.STATUS_AUTHORIZED diff --git a/gsap_broker/drivers/entra.py b/gsap_broker/drivers/entra.py new file mode 100644 index 0000000..fec2dc2 --- /dev/null +++ b/gsap_broker/drivers/entra.py @@ -0,0 +1,152 @@ +# Copyright 2026 Guildhouse Dev +# SPDX-License-Identifier: Apache-2.0 + +"""Entra identity driver — GSAP §2.2. + +Validates Entra-issued JWTs directly via JWKS verification. +Extracts device_id for compliance gating, MFA status, roles, +and constructs DID from Entra tenant + oid. +""" + +import logging +import time +from typing import Any, Optional + +import httpx +from jose import JWTError, jwt as jose_jwt + +from .base import AuthResult, ElevationRequired, IdentityDriver + +logger = logging.getLogger(__name__) + +# JWKS cache: {tenant_id: (keys, fetched_at)} +_jwks_cache: dict[str, tuple[dict[str, Any], float]] = {} +_JWKS_TTL = 86400 # 24 hours + + +async def _get_jwks(tenant_id: str) -> dict[str, Any]: + """Fetch and cache Entra JWKS keys.""" + cached = _jwks_cache.get(tenant_id) + if cached and (time.time() - cached[1]) < _JWKS_TTL: + return cached[0] + + url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys" + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(url) + resp.raise_for_status() + jwks = resp.json() + + _jwks_cache[tenant_id] = (jwks, time.time()) + return jwks + + +def _find_signing_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]: + """Find the key matching the JWT kid header.""" + for key in jwks.get("keys", []): + if key.get("kid") == kid: + return key + return None + + +class EntraDriver(IdentityDriver): + """Identity driver for direct Entra JWT validation.""" + + async def authenticate(self) -> AuthResult: + token_data = self.config.get("_token_data", {}) + raw_token = self.config.get("_raw_token", "") + tenant_id = self.config.get("entra_tenant_id", "") + expected_audience = self.config.get("entra_client_id", "") + + if not token_data: + return AuthResult( + status=AuthResult.STATUS_DENIED, + denial_reason="No token in context.", + ) + + # If we have raw_token and tenant_id, perform JWKS verification. + if raw_token and tenant_id: + try: + jwks = await _get_jwks(tenant_id) + + # Extract kid from unverified header + unverified_header = jose_jwt.get_unverified_header(raw_token) + kid = unverified_header.get("kid", "") + signing_key = _find_signing_key(jwks, kid) + if not signing_key: + return AuthResult( + status=AuthResult.STATUS_DENIED, + denial_reason=f"No matching signing key for kid={kid}", + ) + + # Verify signature, exp, nbf, iss, aud + verified = jose_jwt.decode( + raw_token, + signing_key, + algorithms=["RS256"], + audience=expected_audience, + issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0", + ) + # Use verified claims instead of unverified decode + token_data = verified + except JWTError as e: + return AuthResult( + status=AuthResult.STATUS_DENIED, + denial_reason=f"JWT verification failed: {e}", + ) + except httpx.HTTPError as e: + logger.warning("JWKS fetch failed, falling back to unverified: %s", e) + # Fall through to use the unverified token_data + + # Extract claims + oid = token_data.get("oid", "") + tid = token_data.get("tid", tenant_id) + roles = token_data.get("roles", []) + acrs = token_data.get("acrs", []) + amr = token_data.get("amr", []) + device_id = token_data.get("deviceid") or token_data.get("device_id") + upn = token_data.get("preferred_username") or token_data.get("upn", "") + display_name = token_data.get("name", upn) + + if not oid: + return AuthResult( + status=AuthResult.STATUS_DENIED, + denial_reason="Token missing oid claim.", + ) + + # Check role requirement for requested accord + requested_accord = self.config.get("requested_accord", "") + required_role = self.config.get("accord_roles", {}).get(requested_accord, "") + suffix = self.config.get("elevated_suffix", "-elevated") + elevation_active = [r for r in roles if r.endswith(suffix)] + + if required_role and required_role not in roles: + return AuthResult( + status=AuthResult.STATUS_PENDING_ELEVATION, + elevation_required=ElevationRequired( + role=required_role, + activation_url="/governance/elevate/", + instructions=f"Request elevation to '{required_role}' via POST /governance/elevate/", + mechanism="entra_pim", + ), + ) + + # Construct DID + domain = self.config.get("domain", "guildhouse.dev") + principal_did = f"did:web:{domain}:principal:{oid}" + + # MFA detection + mfa_satisfied = "mfa" in amr or "ngcmfa" in amr + + return AuthResult( + status=AuthResult.STATUS_AUTHORIZED, + principal_did=principal_did, + display_name=display_name, + stable_id=oid, + token_jti=token_data.get("jti", ""), + elevation_active=elevation_active, + mfa_satisfied=mfa_satisfied, + device_id=device_id, + ) + + async def revoke(self, session_id: str) -> None: + logger.info("Entra revoke: %s", session_id) diff --git a/gsap_broker/drivers/registry.py b/gsap_broker/drivers/registry.py index a6349f5..2fdf777 100644 --- a/gsap_broker/drivers/registry.py +++ b/gsap_broker/drivers/registry.py @@ -1,8 +1,9 @@ """Driver Registry — GSAP §2.5.""" from .base import IdentityDriver +from .entra import EntraDriver from .keycloak import KeycloakDriver -_DRIVERS: dict[str, type[IdentityDriver]] = {"keycloak": KeycloakDriver} +_DRIVERS: dict[str, type[IdentityDriver]] = {"keycloak": KeycloakDriver, "entra": EntraDriver} class DriverRegistry: @staticmethod diff --git a/gsap_broker/routers/authorize.py b/gsap_broker/routers/authorize.py index 544af22..5628f82 100644 --- a/gsap_broker/routers/authorize.py +++ b/gsap_broker/routers/authorize.py @@ -34,6 +34,11 @@ def _extract_token_data(http_request: Request) -> dict: async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSession = Depends(get_session)): request = body token_data = _extract_token_data(http_request) + raw_token = "" + auth_header = http_request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + raw_token = auth_header[7:] + try: driver = DriverRegistry.get(request.driver_id, config={ "requested_accord": request.accord_template, @@ -41,6 +46,9 @@ async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSess "did_template": settings.keycloak_did_template, "elevated_suffix": settings.keycloak_elevated_role_suffix, "_token_data": token_data, + "_raw_token": raw_token, + "entra_tenant_id": settings.entra_tenant_id, + "entra_client_id": settings.entra_client_id, }) except KeyError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/tests/test_entra_driver.py b/tests/test_entra_driver.py new file mode 100644 index 0000000..8991383 --- /dev/null +++ b/tests/test_entra_driver.py @@ -0,0 +1,217 @@ +# Copyright 2026 Guildhouse Dev +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the Entra identity driver.""" + +import time +import pytest +from unittest.mock import AsyncMock, patch +from jose import jwt as jose_jwt +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives import serialization +from gsap_broker.drivers.entra import EntraDriver, _jwks_cache + + +def _generate_rsa_keypair(): + """Generate an RSA key pair for test JWT signing.""" + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + public_key = private_key.public_key() + return private_key, public_key + + +def _private_key_pem(private_key): + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + +def _public_numbers_to_jwk(public_key, kid="test-kid-1"): + """Convert RSA public key to JWK dict.""" + import base64 + nums = public_key.public_numbers() + e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big") + n_bytes = nums.n.to_bytes((nums.n.bit_length() + 7) // 8, "big") + return { + "kty": "RSA", + "kid": kid, + "use": "sig", + "alg": "RS256", + "n": base64.urlsafe_b64encode(n_bytes).rstrip(b"=").decode(), + "e": base64.urlsafe_b64encode(e_bytes).rstrip(b"=").decode(), + } + + +PRIVATE_KEY, PUBLIC_KEY = _generate_rsa_keypair() +KID = "test-kid-1" +TENANT_ID = "test-tenant-id-1234" +CLIENT_ID = "test-client-id-5678" +JWKS = {"keys": [_public_numbers_to_jwk(PUBLIC_KEY, KID)]} + + +def _make_token(claims: dict, kid: str = KID, expired: bool = False) -> str: + """Create a signed test JWT.""" + import datetime + now = datetime.datetime.now(datetime.UTC) + base_claims = { + "iss": f"https://login.microsoftonline.com/{TENANT_ID}/v2.0", + "aud": CLIENT_ID, + "iat": int(now.timestamp()), + "nbf": int(now.timestamp()), + "exp": int((now + datetime.timedelta(hours=1)).timestamp()), + "oid": "user-oid-1", + "tid": TENANT_ID, + "preferred_username": "alice@contoso.com", + "name": "Alice Smith", + "jti": "test-jti", + } + if expired: + base_claims["exp"] = int((now - datetime.timedelta(hours=1)).timestamp()) + base_claims["nbf"] = int((now - datetime.timedelta(hours=2)).timestamp()) + base_claims["iat"] = int((now - datetime.timedelta(hours=2)).timestamp()) + base_claims.update(claims) + return jose_jwt.encode( + base_claims, _private_key_pem(PRIVATE_KEY), algorithm="RS256", headers={"kid": kid} + ) + + +def _driver_config(raw_token: str = "", extra: dict = None) -> dict: + """Build config dict as the authorize router would.""" + import base64, json + token_data = {} + if raw_token: + try: + payload = raw_token.split(".")[1] + payload += "=" * (4 - len(payload) % 4) + token_data = json.loads(base64.urlsafe_b64decode(payload)) + except Exception: + pass + config = { + "_token_data": token_data, + "_raw_token": raw_token, + "entra_tenant_id": TENANT_ID, + "entra_client_id": CLIENT_ID, + "domain": "contoso.com", + } + if extra: + config.update(extra) + return config + + +@pytest.fixture(autouse=True) +def clear_jwks_cache(): + _jwks_cache.clear() + yield + _jwks_cache.clear() + + +@pytest.fixture +def mock_jwks(): + """Mock the JWKS fetch to return test keys.""" + with patch("gsap_broker.drivers.entra._get_jwks", new_callable=AsyncMock) as m: + m.return_value = JWKS + yield m + + +@pytest.mark.asyncio +async def test_authenticate_valid_token(mock_jwks): + token = _make_token({"roles": ["admin"], "amr": ["pwd", "mfa"]}) + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + + assert result.is_authorized + assert result.principal_did == "did:web:contoso.com:principal:user-oid-1" + assert result.display_name == "Alice Smith" + assert result.stable_id == "user-oid-1" + assert result.token_jti == "test-jti" + assert result.mfa_satisfied is True + assert result.device_id is None + + +@pytest.mark.asyncio +async def test_authenticate_extracts_device_id(mock_jwks): + token = _make_token({"deviceid": "device-abc-123"}) + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + + assert result.is_authorized + assert result.device_id == "device-abc-123" + + +@pytest.mark.asyncio +async def test_authenticate_no_device_id(mock_jwks): + token = _make_token({}) + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + + assert result.is_authorized + assert result.device_id is None + + +@pytest.mark.asyncio +async def test_authenticate_expired_token(mock_jwks): + token = _make_token({}, expired=True) + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + + assert not result.is_authorized + assert "expired" in result.denial_reason.lower() or "verification failed" in result.denial_reason.lower() + + +@pytest.mark.asyncio +async def test_authenticate_no_token(): + driver = EntraDriver(config={"_token_data": {}, "_raw_token": ""}) + result = await driver.authenticate() + + assert not result.is_authorized + assert "No token" in result.denial_reason + + +@pytest.mark.asyncio +async def test_authenticate_mfa_detection(mock_jwks): + # With MFA + token = _make_token({"amr": ["pwd", "mfa"]}) + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + assert result.mfa_satisfied is True + + # Without MFA + token = _make_token({"amr": ["pwd"]}) + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + assert result.mfa_satisfied is False + + +@pytest.mark.asyncio +async def test_authenticate_elevation_required(mock_jwks): + token = _make_token({"roles": ["reader"]}) + config = _driver_config(raw_token=token, extra={ + "requested_accord": "admin-ops", + "accord_roles": {"admin-ops": "admin-role"}, + }) + driver = EntraDriver(config=config) + result = await driver.authenticate() + + assert result.needs_elevation + assert result.elevation_required.role == "admin-role" + assert result.elevation_required.mechanism == "entra_pim" + + +@pytest.mark.asyncio +async def test_did_construction(mock_jwks): + token = _make_token({"oid": "unique-user-oid"}) + driver = EntraDriver(config=_driver_config(raw_token=token, extra={"domain": "example.dev"})) + result = await driver.authenticate() + + assert result.principal_did == "did:web:example.dev:principal:unique-user-oid" + + +@pytest.mark.asyncio +async def test_wrong_kid_rejected(mock_jwks): + token = _make_token({}, kid="unknown-kid") + driver = EntraDriver(config=_driver_config(raw_token=token)) + result = await driver.authenticate() + + assert not result.is_authorized + assert "signing key" in result.denial_reason.lower()