fastapi-gsap/gsap_broker/drivers/jwks.py
Tyler J King 5015f3dd43 fix(drivers): JWKS verification for Keycloak, remove Entra fallback, gate on_behalf_of
C-1: Keycloak driver now verifies JWT signatures via JWKS.
     Forged tokens are rejected. Previously any base64 JWT was accepted.
C-2: on_behalf_of requires gsap:impersonate role in JWT claims.
C-3: Entra driver denies on JWKS failure (no unverified fallback).
H-10: JWKS cache refreshes on kid miss for key rotation.

Shared JWKSVerifier used by both drivers. alg=none blocked.
iss, aud, exp validated for all tokens.

Signed-off-by: Tyler King <tking@guildhouse.dev>
2026-04-14 07:51:38 -04:00

137 lines
4.4 KiB
Python

# Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Shared JWKS verification for identity drivers.
Fetches, caches, and verifies JWTs against JWKS endpoints.
Used by both Keycloak and Entra drivers.
SECURITY: Never falls back to unverified claims on JWKS failure.
A JWKS fetch failure MUST result in authentication denial.
Fix C-1: Keycloak tokens are now verified via this module.
Fix C-3: Entra tokens are denied on JWKS failure (no fallback).
Fix H-10: JWKS cache refreshes on kid miss for key rotation.
"""
import logging
import time
from typing import Any, Optional
import httpx
from jose import JWTError, jwt as jose_jwt
logger = logging.getLogger(__name__)
class AuthenticationError(Exception):
"""JWT verification failed. The request MUST be denied."""
class JWKSKeyNotFound(AuthenticationError):
"""The JWT's kid does not match any key in the cached JWKS."""
class JWKSVerifier:
"""Verifies JWTs against a remote JWKS endpoint.
Cache TTL defaults to 1 hour. On kid miss, the cache is
invalidated and JWKS is re-fetched once before rejecting.
"""
def __init__(
self,
jwks_url: str,
audience: str,
issuer: str,
cache_ttl: int = 3600,
):
self._jwks_url = jwks_url
self._audience = audience
self._issuer = issuer
self._cache_ttl = cache_ttl
self._jwks_cache: Optional[dict[str, Any]] = None
self._cache_fetched_at: float = 0.0
async def verify_token(self, raw_token: str) -> dict[str, Any]:
"""Verify JWT signature and standard claims.
Returns the verified claims dict.
Raises AuthenticationError on ANY failure — NEVER falls back
to unverified claims.
"""
if not raw_token:
raise AuthenticationError("No token provided")
jwks = await self._fetch_jwks()
try:
unverified_header = jose_jwt.get_unverified_header(raw_token)
except JWTError as e:
raise AuthenticationError(f"Malformed JWT header: {e}")
kid = unverified_header.get("kid", "")
signing_key = self._find_key(jwks, kid)
if signing_key is None:
raise JWKSKeyNotFound(f"No matching key for kid={kid}")
try:
# algorithms=["RS256"] blocks alg=none and HMAC confusion
return jose_jwt.decode(
raw_token,
signing_key,
algorithms=["RS256"],
audience=self._audience,
issuer=self._issuer,
options={"require_exp": True},
)
except JWTError as e:
raise AuthenticationError(f"JWT verification failed: {e}")
async def verify_or_refresh(self, raw_token: str) -> dict[str, Any]:
"""Verify with cache; on kid miss, refresh JWKS once and retry.
Fix H-10: handles key rotation gracefully.
"""
try:
return await self.verify_token(raw_token)
except JWKSKeyNotFound:
# kid not in cache — force refresh and retry once
self._invalidate_cache()
return await self.verify_token(raw_token)
async def _fetch_jwks(self) -> dict[str, Any]:
"""Fetch and cache JWKS. Raises on failure — never falls back."""
if self._cache_valid():
return self._jwks_cache
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(self._jwks_url)
resp.raise_for_status()
self._jwks_cache = resp.json()
self._cache_fetched_at = time.time()
return self._jwks_cache
except Exception as e:
# Fix C-3: NEVER fall back — deny on failure
raise AuthenticationError(
f"JWKS fetch failed from {self._jwks_url}: {e}. "
"Cannot verify token — denying."
)
def _cache_valid(self) -> bool:
return (
self._jwks_cache is not None
and (time.time() - self._cache_fetched_at) < self._cache_ttl
)
def _invalidate_cache(self) -> None:
self._jwks_cache = None
self._cache_fetched_at = 0.0
@staticmethod
def _find_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]:
for key in jwks.get("keys", []):
if key.get("kid") == kid:
return key
return None