fix(drivers): JWKS verification for Keycloak, remove Entra fallback, gate on_behalf_of

C-1: Keycloak driver now verifies JWT signatures via JWKS.
     Forged tokens are rejected. Previously any base64 JWT was accepted.
C-2: on_behalf_of requires gsap:impersonate role in JWT claims.
C-3: Entra driver denies on JWKS failure (no unverified fallback).
H-10: JWKS cache refreshes on kid miss for key rotation.

Shared JWKSVerifier used by both drivers. alg=none blocked.
iss, aud, exp validated for all tokens.

Signed-off-by: Tyler King <tking@guildhouse.dev>
This commit is contained in:
Tyler J King 2026-04-14 07:51:38 -04:00
parent 4dff879c84
commit 5015f3dd43
6 changed files with 433 additions and 130 deletions

View file

@ -6,98 +6,56 @@
Validates Entra-issued JWTs directly via JWKS verification. Validates Entra-issued JWTs directly via JWKS verification.
Extracts device_id for compliance gating, MFA status, roles, Extracts device_id for compliance gating, MFA status, roles,
and constructs DID from Entra tenant + oid. and constructs DID from Entra tenant + oid.
Fix C-3: JWKS fetch failure results in denial. Never falls back
to unverified claims.
Fix H-10: JWKS cache refreshes on kid miss for key rotation.
""" """
import logging import logging
import time from typing import Optional
from typing import Any, Optional
import httpx
from jose import JWTError, jwt as jose_jwt
from .base import AuthResult, ElevationRequired, IdentityDriver from .base import AuthResult, ElevationRequired, IdentityDriver
from .jwks import AuthenticationError, JWKSVerifier
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# JWKS cache: {tenant_id: (keys, fetched_at)}
_jwks_cache: dict[str, tuple[dict[str, Any], float]] = {}
_JWKS_TTL = 86400 # 24 hours
async def _get_jwks(tenant_id: str) -> dict[str, Any]:
"""Fetch and cache Entra JWKS keys."""
cached = _jwks_cache.get(tenant_id)
if cached and (time.time() - cached[1]) < _JWKS_TTL:
return cached[0]
url = f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys"
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(url)
resp.raise_for_status()
jwks = resp.json()
_jwks_cache[tenant_id] = (jwks, time.time())
return jwks
def _find_signing_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]:
"""Find the key matching the JWT kid header."""
for key in jwks.get("keys", []):
if key.get("kid") == kid:
return key
return None
class EntraDriver(IdentityDriver): class EntraDriver(IdentityDriver):
"""Identity driver for direct Entra JWT validation.""" """Identity driver for direct Entra JWT validation."""
async def authenticate(self) -> AuthResult: async def authenticate(self) -> AuthResult:
token_data = self.config.get("_token_data", {})
raw_token = self.config.get("_raw_token", "") raw_token = self.config.get("_raw_token", "")
tenant_id = self.config.get("entra_tenant_id", "") tenant_id = self.config.get("entra_tenant_id", "")
expected_audience = self.config.get("entra_client_id", "") expected_audience = self.config.get("entra_client_id", "")
if not token_data: if not raw_token:
return AuthResult( return AuthResult(
status=AuthResult.STATUS_DENIED, status=AuthResult.STATUS_DENIED,
denial_reason="No token in context.", denial_reason="No token in context.",
) )
# If we have raw_token and tenant_id, perform JWKS verification. if not tenant_id:
if raw_token and tenant_id: return AuthResult(
try: status=AuthResult.STATUS_DENIED,
jwks = await _get_jwks(tenant_id) denial_reason="Entra tenant_id not configured.",
)
# Extract kid from unverified header # Fix C-3: verify via JWKS — no fallback on failure
unverified_header = jose_jwt.get_unverified_header(raw_token) verifier = JWKSVerifier(
kid = unverified_header.get("kid", "") jwks_url=f"https://login.microsoftonline.com/{tenant_id}/discovery/v2.0/keys",
signing_key = _find_signing_key(jwks, kid) audience=expected_audience,
if not signing_key: issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0",
return AuthResult( )
status=AuthResult.STATUS_DENIED, try:
denial_reason=f"No matching signing key for kid={kid}", token_data = await verifier.verify_or_refresh(raw_token)
) except AuthenticationError as e:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason=str(e),
)
# Verify signature, exp, nbf, iss, aud # Extract claims from VERIFIED token data
verified = jose_jwt.decode(
raw_token,
signing_key,
algorithms=["RS256"],
audience=expected_audience,
issuer=f"https://login.microsoftonline.com/{tenant_id}/v2.0",
)
# Use verified claims instead of unverified decode
token_data = verified
except JWTError as e:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason=f"JWT verification failed: {e}",
)
except httpx.HTTPError as e:
logger.warning("JWKS fetch failed, falling back to unverified: %s", e)
# Fall through to use the unverified token_data
# Extract claims
oid = token_data.get("oid", "") oid = token_data.get("oid", "")
tid = token_data.get("tid", tenant_id) tid = token_data.get("tid", tenant_id)
roles = token_data.get("roles", []) roles = token_data.get("roles", [])

