fastapi-gsap/tests/test_keycloak_driver.py
Tyler J King 5015f3dd43 fix(drivers): JWKS verification for Keycloak, remove Entra fallback, gate on_behalf_of
C-1: Keycloak driver now verifies JWT signatures via JWKS.
     Forged tokens are rejected. Previously any base64 JWT was accepted.
C-2: on_behalf_of requires gsap:impersonate role in JWT claims.
C-3: Entra driver denies on JWKS failure (no unverified fallback).
H-10: JWKS cache refreshes on kid miss for key rotation.

Shared JWKSVerifier used by both drivers. alg=none blocked.
iss, aud, exp validated for all tokens.

Signed-off-by: Tyler King <tking@guildhouse.dev>
2026-04-14 07:51:38 -04:00

173 lines
6.1 KiB
Python

# Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Tests for the Keycloak identity driver — C-1: JWKS verification."""
import datetime
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from jose import jwt as jose_jwt
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from gsap_broker.drivers.keycloak import KeycloakDriver
def _generate_rsa_keypair():
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
return private_key, private_key.public_key()
def _private_key_pem(private_key):
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def _public_numbers_to_jwk(public_key, kid="kc-kid-1"):
import base64
nums = public_key.public_numbers()
e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big")
n_bytes = nums.n.to_bytes((nums.n.bit_length() + 7) // 8, "big")
return {
"kty": "RSA", "kid": kid, "use": "sig", "alg": "RS256",
"n": base64.urlsafe_b64encode(n_bytes).rstrip(b"=").decode(),
"e": base64.urlsafe_b64encode(e_bytes).rstrip(b"=").decode(),
}
PRIVATE_KEY, PUBLIC_KEY = _generate_rsa_keypair()
KID = "kc-kid-1"
KC_URL = "http://keycloak.test:8080"
KC_REALM = "test-realm"
KC_CLIENT_ID = "test-kc-client"
JWKS = {"keys": [_public_numbers_to_jwk(PUBLIC_KEY, KID)]}
def _make_kc_token(claims: dict, kid: str = KID, expired: bool = False) -> str:
now = datetime.datetime.now(datetime.UTC)
base_claims = {
"iss": f"{KC_URL}/realms/{KC_REALM}",
"aud": KC_CLIENT_ID,
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"exp": int((now + datetime.timedelta(hours=1)).timestamp()),
"sub": "user-sub-1",
"preferred_username": "bob",
"name": "Bob Smith",
"jti": "kc-jti-1",
"realm_access": {"roles": ["user"]},
}
if expired:
base_claims["exp"] = int((now - datetime.timedelta(hours=1)).timestamp())
base_claims["nbf"] = int((now - datetime.timedelta(hours=2)).timestamp())
base_claims["iat"] = int((now - datetime.timedelta(hours=2)).timestamp())
base_claims.update(claims)
return jose_jwt.encode(
base_claims, _private_key_pem(PRIVATE_KEY), algorithm="RS256", headers={"kid": kid}
)
def _driver_config(raw_token: str = "") -> dict:
return {
"_raw_token": raw_token,
"keycloak_url": KC_URL,
"keycloak_realm": KC_REALM,
"keycloak_client_id": KC_CLIENT_ID,
"domain": "example.com",
"did_template": "did:web:{domain}/principal/{alias}",
"elevated_suffix": "-elevated",
}
@pytest.fixture
def mock_jwks_fetch():
with patch("gsap_broker.drivers.jwks.httpx.AsyncClient") as mock_http:
mock_resp = MagicMock()
mock_resp.json.return_value = JWKS
mock_resp.raise_for_status = MagicMock()
ctx_manager = AsyncMock()
ctx_manager.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
mock_http.return_value = ctx_manager
yield mock_http
@pytest.mark.asyncio
async def test_valid_keycloak_jwt_accepted(mock_jwks_fetch):
"""C-1: Valid signed Keycloak JWT is accepted."""
token = _make_kc_token({})
driver = KeycloakDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.is_authorized
assert "bob" in result.principal_did
@pytest.mark.asyncio
async def test_forged_keycloak_jwt_rejected(mock_jwks_fetch):
"""C-1: Forged JWT (wrong signature) is rejected."""
# Create a token signed with a DIFFERENT key
other_key, _ = _generate_rsa_keypair()
now = datetime.datetime.now(datetime.UTC)
forged = jose_jwt.encode(
{
"iss": f"{KC_URL}/realms/{KC_REALM}",
"aud": KC_CLIENT_ID,
"exp": int((now + datetime.timedelta(hours=1)).timestamp()),
"sub": "attacker",
"preferred_username": "hacker",
},
_private_key_pem(other_key),
algorithm="RS256",
headers={"kid": KID},
)
driver = KeycloakDriver(config=_driver_config(raw_token=forged))
result = await driver.authenticate()
assert not result.is_authorized
assert "verification failed" in result.denial_reason.lower() or "signature" in result.denial_reason.lower()
@pytest.mark.asyncio
async def test_expired_keycloak_jwt_rejected(mock_jwks_fetch):
token = _make_kc_token({}, expired=True)
driver = KeycloakDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
@pytest.mark.asyncio
async def test_no_token_rejected():
driver = KeycloakDriver(config=_driver_config(raw_token=""))
result = await driver.authenticate()
assert not result.is_authorized
@pytest.mark.asyncio
async def test_jwks_unreachable_rejected():
"""JWKS fetch failure denies — no fallback."""
with patch("gsap_broker.drivers.jwks.httpx.AsyncClient") as mock_http:
ctx = AsyncMock()
ctx.__aenter__.return_value.get = AsyncMock(side_effect=Exception("DNS failure"))
mock_http.return_value = ctx
token = _make_kc_token({})
driver = KeycloakDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
assert "JWKS fetch failed" in result.denial_reason
@pytest.mark.asyncio
async def test_alg_none_rejected(mock_jwks_fetch):
"""alg=none attack is blocked."""
import base64, json
header = base64.urlsafe_b64encode(json.dumps({"alg": "none", "typ": "JWT"}).encode()).rstrip(b"=").decode()
payload = base64.urlsafe_b64encode(json.dumps({
"sub": "attacker", "iss": f"{KC_URL}/realms/{KC_REALM}",
"aud": KC_CLIENT_ID, "exp": 9999999999,
}).encode()).rstrip(b"=").decode()
forged = f"{header}.{payload}."
driver = KeycloakDriver(config=_driver_config(raw_token=forged))
result = await driver.authenticate()
assert not result.is_authorized