C-1: Keycloak driver now verifies JWT signatures via JWKS.
Forged tokens are rejected. Previously any base64 JWT was accepted.
C-2: on_behalf_of requires gsap:impersonate role in JWT claims.
C-3: Entra driver denies on JWKS failure (no unverified fallback).
H-10: JWKS cache refreshes on kid miss for key rotation.
Shared JWKSVerifier used by both drivers. alg=none blocked.
iss, aud, exp validated for all tokens.
Signed-off-by: Tyler King <tking@guildhouse.dev>
137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
# Copyright 2026 Guildhouse Dev
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""Shared JWKS verification for identity drivers.
|
|
|
|
Fetches, caches, and verifies JWTs against JWKS endpoints.
|
|
Used by both Keycloak and Entra drivers.
|
|
|
|
SECURITY: Never falls back to unverified claims on JWKS failure.
|
|
A JWKS fetch failure MUST result in authentication denial.
|
|
|
|
Fix C-1: Keycloak tokens are now verified via this module.
|
|
Fix C-3: Entra tokens are denied on JWKS failure (no fallback).
|
|
Fix H-10: JWKS cache refreshes on kid miss for key rotation.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any, Optional
|
|
|
|
import httpx
|
|
from jose import JWTError, jwt as jose_jwt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AuthenticationError(Exception):
|
|
"""JWT verification failed. The request MUST be denied."""
|
|
|
|
|
|
class JWKSKeyNotFound(AuthenticationError):
|
|
"""The JWT's kid does not match any key in the cached JWKS."""
|
|
|
|
|
|
class JWKSVerifier:
|
|
"""Verifies JWTs against a remote JWKS endpoint.
|
|
|
|
Cache TTL defaults to 1 hour. On kid miss, the cache is
|
|
invalidated and JWKS is re-fetched once before rejecting.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
jwks_url: str,
|
|
audience: str,
|
|
issuer: str,
|
|
cache_ttl: int = 3600,
|
|
):
|
|
self._jwks_url = jwks_url
|
|
self._audience = audience
|
|
self._issuer = issuer
|
|
self._cache_ttl = cache_ttl
|
|
self._jwks_cache: Optional[dict[str, Any]] = None
|
|
self._cache_fetched_at: float = 0.0
|
|
|
|
async def verify_token(self, raw_token: str) -> dict[str, Any]:
|
|
"""Verify JWT signature and standard claims.
|
|
|
|
Returns the verified claims dict.
|
|
Raises AuthenticationError on ANY failure — NEVER falls back
|
|
to unverified claims.
|
|
"""
|
|
if not raw_token:
|
|
raise AuthenticationError("No token provided")
|
|
|
|
jwks = await self._fetch_jwks()
|
|
|
|
try:
|
|
unverified_header = jose_jwt.get_unverified_header(raw_token)
|
|
except JWTError as e:
|
|
raise AuthenticationError(f"Malformed JWT header: {e}")
|
|
|
|
kid = unverified_header.get("kid", "")
|
|
signing_key = self._find_key(jwks, kid)
|
|
if signing_key is None:
|
|
raise JWKSKeyNotFound(f"No matching key for kid={kid}")
|
|
|
|
try:
|
|
# algorithms=["RS256"] blocks alg=none and HMAC confusion
|
|
return jose_jwt.decode(
|
|
raw_token,
|
|
signing_key,
|
|
algorithms=["RS256"],
|
|
audience=self._audience,
|
|
issuer=self._issuer,
|
|
options={"require_exp": True},
|
|
)
|
|
except JWTError as e:
|
|
raise AuthenticationError(f"JWT verification failed: {e}")
|
|
|
|
async def verify_or_refresh(self, raw_token: str) -> dict[str, Any]:
|
|
"""Verify with cache; on kid miss, refresh JWKS once and retry.
|
|
|
|
Fix H-10: handles key rotation gracefully.
|
|
"""
|
|
try:
|
|
return await self.verify_token(raw_token)
|
|
except JWKSKeyNotFound:
|
|
# kid not in cache — force refresh and retry once
|
|
self._invalidate_cache()
|
|
return await self.verify_token(raw_token)
|
|
|
|
async def _fetch_jwks(self) -> dict[str, Any]:
|
|
"""Fetch and cache JWKS. Raises on failure — never falls back."""
|
|
if self._cache_valid():
|
|
return self._jwks_cache
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
resp = await client.get(self._jwks_url)
|
|
resp.raise_for_status()
|
|
self._jwks_cache = resp.json()
|
|
self._cache_fetched_at = time.time()
|
|
return self._jwks_cache
|
|
except Exception as e:
|
|
# Fix C-3: NEVER fall back — deny on failure
|
|
raise AuthenticationError(
|
|
f"JWKS fetch failed from {self._jwks_url}: {e}. "
|
|
"Cannot verify token — denying."
|
|
)
|
|
|
|
def _cache_valid(self) -> bool:
|
|
return (
|
|
self._jwks_cache is not None
|
|
and (time.time() - self._cache_fetched_at) < self._cache_ttl
|
|
)
|
|
|
|
def _invalidate_cache(self) -> None:
|
|
self._jwks_cache = None
|
|
self._cache_fetched_at = 0.0
|
|
|
|
@staticmethod
|
|
def _find_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]:
|
|
for key in jwks.get("keys", []):
|
|
if key.get("kid") == kid:
|
|
return key
|
|
return None
|