137
gsap_broker/drivers/jwks.py Normal file
View file

@ -0,0 +1,137 @@
# Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Shared JWKS verification for identity drivers.
Fetches, caches, and verifies JWTs against JWKS endpoints.
Used by both Keycloak and Entra drivers.
SECURITY: Never falls back to unverified claims on JWKS failure.
A JWKS fetch failure MUST result in authentication denial.
Fix C-1: Keycloak tokens are now verified via this module.
Fix C-3: Entra tokens are denied on JWKS failure (no fallback).
Fix H-10: JWKS cache refreshes on kid miss for key rotation.
"""
import logging
import time
from typing import Any, Optional
import httpx
from jose import JWTError, jwt as jose_jwt
logger = logging.getLogger(__name__)
class AuthenticationError(Exception):
"""JWT verification failed. The request MUST be denied."""
class JWKSKeyNotFound(AuthenticationError):
"""The JWT's kid does not match any key in the cached JWKS."""
class JWKSVerifier:
"""Verifies JWTs against a remote JWKS endpoint.
Cache TTL defaults to 1 hour. On kid miss, the cache is
invalidated and JWKS is re-fetched once before rejecting.
"""
def __init__(
self,
jwks_url: str,
audience: str,
issuer: str,
cache_ttl: int = 3600,
):
self._jwks_url = jwks_url
self._audience = audience
self._issuer = issuer
self._cache_ttl = cache_ttl
self._jwks_cache: Optional[dict[str, Any]] = None
self._cache_fetched_at: float = 0.0
async def verify_token(self, raw_token: str) -> dict[str, Any]:
"""Verify JWT signature and standard claims.
Returns the verified claims dict.
Raises AuthenticationError on ANY failure NEVER falls back
to unverified claims.
"""
if not raw_token:
raise AuthenticationError("No token provided")
jwks = await self._fetch_jwks()
try:
unverified_header = jose_jwt.get_unverified_header(raw_token)
except JWTError as e:
raise AuthenticationError(f"Malformed JWT header: {e}")
kid = unverified_header.get("kid", "")
signing_key = self._find_key(jwks, kid)
if signing_key is None:
raise JWKSKeyNotFound(f"No matching key for kid={kid}")
try:
# algorithms=["RS256"] blocks alg=none and HMAC confusion
return jose_jwt.decode(
raw_token,
signing_key,
algorithms=["RS256"],
audience=self._audience,
issuer=self._issuer,
options={"require_exp": True},
)
except JWTError as e:
raise AuthenticationError(f"JWT verification failed: {e}")
async def verify_or_refresh(self, raw_token: str) -> dict[str, Any]:
"""Verify with cache; on kid miss, refresh JWKS once and retry.
Fix H-10: handles key rotation gracefully.
"""
try:
return await self.verify_token(raw_token)
except JWKSKeyNotFound:
# kid not in cache — force refresh and retry once
self._invalidate_cache()
return await self.verify_token(raw_token)
async def _fetch_jwks(self) -> dict[str, Any]:
"""Fetch and cache JWKS. Raises on failure — never falls back."""
if self._cache_valid():
return self._jwks_cache
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(self._jwks_url)
resp.raise_for_status()
self._jwks_cache = resp.json()
self._cache_fetched_at = time.time()
return self._jwks_cache
except Exception as e:
# Fix C-3: NEVER fall back — deny on failure
raise AuthenticationError(
f"JWKS fetch failed from {self._jwks_url}: {e}. "
"Cannot verify token — denying."
)
def _cache_valid(self) -> bool:
return (
self._jwks_cache is not None
and (time.time() - self._cache_fetched_at) < self._cache_ttl
)
def _invalidate_cache(self) -> None:
self._jwks_cache = None
self._cache_fetched_at = 0.0
@staticmethod
def _find_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]:
for key in jwks.get("keys", []):
if key.get("kid") == kid:
return key
return None

