# Copyright 2026 Guildhouse Dev # SPDX-License-Identifier: Apache-2.0 """Shared JWKS verification for identity drivers. Fetches, caches, and verifies JWTs against JWKS endpoints. Used by both Keycloak and Entra drivers. SECURITY: Never falls back to unverified claims on JWKS failure. A JWKS fetch failure MUST result in authentication denial. Fix C-1: Keycloak tokens are now verified via this module. Fix C-3: Entra tokens are denied on JWKS failure (no fallback). Fix H-10: JWKS cache refreshes on kid miss for key rotation. """ import logging import time from typing import Any, Optional import httpx from jose import JWTError, jwt as jose_jwt logger = logging.getLogger(__name__) class AuthenticationError(Exception): """JWT verification failed. The request MUST be denied.""" class JWKSKeyNotFound(AuthenticationError): """The JWT's kid does not match any key in the cached JWKS.""" class JWKSVerifier: """Verifies JWTs against a remote JWKS endpoint. Cache TTL defaults to 1 hour. On kid miss, the cache is invalidated and JWKS is re-fetched once before rejecting. """ def __init__( self, jwks_url: str, audience: str, issuer: str, cache_ttl: int = 3600, ): self._jwks_url = jwks_url self._audience = audience self._issuer = issuer self._cache_ttl = cache_ttl self._jwks_cache: Optional[dict[str, Any]] = None self._cache_fetched_at: float = 0.0 async def verify_token(self, raw_token: str) -> dict[str, Any]: """Verify JWT signature and standard claims. Returns the verified claims dict. Raises AuthenticationError on ANY failure — NEVER falls back to unverified claims. """ if not raw_token: raise AuthenticationError("No token provided") jwks = await self._fetch_jwks() try: unverified_header = jose_jwt.get_unverified_header(raw_token) except JWTError as e: raise AuthenticationError(f"Malformed JWT header: {e}") kid = unverified_header.get("kid", "") signing_key = self._find_key(jwks, kid) if signing_key is None: raise JWKSKeyNotFound(f"No matching key for kid={kid}") try: # algorithms=["RS256"] blocks alg=none and HMAC confusion return jose_jwt.decode( raw_token, signing_key, algorithms=["RS256"], audience=self._audience, issuer=self._issuer, options={"require_exp": True}, ) except JWTError as e: raise AuthenticationError(f"JWT verification failed: {e}") async def verify_or_refresh(self, raw_token: str) -> dict[str, Any]: """Verify with cache; on kid miss, refresh JWKS once and retry. Fix H-10: handles key rotation gracefully. """ try: return await self.verify_token(raw_token) except JWKSKeyNotFound: # kid not in cache — force refresh and retry once self._invalidate_cache() return await self.verify_token(raw_token) async def _fetch_jwks(self) -> dict[str, Any]: """Fetch and cache JWKS. Raises on failure — never falls back.""" if self._cache_valid(): return self._jwks_cache try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get(self._jwks_url) resp.raise_for_status() self._jwks_cache = resp.json() self._cache_fetched_at = time.time() return self._jwks_cache except Exception as e: # Fix C-3: NEVER fall back — deny on failure raise AuthenticationError( f"JWKS fetch failed from {self._jwks_url}: {e}. " "Cannot verify token — denying." ) def _cache_valid(self) -> bool: return ( self._jwks_cache is not None and (time.time() - self._cache_fetched_at) < self._cache_ttl ) def _invalidate_cache(self) -> None: self._jwks_cache = None self._cache_fetched_at = 0.0 @staticmethod def _find_key(jwks: dict[str, Any], kid: str) -> Optional[dict[str, Any]]: for key in jwks.get("keys", []): if key.get("kid") == kid: return key return None