fastapi-gsap/tests/test_entra_driver.py
Tyler J King 8196396ce6 feat(drivers): add native Entra identity driver
Validates Entra JWTs directly via JWKS verification.
Extracts device_id for compliance gating, MFA status,
roles, and constructs DID from Entra tenant + oid.
Adds device_id field to AuthResult dataclass.

Signed-off-by: Tyler King <tking@guildhouse.dev>
2026-04-14 05:19:54 -04:00

217 lines
7 KiB
Python

# Copyright 2026 Guildhouse Dev
# SPDX-License-Identifier: Apache-2.0
"""Tests for the Entra identity driver."""
import time
import pytest
from unittest.mock import AsyncMock, patch
from jose import jwt as jose_jwt
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from gsap_broker.drivers.entra import EntraDriver, _jwks_cache
def _generate_rsa_keypair():
"""Generate an RSA key pair for test JWT signing."""
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()
return private_key, public_key
def _private_key_pem(private_key):
return private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
def _public_numbers_to_jwk(public_key, kid="test-kid-1"):
"""Convert RSA public key to JWK dict."""
import base64
nums = public_key.public_numbers()
e_bytes = nums.e.to_bytes((nums.e.bit_length() + 7) // 8, "big")
n_bytes = nums.n.to_bytes((nums.n.bit_length() + 7) // 8, "big")
return {
"kty": "RSA",
"kid": kid,
"use": "sig",
"alg": "RS256",
"n": base64.urlsafe_b64encode(n_bytes).rstrip(b"=").decode(),
"e": base64.urlsafe_b64encode(e_bytes).rstrip(b"=").decode(),
}
PRIVATE_KEY, PUBLIC_KEY = _generate_rsa_keypair()
KID = "test-kid-1"
TENANT_ID = "test-tenant-id-1234"
CLIENT_ID = "test-client-id-5678"
JWKS = {"keys": [_public_numbers_to_jwk(PUBLIC_KEY, KID)]}
def _make_token(claims: dict, kid: str = KID, expired: bool = False) -> str:
"""Create a signed test JWT."""
import datetime
now = datetime.datetime.now(datetime.UTC)
base_claims = {
"iss": f"https://login.microsoftonline.com/{TENANT_ID}/v2.0",
"aud": CLIENT_ID,
"iat": int(now.timestamp()),
"nbf": int(now.timestamp()),
"exp": int((now + datetime.timedelta(hours=1)).timestamp()),
"oid": "user-oid-1",
"tid": TENANT_ID,
"preferred_username": "alice@contoso.com",
"name": "Alice Smith",
"jti": "test-jti",
}
if expired:
base_claims["exp"] = int((now - datetime.timedelta(hours=1)).timestamp())
base_claims["nbf"] = int((now - datetime.timedelta(hours=2)).timestamp())
base_claims["iat"] = int((now - datetime.timedelta(hours=2)).timestamp())
base_claims.update(claims)
return jose_jwt.encode(
base_claims, _private_key_pem(PRIVATE_KEY), algorithm="RS256", headers={"kid": kid}
)
def _driver_config(raw_token: str = "", extra: dict = None) -> dict:
"""Build config dict as the authorize router would."""
import base64, json
token_data = {}
if raw_token:
try:
payload = raw_token.split(".")[1]
payload += "=" * (4 - len(payload) % 4)
token_data = json.loads(base64.urlsafe_b64decode(payload))
except Exception:
pass
config = {
"_token_data": token_data,
"_raw_token": raw_token,
"entra_tenant_id": TENANT_ID,
"entra_client_id": CLIENT_ID,
"domain": "contoso.com",
}
if extra:
config.update(extra)
return config
@pytest.fixture(autouse=True)
def clear_jwks_cache():
_jwks_cache.clear()
yield
_jwks_cache.clear()
@pytest.fixture
def mock_jwks():
"""Mock the JWKS fetch to return test keys."""
with patch("gsap_broker.drivers.entra._get_jwks", new_callable=AsyncMock) as m:
m.return_value = JWKS
yield m
@pytest.mark.asyncio
async def test_authenticate_valid_token(mock_jwks):
token = _make_token({"roles": ["admin"], "amr": ["pwd", "mfa"]})
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.is_authorized
assert result.principal_did == "did:web:contoso.com:principal:user-oid-1"
assert result.display_name == "Alice Smith"
assert result.stable_id == "user-oid-1"
assert result.token_jti == "test-jti"
assert result.mfa_satisfied is True
assert result.device_id is None
@pytest.mark.asyncio
async def test_authenticate_extracts_device_id(mock_jwks):
token = _make_token({"deviceid": "device-abc-123"})
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.is_authorized
assert result.device_id == "device-abc-123"
@pytest.mark.asyncio
async def test_authenticate_no_device_id(mock_jwks):
token = _make_token({})
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.is_authorized
assert result.device_id is None
@pytest.mark.asyncio
async def test_authenticate_expired_token(mock_jwks):
token = _make_token({}, expired=True)
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
assert "expired" in result.denial_reason.lower() or "verification failed" in result.denial_reason.lower()
@pytest.mark.asyncio
async def test_authenticate_no_token():
driver = EntraDriver(config={"_token_data": {}, "_raw_token": ""})
result = await driver.authenticate()
assert not result.is_authorized
assert "No token" in result.denial_reason
@pytest.mark.asyncio
async def test_authenticate_mfa_detection(mock_jwks):
# With MFA
token = _make_token({"amr": ["pwd", "mfa"]})
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.mfa_satisfied is True
# Without MFA
token = _make_token({"amr": ["pwd"]})
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert result.mfa_satisfied is False
@pytest.mark.asyncio
async def test_authenticate_elevation_required(mock_jwks):
token = _make_token({"roles": ["reader"]})
config = _driver_config(raw_token=token, extra={
"requested_accord": "admin-ops",
"accord_roles": {"admin-ops": "admin-role"},
})
driver = EntraDriver(config=config)
result = await driver.authenticate()
assert result.needs_elevation
assert result.elevation_required.role == "admin-role"
assert result.elevation_required.mechanism == "entra_pim"
@pytest.mark.asyncio
async def test_did_construction(mock_jwks):
token = _make_token({"oid": "unique-user-oid"})
driver = EntraDriver(config=_driver_config(raw_token=token, extra={"domain": "example.dev"}))
result = await driver.authenticate()
assert result.principal_did == "did:web:example.dev:principal:unique-user-oid"
@pytest.mark.asyncio
async def test_wrong_kid_rejected(mock_jwks):
token = _make_token({}, kid="unknown-kid")
driver = EntraDriver(config=_driver_config(raw_token=token))
result = await driver.authenticate()
assert not result.is_authorized
assert "signing key" in result.denial_reason.lower()