feat(drivers): add native Entra identity driver
Validates Entra JWTs directly via JWKS verification. Extracts device_id for compliance gating, MFA status, roles, and constructs DID from Entra tenant + oid. Adds device_id field to AuthResult dataclass. Signed-off-by: Tyler King <tking@guildhouse.dev>
This commit is contained in:
parent
1ab47417c9
commit
8196396ce6
5 changed files with 380 additions and 1 deletions
|
|
@ -25,6 +25,7 @@ class AuthResult:
|
||||||
mfa_satisfied: bool = False
|
mfa_satisfied: bool = False
|
||||||
elevation_required: Optional[ElevationRequired] = None
|
elevation_required: Optional[ElevationRequired] = None
|
||||||
denial_reason: str = ""
|
denial_reason: str = ""
|
||||||
|
device_id: Optional[str] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_authorized(self): return self.status == self.STATUS_AUTHORIZED
|
def is_authorized(self): return self.status == self.STATUS_AUTHORIZED
|
||||||
|
|
|
||||||
152
gsap_broker/drivers/entra.py
Normal file
152
gsap_broker/drivers/entra.py
Normal file
|
|
@ -0,0 +1,152 @@
|
||||||
|
# Copyright 2026 Guildhouse Dev
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
"""Entra identity driver — GSAP §2.2.
|
||||||
|
|
||||||
|
Validates Entra-issued JWTs directly via JWKS verification.
|
||||||
|
Extracts device_id for compliance gating, MFA status, roles,
|
||||||
|
and constructs DID from Entra tenant + oid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from jose import JWTError, jwt as jose_jwt
|
||||||
|
|
||||||
|
from .base import AuthResult, ElevationRequired, IdentityDriver
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# JWKS cache: {tenant_id: (keys, fetched_at)}
|
||||||
|
_jwks_cache: dict[str, tuple[dict[str, Any], float]] = {}
|
||||||
|
_JWKS_TTL = 86400 # 24 hours
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_jwks(tenant_id: str) -> dict[str, Any]:
|
||||||
|
"""Fetch and cache Entra JWKS keys."""
|
||||||
|
cached = _jwks_cache.get(tenant_id)
|
||||||
|
if cached and (time.time() - cached[1]) < _JWKS_TTL:
|
||||||
|
return cached[0]
|
||||||
|
|
||||||
|
url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
jwks = resp.json()
|
||||||
|
|
||||||
|
_jwks_cache[tenant_id] = (jwks, time.time())
|
||||||
|
return jwks
|
||||||
|
|
||||||
|
|
||||||
|
def _find_signing_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]:
|
||||||
|
"""Find the key matching the JWT kid header."""
|
||||||
|
for key in jwks.get("keys", []):
|
||||||
|
if key.get("kid") == kid:
|
||||||
|
return key
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class EntraDriver(IdentityDriver):
|
||||||
|
"""Identity driver for direct Entra JWT validation."""
|
||||||
|
|
||||||
|
async def authenticate(self) -> AuthResult:
|
||||||
|
token_data = self.config.get("_token_data", {})
|
||||||
|
raw_token = self.config.get("_raw_token", "")
|
||||||
|
tenant_id = self.config.get("entra_tenant_id", "")
|
||||||
|
expected_audience = self.config.get("entra_client_id", "")
|
||||||
|
|
||||||
|
if not token_data:
|
||||||
|
return AuthResult(
|
||||||
|
status=AuthResult.STATUS_DENIED,
|
||||||
|
denial_reason="No token in context.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we have raw_token and tenant_id, perform JWKS verification.
|
||||||
|
if raw_token and tenant_id:
|
||||||
|
try:
|
||||||
|
jwks = await _get_jwks(tenant_id)
|
||||||
|
|
||||||
|
# Extract kid from unverified header
|
||||||
|
unverified_header = jose_jwt.get_unverified_header(raw_token)
|
||||||
|
kid = unverified_header.get("kid", "")
|
||||||
|
signing_key = _find_signing_key(jwks, kid)
|
||||||
|
if not signing_key:
|
||||||
|
return AuthResult(
|
||||||
|
status=AuthResult.STATUS_DENIED,
|
||||||
|
denial_reason=f"No matching signing key for kid={kid}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify signature, exp, nbf, iss, aud
|
||||||
|
verified = jose_jwt.decode(
|
||||||
|
raw_token,
|
||||||
|
signing_key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
audience=expected_audience,
|
||||||
|
issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0",
|
||||||
|
)
|
||||||
|
# Use verified claims instead of unverified decode
|
||||||
|
token_data = verified
|
||||||
|
except JWTError as e:
|
||||||
|
return AuthResult(
|
||||||
|
status=AuthResult.STATUS_DENIED,
|
||||||
|
denial_reason=f"JWT verification failed: {e}",
|
||||||
|
)
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.warning("JWKS fetch failed, falling back to unverified: %s", e)
|
||||||
|
# Fall through to use the unverified token_data
|
||||||
|
|
||||||
|
# Extract claims
|
||||||
|
oid = token_data.get("oid", "")
|
||||||
|
tid = token_data.get("tid", tenant_id)
|
||||||
|
roles = token_data.get("roles", [])
|
||||||
|
acrs = token_data.get("acrs", [])
|
||||||
|
amr = token_data.get("amr", [])
|
||||||
|
device_id = token_data.get("deviceid") or token_data.get("device_id")
|
||||||
|
upn = token_data.get("preferred_username") or token_data.get("upn", "")
|
||||||
|
display_name = token_data.get("name", upn)
|
||||||
|
|
||||||
|
if not oid:
|
||||||
|
return AuthResult(
|
||||||
|
status=AuthResult.STATUS_DENIED,
|
||||||
|
denial_reason="Token missing oid claim.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check role requirement for requested accord
|
||||||
|
requested_accord = self.config.get("requested_accord", "")
|
||||||
|
required_role = self.config.get("accord_roles", {}).get(requested_accord, "")
|
||||||
|
suffix = self.config.get("elevated_suffix", "-elevated")
|
||||||
|
elevation_active = [r for r in roles if r.endswith(suffix)]
|
||||||
|
|
||||||
|
if required_role and required_role not in roles:
|
||||||
|
return AuthResult(
|
||||||
|
status=AuthResult.STATUS_PENDING_ELEVATION,
|
||||||
|
elevation_required=ElevationRequired(
|
||||||
|
role=required_role,
|
||||||
|
activation_url="/governance/elevate/",
|
||||||
|
instructions=f"Request elevation to '{required_role}' via POST /governance/elevate/",
|
||||||
|
mechanism="entra_pim",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct DID
|
||||||
|
domain = self.config.get("domain", "guildhouse.dev")
|
||||||
|
principal_did = f"did:web:{domain}:principal:{oid}"
|
||||||
|
|
||||||
|
# MFA detection
|
||||||
|
mfa_satisfied = "mfa" in amr or "ngcmfa" in amr
|
||||||
|
|
||||||
|
return AuthResult(
|
||||||
|
status=AuthResult.STATUS_AUTHORIZED,
|
||||||
|
principal_did=principal_did,
|
||||||
|
display_name=display_name,
|
||||||
|
stable_id=oid,
|
||||||
|
token_jti=token_data.get("jti", ""),
|
||||||
|
elevation_active=elevation_active,
|
||||||
|
mfa_satisfied=mfa_satisfied,
|
||||||
|
device_id=device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke(self, session_id: str) -> None:
|
||||||
|
logger.info("Entra revoke: %s", session_id)
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
"""Driver Registry — GSAP §2.5."""
|
"""Driver Registry — GSAP §2.5."""
|
||||||
from .base import IdentityDriver
|
from .base import IdentityDriver
|
||||||
|
from .entra import EntraDriver
|
||||||
from .keycloak import KeycloakDriver
|
from .keycloak import KeycloakDriver
|
||||||
|
|
||||||
_DRIVERS: dict[str, type[IdentityDriver]] = {"keycloak": KeycloakDriver}
|
_DRIVERS: dict[str, type[IdentityDriver]] = {"keycloak": KeycloakDriver, "entra": EntraDriver}
|
||||||
|
|
||||||
class DriverRegistry:
|
class DriverRegistry:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,11 @@ def _extract_token_data(http_request: Request) -> dict:
|
||||||
async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSession = Depends(get_session)):
|
async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSession = Depends(get_session)):
|
||||||
request = body
|
request = body
|
||||||
token_data = _extract_token_data(http_request)
|
token_data = _extract_token_data(http_request)
|
||||||
|
raw_token = ""
|
||||||
|
auth_header = http_request.headers.get("authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
raw_token = auth_header[7:]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
driver = DriverRegistry.get(request.driver_id, config={
|
driver = DriverRegistry.get(request.driver_id, config={
|
||||||
"requested_accord": request.accord_template,
|
"requested_accord": request.accord_template,
|
||||||
|
|
@ -41,6 +46,9 @@ async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSess
|
||||||
"did_template": settings.keycloak_did_template,
|
"did_template": settings.keycloak_did_template,
|
||||||
"elevated_suffix": settings.keycloak_elevated_role_suffix,
|
"elevated_suffix": settings.keycloak_elevated_role_suffix,
|
||||||
"_token_data": token_data,
|
"_token_data": token_data,
|
||||||
|
"_raw_token": raw_token,
|
||||||
|
"entra_tenant_id": settings.entra_tenant_id,
|
||||||
|
"entra_client_id": settings.entra_client_id,
|
||||||
})
|
})
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
|
||||||
217
tests/test_entra_driver.py
Normal file
217
tests/test_entra_driver.py
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
# Copyright 2026 Guildhouse Dev
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
"""Tests for the Entra identity driver."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from jose import jwt as jose_jwt
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from cryptography.hazmat.primitives import serialization
|
||||||
|
from gsap_broker.drivers.entra import EntraDriver, _jwks_cache
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_rsa_keypair():
|
||||||
|
"""Generate an RSA key pair for test JWT signing."""
|
||||||
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||||
|
public_key = private_key.public_key()
|
||||||
|
return private_key, public_key
|
||||||
|
|
||||||
|
|
||||||
|
def _private_key_pem(private_key):
|
||||||
|
return private_key.private_bytes(
|
||||||
|
encoding=serialization.Encoding.PEM,
|
||||||
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
|
encryption_algorithm=serialization.NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _public_numbers_to_jwk(public_key, kid="test-kid-1"):
|
||||||
|
"""Convert RSA public key to JWK dict."""
|
||||||
|
import base64
|
||||||
|
nums = public_key.public_numbers()
|
||||||
|
e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big")
|
||||||
|
n_bytes = nums.n.to_bytes((nums.n.bit_length() + 7) // 8, "big")
|
||||||
|
return {
|
||||||
|
"kty": "RSA",
|
||||||
|
"kid": kid,
|
||||||
|
"use": "sig",
|
||||||
|
"alg": "RS256",
|
||||||
|
"n": base64.urlsafe_b64encode(n_bytes).rstrip(b"=").decode(),
|
||||||
|
"e": base64.urlsafe_b64encode(e_bytes).rstrip(b"=").decode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
PRIVATE_KEY, PUBLIC_KEY = _generate_rsa_keypair()
|
||||||
|
KID = "test-kid-1"
|
||||||
|
TENANT_ID = "test-tenant-id-1234"
|
||||||
|
CLIENT_ID = "test-client-id-5678"
|
||||||
|
JWKS = {"keys": [_public_numbers_to_jwk(PUBLIC_KEY, KID)]}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_token(claims: dict, kid: str = KID, expired: bool = False) -> str:
|
||||||
|
"""Create a signed test JWT."""
|
||||||
|
import datetime
|
||||||
|
now = datetime.datetime.now(datetime.UTC)
|
||||||
|
base_claims = {
|
||||||
|
"iss": f"https://login.microsoftonline.com/{TENANT_ID}/v2.0",
|
||||||
|
"aud": CLIENT_ID,
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"nbf": int(now.timestamp()),
|
||||||
|
"exp": int((now + datetime.timedelta(hours=1)).timestamp()),
|
||||||
|
"oid": "user-oid-1",
|
||||||
|
"tid": TENANT_ID,
|
||||||
|
"preferred_username": "alice@contoso.com",
|
||||||
|
"name": "Alice Smith",
|
||||||
|
"jti": "test-jti",
|
||||||
|
}
|
||||||
|
if expired:
|
||||||
|
base_claims["exp"] = int((now - datetime.timedelta(hours=1)).timestamp())
|
||||||
|
base_claims["nbf"] = int((now - datetime.timedelta(hours=2)).timestamp())
|
||||||
|
base_claims["iat"] = int((now - datetime.timedelta(hours=2)).timestamp())
|
||||||
|
base_claims.update(claims)
|
||||||
|
return jose_jwt.encode(
|
||||||
|
base_claims, _private_key_pem(PRIVATE_KEY), algorithm="RS256", headers={"kid": kid}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _driver_config(raw_token: str = "", extra: dict = None) -> dict:
|
||||||
|
"""Build config dict as the authorize router would."""
|
||||||
|
import base64, json
|
||||||
|
token_data = {}
|
||||||
|
if raw_token:
|
||||||
|
try:
|
||||||
|
payload = raw_token.split(".")[1]
|
||||||
|
payload += "=" * (4 - len(payload) % 4)
|
||||||
|
token_data = json.loads(base64.urlsafe_b64decode(payload))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
config = {
|
||||||
|
"_token_data": token_data,
|
||||||
|
"_raw_token": raw_token,
|
||||||
|
"entra_tenant_id": TENANT_ID,
|
||||||
|
"entra_client_id": CLIENT_ID,
|
||||||
|
"domain": "contoso.com",
|
||||||
|
}
|
||||||
|
if extra:
|
||||||
|
config.update(extra)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_jwks_cache():
|
||||||
|
_jwks_cache.clear()
|
||||||
|
yield
|
||||||
|
_jwks_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_jwks():
|
||||||
|
"""Mock the JWKS fetch to return test keys."""
|
||||||
|
with patch("gsap_broker.drivers.entra._get_jwks", new_callable=AsyncMock) as m:
|
||||||
|
m.return_value = JWKS
|
||||||
|
yield m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_valid_token(mock_jwks):
|
||||||
|
token = _make_token({"roles": ["admin"], "amr": ["pwd", "mfa"]})
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert result.is_authorized
|
||||||
|
assert result.principal_did == "did:web:contoso.com:principal:user-oid-1"
|
||||||
|
assert result.display_name == "Alice Smith"
|
||||||
|
assert result.stable_id == "user-oid-1"
|
||||||
|
assert result.token_jti == "test-jti"
|
||||||
|
assert result.mfa_satisfied is True
|
||||||
|
assert result.device_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_extracts_device_id(mock_jwks):
|
||||||
|
token = _make_token({"deviceid": "device-abc-123"})
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert result.is_authorized
|
||||||
|
assert result.device_id == "device-abc-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_no_device_id(mock_jwks):
|
||||||
|
token = _make_token({})
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert result.is_authorized
|
||||||
|
assert result.device_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_expired_token(mock_jwks):
|
||||||
|
token = _make_token({}, expired=True)
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert not result.is_authorized
|
||||||
|
assert "expired" in result.denial_reason.lower() or "verification failed" in result.denial_reason.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_no_token():
|
||||||
|
driver = EntraDriver(config={"_token_data": {}, "_raw_token": ""})
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert not result.is_authorized
|
||||||
|
assert "No token" in result.denial_reason
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_mfa_detection(mock_jwks):
|
||||||
|
# With MFA
|
||||||
|
token = _make_token({"amr": ["pwd", "mfa"]})
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
assert result.mfa_satisfied is True
|
||||||
|
|
||||||
|
# Without MFA
|
||||||
|
token = _make_token({"amr": ["pwd"]})
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
assert result.mfa_satisfied is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_elevation_required(mock_jwks):
|
||||||
|
token = _make_token({"roles": ["reader"]})
|
||||||
|
config = _driver_config(raw_token=token, extra={
|
||||||
|
"requested_accord": "admin-ops",
|
||||||
|
"accord_roles": {"admin-ops": "admin-role"},
|
||||||
|
})
|
||||||
|
driver = EntraDriver(config=config)
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert result.needs_elevation
|
||||||
|
assert result.elevation_required.role == "admin-role"
|
||||||
|
assert result.elevation_required.mechanism == "entra_pim"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_did_construction(mock_jwks):
|
||||||
|
token = _make_token({"oid": "unique-user-oid"})
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token, extra={"domain": "example.dev"}))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert result.principal_did == "did:web:example.dev:principal:unique-user-oid"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wrong_kid_rejected(mock_jwks):
|
||||||
|
token = _make_token({}, kid="unknown-kid")
|
||||||
|
driver = EntraDriver(config=_driver_config(raw_token=token))
|
||||||
|
result = await driver.authenticate()
|
||||||
|
|
||||||
|
assert not result.is_authorized
|
||||||
|
assert "signing key" in result.denial_reason.lower()
|
||||||
Loading…
Reference in a new issue