View file

@ -1,14 +1,47 @@
"""Keycloak identity driver — GSAP §2.2.""" # Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Keycloak identity driver — GSAP §2.2.
Fix C-1: JWT signatures are now verified via JWKS.
Previously this driver accepted any base64-decoded JWT without
signature verification. Now uses shared JWKSVerifier.
"""
import logging import logging
from .base import IdentityDriver, AuthResult, ElevationRequired from .base import IdentityDriver, AuthResult, ElevationRequired
from .jwks import AuthenticationError, JWKSVerifier
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class KeycloakDriver(IdentityDriver): class KeycloakDriver(IdentityDriver):
async def authenticate(self) -> AuthResult: async def authenticate(self) -> AuthResult:
token_data = self.config.get("_token_data", {}) raw_token = self.config.get("_raw_token", "")
if not token_data: keycloak_url = self.config.get("keycloak_url", "http://localhost:8080")
return AuthResult(status=AuthResult.STATUS_DENIED, denial_reason="No token in context.") keycloak_realm = self.config.get("keycloak_realm", "substrate")
keycloak_client_id = self.config.get("keycloak_client_id", "")
# Fix C-1: verify JWT signature via JWKS before trusting claims.
if raw_token and keycloak_url and keycloak_realm:
verifier = JWKSVerifier(
jwks_url=f"{keycloak_url}/realms/{keycloak_realm}/protocol/openid-connect/certs",
audience=keycloak_client_id,
issuer=f"{keycloak_url}/realms/{keycloak_realm}",
)
try:
token_data = await verifier.verify_or_refresh(raw_token)
except AuthenticationError as e:
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason=str(e),
)
else:
# No raw token or no Keycloak config — cannot verify
return AuthResult(
status=AuthResult.STATUS_DENIED,
denial_reason="No token or Keycloak configuration missing.",
)
realm_roles = token_data.get("realm_access", {}).get("roles", []) realm_roles = token_data.get("realm_access", {}).get("roles", [])
requested_accord = self.config.get("requested_accord", "") requested_accord = self.config.get("requested_accord", "")

View file

@ -60,6 +60,10 @@ async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSess
"_raw_token": raw_token, "_raw_token": raw_token,
"entra_tenant_id": settings.entra_tenant_id, "entra_tenant_id": settings.entra_tenant_id,
"entra_client_id": settings.entra_client_id, "entra_client_id": settings.entra_client_id,
# Fix C-1: Keycloak driver needs these for JWKS verification
"keycloak_url": settings.keycloak_url,
"keycloak_realm": settings.keycloak_realm,
"keycloak_client_id": settings.keycloak_admin_client_id,
}) })
except KeyError as e: except KeyError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@ -136,7 +140,17 @@ async def authorize(body: AuthorizeRequest, http_request: Request, db: AsyncSess
expires = now + timedelta(minutes=settings.ac_ttl_minutes) expires = now + timedelta(minutes=settings.ac_ttl_minutes)
ctx_id = uuid.uuid4() ctx_id = uuid.uuid4()
# on_behalf_of: trusted caller (Bascule SA) asserts who the AC is for # Fix C-2: on_behalf_of requires gsap:impersonate role
if request.on_behalf_of:
caller_roles = getattr(auth_result, "elevation_active", [])
# Check for impersonation role in JWT roles (passed through auth_result)
token_roles = token_data.get("roles", []) + token_data.get("realm_access", {}).get("roles", [])
if "gsap:impersonate" not in token_roles:
raise HTTPException(
status_code=403,
detail="on_behalf_of requires gsap:impersonate role.",
)
principal_did = request.on_behalf_of or auth_result.principal_did principal_did = request.on_behalf_of or auth_result.principal_did
display_name = request.on_behalf_of.rsplit("/", 1)[-1] if request.on_behalf_of else auth_result.display_name display_name = request.on_behalf_of.rsplit("/", 1)[-1] if request.on_behalf_of else auth_result.display_name

View file

@ -1,19 +1,18 @@
# Copyright 2026 Guildhouse Dev # Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Tests for the Entra identity driver.""" """Tests for the Entra identity driver — C-3, H-10."""
import time
import pytest import pytest
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from jose import jwt as jose_jwt from jose import jwt as jose_jwt
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from gsap_broker.drivers.entra import EntraDriver, _jwks_cache from gsap_broker.drivers.entra import EntraDriver
from gsap_broker.drivers.jwks import AuthenticationError
def _generate_rsa_keypair(): def _generate_rsa_keypair():
"""Generate an RSA key pair for test JWT signing."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key() public_key = private_key.public_key()
return private_key, public_key return private_key, public_key
@ -28,7 +27,6 @@ def _private_key_pem(private_key):
def _public_numbers_to_jwk(public_key, kid="test-kid-1"): def _public_numbers_to_jwk(public_key, kid="test-kid-1"):
"""Convert RSA public key to JWK dict."""
import base64 import base64
nums = public_key.public_numbers() nums = public_key.public_numbers()
e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big") e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big")
@ -51,7 +49,6 @@ JWKS = {"keys": [_public_numbers_to_jwk(PUBLIC_KEY, KID)]}
def _make_token(claims: dict, kid: str = KID, expired: bool = False) -> str: def _make_token(claims: dict, kid: str = KID, expired: bool = False) -> str:
"""Create a signed test JWT."""
import datetime import datetime
now = datetime.datetime.now(datetime.UTC) now = datetime.datetime.now(datetime.UTC)
base_claims = { base_claims = {
@ -77,18 +74,7 @@ def _make_token(claims: dict, kid: str = KID, expired: bool = False) -> str:
def _driver_config(raw_token: str = "", extra: dict = None) -> dict: def _driver_config(raw_token: str = "", extra: dict = None) -> dict:
"""Build config dict as the authorize router would."""
import base64, json
token_data = {}
if raw_token:
try:
payload = raw_token.split(".")[1]
payload += "=" * (4 - len(payload) % 4)
token_data = json.loads(base64.urlsafe_b64decode(payload))
except Exception:
pass
config = { config = {
"_token_data": token_data,
"_raw_token": raw_token, "_raw_token": raw_token,
"entra_tenant_id": TENANT_ID, "entra_tenant_id": TENANT_ID,
"entra_client_id": CLIENT_ID, "entra_client_id": CLIENT_ID,
@ -99,84 +85,71 @@ def _driver_config(raw_token: str = "", extra: dict = None) -> dict:
return config return config
@pytest.fixture(autouse=True)
def clear_jwks_cache():
_jwks_cache.clear()
yield
_jwks_cache.clear()
@pytest.fixture @pytest.fixture
def mock_jwks(): def mock_jwks_fetch():
"""Mock the JWKS fetch to return test keys.""" """Mock the JWKS HTTP fetch to return test keys."""
with patch("gsap_broker.drivers.entra._get_jwks", new_callable=AsyncMock) as m: with patch("gsap_broker.drivers.jwks.httpx.AsyncClient") as mock_http:
m.return_value = JWKS import unittest.mock
yield m mock_resp = unittest.mock.MagicMock()
mock_resp.json.return_value = JWKS
mock_resp.raise_for_status = unittest.mock.MagicMock()
ctx_manager = AsyncMock()
ctx_manager.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
mock_http.return_value = ctx_manager
yield mock_http
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_valid_token(mock_jwks): async def test_authenticate_valid_token(mock_jwks_fetch):
token = _make_token({"roles": ["admin"], "amr": ["pwd", "mfa"]}) token = _make_token({"roles": ["admin"], "amr": ["pwd", "mfa"]})
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
assert result.is_authorized assert result.is_authorized
assert result.principal_did == "did:web:contoso.com:principal:user-oid-1" assert result.principal_did == "did:web:contoso.com:principal:user-oid-1"
assert result.display_name == "Alice Smith"
assert result.stable_id == "user-oid-1"
assert result.token_jti == "test-jti"
assert result.mfa_satisfied is True assert result.mfa_satisfied is True
assert result.device_id is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_extracts_device_id(mock_jwks): async def test_authenticate_extracts_device_id(mock_jwks_fetch):
token = _make_token({"deviceid": "device-abc-123"}) token = _make_token({"deviceid": "device-abc-123"})
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
assert result.is_authorized
assert result.device_id == "device-abc-123" assert result.device_id == "device-abc-123"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_no_device_id(mock_jwks): async def test_authenticate_no_device_id(mock_jwks_fetch):
token = _make_token({}) token = _make_token({})
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
assert result.is_authorized
assert result.device_id is None assert result.device_id is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_expired_token(mock_jwks): async def test_authenticate_expired_token(mock_jwks_fetch):
token = _make_token({}, expired=True) token = _make_token({}, expired=True)
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
assert not result.is_authorized assert not result.is_authorized
assert "expired" in result.denial_reason.lower() or "verification failed" in result.denial_reason.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_no_token(): async def test_authenticate_no_token():
driver = EntraDriver(config={"_token_data": {}, "_raw_token": ""}) driver = EntraDriver(config={"_raw_token": "", "entra_tenant_id": TENANT_ID})
result = await driver.authenticate() result = await driver.authenticate()
assert not result.is_authorized assert not result.is_authorized
assert "No token" in result.denial_reason assert "No token" in result.denial_reason
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_mfa_detection(mock_jwks): async def test_authenticate_mfa_detection(mock_jwks_fetch):
# With MFA
token = _make_token({"amr": ["pwd", "mfa"]}) token = _make_token({"amr": ["pwd", "mfa"]})
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
assert result.mfa_satisfied is True assert result.mfa_satisfied is True
# Without MFA
token = _make_token({"amr": ["pwd"]}) token = _make_token({"amr": ["pwd"]})
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
@ -184,7 +157,7 @@ async def test_authenticate_mfa_detection(mock_jwks):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_authenticate_elevation_required(mock_jwks): async def test_authenticate_elevation_required(mock_jwks_fetch):
token = _make_token({"roles": ["reader"]}) token = _make_token({"roles": ["reader"]})
config = _driver_config(raw_token=token, extra={ config = _driver_config(raw_token=token, extra={
"requested_accord": "admin-ops", "requested_accord": "admin-ops",
@ -192,26 +165,41 @@ async def test_authenticate_elevation_required(mock_jwks):
}) })
driver = EntraDriver(config=config) driver = EntraDriver(config=config)
result = await driver.authenticate() result = await driver.authenticate()
assert result.needs_elevation assert result.needs_elevation
assert result.elevation_required.role == "admin-role"
assert result.elevation_required.mechanism == "entra_pim"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_did_construction(mock_jwks): async def test_did_construction(mock_jwks_fetch):
token = _make_token({"oid": "unique-user-oid"}) token = _make_token({"oid": "unique-user-oid"})
driver = EntraDriver(config=_driver_config(raw_token=token, extra={"domain": "example.dev"})) driver = EntraDriver(config=_driver_config(raw_token=token, extra={"domain": "example.dev"}))
result = await driver.authenticate() result = await driver.authenticate()
assert result.principal_did == "did:web:example.dev:principal:unique-user-oid" assert result.principal_did == "did:web:example.dev:principal:unique-user-oid"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wrong_kid_rejected(mock_jwks): async def test_wrong_kid_rejected_then_refreshed(mock_jwks_fetch):
"""H-10: kid miss triggers JWKS refresh. With only one JWKS response,
the second fetch still has the same keys, so unknown kid is rejected."""
token = _make_token({}, kid="unknown-kid") token = _make_token({}, kid="unknown-kid")
driver = EntraDriver(config=_driver_config(raw_token=token)) driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate() result = await driver.authenticate()
assert not result.is_authorized assert not result.is_authorized
assert "signing key" in result.denial_reason.lower() assert "signing key" in result.denial_reason.lower() or "key" in result.denial_reason.lower()
@pytest.mark.asyncio
async def test_jwks_failure_denies_no_fallback():
"""C-3: JWKS fetch failure results in denial, no fallback."""
with patch("gsap_broker.drivers.jwks.httpx.AsyncClient") as mock_http:
ctx_manager = AsyncMock()
ctx_manager.__aenter__.return_value.get = AsyncMock(
side_effect=Exception("Network unreachable")
)
mock_http.return_value = ctx_manager
token = _make_token({})
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
assert "JWKS fetch failed" in result.denial_reason

View file

@ -0,0 +1,173 @@
# Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Tests for the Keycloak identity driver — C-1: JWKS verification."""
import datetime
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from jose import jwt as jose_jwt
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from gsap_broker.drivers.keycloak import KeycloakDriver
def _generate_rsa_keypair():
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
return private_key, private_key.public_key()
def _private_key_pem(private_key):
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def _public_numbers_to_jwk(public_key, kid="kc-kid-1"):
import base64
nums = public_key.public_numbers()
e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big")
n_bytes = nums.n.to_bytes((nums.n.bit_length() + 7) // 8, "big")
return {
"kty": "RSA", "kid": kid, "use": "sig", "alg": "RS256",
"n": base64.urlsafe_b64encode(n_bytes).rstrip(b"=").decode(),
"e": base64.urlsafe_b64encode(e_bytes).rstrip(b"=").decode(),
}
PRIVATE_KEY, PUBLIC_KEY = _generate_rsa_keypair()
KID = "kc-kid-1"
KC_URL = "http://keycloak.test:8080"
KC_REALM = "test-realm"
KC_CLIENT_ID = "test-kc-client"
JWKS = {"keys": [_public_numbers_to_jwk(PUBLIC_KEY, KID)]}
def _make_kc_token(claims: dict, kid: str = KID, expired: bool = False) -> str:
now = datetime.datetime.now(datetime.UTC)
base_claims = {
"iss": f"{KC_URL}/realms/{KC_REALM}",
"aud": KC_CLIENT_ID,
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"exp": int((now + datetime.timedelta(hours=1)).timestamp()),
"sub": "user-sub-1",
"preferred_username": "bob",
"name": "Bob Smith",
"jti": "kc-jti-1",
"realm_access": {"roles": ["user"]},
}
if expired:
base_claims["exp"] = int((now - datetime.timedelta(hours=1)).timestamp())
base_claims["nbf"] = int((now - datetime.timedelta(hours=2)).timestamp())
base_claims["iat"] = int((now - datetime.timedelta(hours=2)).timestamp())
base_claims.update(claims)
return jose_jwt.encode(
base_claims, _private_key_pem(PRIVATE_KEY), algorithm="RS256", headers={"kid": kid}
)
def _driver_config(raw_token: str = "") -> dict:
return {
"_raw_token": raw_token,
"keycloak_url": KC_URL,
"keycloak_realm": KC_REALM,
"keycloak_client_id": KC_CLIENT_ID,
"domain": "example.com",
"did_template": "did:web:{domain}/principal/{alias}",
"elevated_suffix": "-elevated",
}
@pytest.fixture
def mock_jwks_fetch():
with patch("gsap_broker.drivers.jwks.httpx.AsyncClient") as mock_http:
mock_resp = MagicMock()
mock_resp.json.return_value = JWKS
mock_resp.raise_for_status = MagicMock()
ctx_manager = AsyncMock()
ctx_manager.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
mock_http.return_value = ctx_manager
yield mock_http
@pytest.mark.asyncio
async def test_valid_keycloak_jwt_accepted(mock_jwks_fetch):
"""C-1: Valid signed Keycloak JWT is accepted."""
token = _make_kc_token({})
driver = KeycloakDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.is_authorized
assert "bob" in result.principal_did
@pytest.mark.asyncio
async def test_forged_keycloak_jwt_rejected(mock_jwks_fetch):
"""C-1: Forged JWT (wrong signature) is rejected."""
# Create a token signed with a DIFFERENT key
other_key, _ = _generate_rsa_keypair()
now = datetime.datetime.now(datetime.UTC)
forged = jose_jwt.encode(
{
"iss": f"{KC_URL}/realms/{KC_REALM}",
"aud": KC_CLIENT_ID,
"exp": int((now + datetime.timedelta(hours=1)).timestamp()),
"sub": "attacker",
"preferred_username": "hacker",
},
_private_key_pem(other_key),
algorithm="RS256",
headers={"kid": KID},
)
driver = KeycloakDriver(config=_driver_config(raw_token=forged))
result = await driver.authenticate()
assert not result.is_authorized
assert "verification failed" in result.denial_reason.lower() or "signature" in result.denial_reason.lower()
@pytest.mark.asyncio
async def test_expired_keycloak_jwt_rejected(mock_jwks_fetch):
token = _make_kc_token({}, expired=True)
driver = KeycloakDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
@pytest.mark.asyncio
async def test_no_token_rejected():
driver = KeycloakDriver(config=_driver_config(raw_token=""))
result = await driver.authenticate()
assert not result.is_authorized
@pytest.mark.asyncio
async def test_jwks_unreachable_rejected():
"""JWKS fetch failure denies — no fallback."""
with patch("gsap_broker.drivers.jwks.httpx.AsyncClient") as mock_http:
ctx = AsyncMock()
ctx.__aenter__.return_value.get = AsyncMock(side_effect=Exception("DNS failure"))
mock_http.return_value = ctx
token = _make_kc_token({})
driver = KeycloakDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
assert "JWKS fetch failed" in result.denial_reason
@pytest.mark.asyncio
async def test_alg_none_rejected(mock_jwks_fetch):
"""alg=none attack is blocked."""
import base64, json
header = base64.urlsafe_b64encode(json.dumps({"alg": "none", "typ": "JWT"}).encode()).rstrip(b"=").decode()
payload = base64.urlsafe_b64encode(json.dumps({
"sub": "attacker", "iss": f"{KC_URL}/realms/{KC_REALM}",
"aud": KC_CLIENT_ID, "exp": 9999999999,
}).encode()).rstrip(b"=").decode()
forged = f"{header}.{payload}."
driver = KeycloakDriver(config=_driver_config(raw_token=forged))
result = await driver.authenticate()
assert not result.is_authorized