guildhouse-spire-plugins/pkg/oidc/oidc.go
Tyler King a58d548518 feat: network-policy extension, governance lifecycle, audit remediation
- Network-policy SPIRE plugin extension
- Governance event notification with merkle anchoring
- Shellstream specs for consent channels + HFL embedded ABI
- All 17 audit findings from AUDIT.md remediated
- SSH credential composer + substrate key manager updates
- Test coverage for config + sshcert packages

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-18 15:54:46 -04:00

451 lines
12 KiB
Go

// Package oidc provides OIDC token verification for SPIRE workload attestation.
package oidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// Config holds OIDC verifier configuration.
type Config struct {
// Issuer is the expected OIDC issuer URL. MUST be an https:// URL.
Issuer string
// Audience is the expected token audience. Required.
Audience string
// JWKSURL overrides automatic OIDC discovery for the JWKS endpoint.
// If set, MUST be an https:// URL.
JWKSURL string
// RequireNonce requires a nonce claim in verified tokens.
RequireNonce bool
}
// Claims represents the verified claims from an OIDC token.
type Claims struct {
Subject string
Issuer string
Audience []string
Email string
Groups []string
JTI string // JWT ID — unique token identifier for replay detection
}
// Verifier validates OIDC tokens and extracts claims.
type Verifier interface {
// Verify validates the token against the expected audience and returns the claims.
// The expectedAudience parameter is required and MUST be checked by implementations.
Verify(ctx context.Context, rawToken string, expectedAudience string) (*Claims, error)
}
// jwksVerifier implements Verifier using JWKS key fetching and JWT validation.
type jwksVerifier struct {
issuer string
audience string
jwksURL string
requireNonce bool
httpClient *http.Client
mu sync.RWMutex
keys map[string]crypto.PublicKey
fetched time.Time
}
// NewVerifier creates an OIDC token verifier from the given configuration.
func NewVerifier(cfg Config) (Verifier, error) {
if cfg.Issuer == "" {
return nil, fmt.Errorf("oidc: issuer is required")
}
if err := requireHTTPS(cfg.Issuer, "issuer"); err != nil {
return nil, err
}
if cfg.Audience == "" {
return nil, fmt.Errorf("oidc: audience is required")
}
if cfg.JWKSURL != "" {
if err := requireHTTPS(cfg.JWKSURL, "jwks_url"); err != nil {
return nil, err
}
}
jwksURL := cfg.JWKSURL
if jwksURL == "" {
// OIDC discovery: fetch .well-known/openid-configuration
jwksURL = strings.TrimRight(cfg.Issuer, "/") + "/.well-known/openid-configuration"
}
return &jwksVerifier{
issuer: cfg.Issuer,
audience: cfg.Audience,
jwksURL: jwksURL,
requireNonce: cfg.RequireNonce,
httpClient: &http.Client{Timeout: 10 * time.Second},
keys: make(map[string]crypto.PublicKey),
}, nil
}
// Verify validates a JWT token string and returns the verified claims.
func (v *jwksVerifier) Verify(ctx context.Context, rawToken string, expectedAudience string) (*Claims, error) {
if expectedAudience == "" {
return nil, fmt.Errorf("oidc: expected audience must not be empty")
}
// Split JWT into parts.
parts := strings.Split(rawToken, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("oidc: invalid JWT format: expected 3 parts, got %d", len(parts))
}
// Decode header.
headerBytes, err := base64URLDecode(parts[0])
if err != nil {
return nil, fmt.Errorf("oidc: decode JWT header: %w", err)
}
var header struct {
Alg string `json:"alg"`
Kid string `json:"kid"`
Typ string `json:"typ"`
}
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, fmt.Errorf("oidc: parse JWT header: %w", err)
}
// Decode payload.
payloadBytes, err := base64URLDecode(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: decode JWT payload: %w", err)
}
// Parse claims.
var rawClaims struct {
Iss string `json:"iss"`
Sub string `json:"sub"`
Aud json.RawMessage `json:"aud"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
Nbf int64 `json:"nbf"`
Email string `json:"email"`
Groups []string `json:"groups"`
JTI string `json:"jti"`
Nonce string `json:"nonce"`
}
if err := json.Unmarshal(payloadBytes, &rawClaims); err != nil {
return nil, fmt.Errorf("oidc: parse JWT claims: %w", err)
}
// Validate issuer.
if rawClaims.Iss != v.issuer {
return nil, fmt.Errorf("oidc: issuer mismatch: got %q, expected %q", rawClaims.Iss, v.issuer)
}
// Parse audience (can be string or array).
var audiences []string
if len(rawClaims.Aud) > 0 {
if rawClaims.Aud[0] == '"' {
var single string
if err := json.Unmarshal(rawClaims.Aud, &single); err == nil {
audiences = []string{single}
}
} else {
json.Unmarshal(rawClaims.Aud, &audiences)
}
}
// Validate audience.
if err := ValidateAudience(audiences, expectedAudience); err != nil {
return nil, err
}
// Validate expiry.
now := time.Now().Unix()
if rawClaims.Exp > 0 && now > rawClaims.Exp {
return nil, fmt.Errorf("oidc: token expired at %d, current time %d", rawClaims.Exp, now)
}
// Validate not-before.
if rawClaims.Nbf > 0 && now < rawClaims.Nbf {
return nil, fmt.Errorf("oidc: token not valid until %d, current time %d", rawClaims.Nbf, now)
}
// Validate nonce if required.
if v.requireNonce && rawClaims.Nonce == "" {
return nil, fmt.Errorf("oidc: nonce is required but not present in token")
}
// Verify signature.
sigBytes, err := base64URLDecode(parts[2])
if err != nil {
return nil, fmt.Errorf("oidc: decode JWT signature: %w", err)
}
if err := v.verifySignature(ctx, header.Alg, header.Kid, parts[0]+"."+parts[1], sigBytes); err != nil {
return nil, fmt.Errorf("oidc: signature verification failed: %w", err)
}
return &Claims{
Subject: rawClaims.Sub,
Issuer: rawClaims.Iss,
Audience: audiences,
Email: rawClaims.Email,
Groups: rawClaims.Groups,
JTI: rawClaims.JTI,
}, nil
}
// verifySignature verifies the JWT signature against the JWKS-fetched public key.
func (v *jwksVerifier) verifySignature(ctx context.Context, alg, kid, signedContent string, signature []byte) error {
key, err := v.getKey(ctx, kid)
if err != nil {
return err
}
signedBytes := []byte(signedContent)
switch alg {
case "RS256":
rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("key %q is not RSA", kid)
}
h := crypto.SHA256.New()
h.Write(signedBytes)
return rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, h.Sum(nil), signature)
case "ES256":
ecKey, ok := key.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("key %q is not ECDSA", kid)
}
h := crypto.SHA256.New()
h.Write(signedBytes)
if !ecdsa.VerifyASN1(ecKey, h.Sum(nil), signature) {
return fmt.Errorf("ECDSA signature verification failed")
}
return nil
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
}
// getKey retrieves a public key by kid, fetching JWKS if needed.
func (v *jwksVerifier) getKey(ctx context.Context, kid string) (crypto.PublicKey, error) {
v.mu.RLock()
key, ok := v.keys[kid]
fetched := v.fetched
v.mu.RUnlock()
if ok {
return key, nil
}
// Refresh JWKS if not fetched recently (max once per 5 minutes).
if time.Since(fetched) < 5*time.Minute && len(v.keys) > 0 {
return nil, fmt.Errorf("key %q not found in JWKS", kid)
}
if err := v.fetchJWKS(ctx); err != nil {
return nil, fmt.Errorf("fetch JWKS: %w", err)
}
v.mu.RLock()
defer v.mu.RUnlock()
key, ok = v.keys[kid]
if !ok {
return nil, fmt.Errorf("key %q not found in JWKS after refresh", kid)
}
return key, nil
}
// fetchJWKS fetches the JWKS document and populates the key cache.
func (v *jwksVerifier) fetchJWKS(ctx context.Context) error {
jwksURL := v.jwksURL
// If the URL is a discovery endpoint, resolve the actual JWKS URI first.
if strings.Contains(jwksURL, ".well-known/openid-configuration") {
discoveryURL := jwksURL
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
if err != nil {
return err
}
resp, err := v.httpClient.Do(req)
if err != nil {
return fmt.Errorf("fetch OIDC discovery: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return err
}
var discovery struct {
JWKSURI string `json:"jwks_uri"`
}
if err := json.Unmarshal(body, &discovery); err != nil {
return fmt.Errorf("parse OIDC discovery: %w", err)
}
if discovery.JWKSURI == "" {
return fmt.Errorf("OIDC discovery document missing jwks_uri")
}
jwksURL = discovery.JWKSURI
// Cache resolved JWKS URL for future fetches.
v.jwksURL = jwksURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
if err != nil {
return err
}
resp, err := v.httpClient.Do(req)
if err != nil {
return fmt.Errorf("fetch JWKS: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return err
}
var jwks struct {
Keys []jwkKey `json:"keys"`
}
if err := json.Unmarshal(body, &jwks); err != nil {
return fmt.Errorf("parse JWKS: %w", err)
}
keys := make(map[string]crypto.PublicKey, len(jwks.Keys))
for _, k := range jwks.Keys {
if k.Use != "" && k.Use != "sig" {
continue
}
pub, err := k.toPublicKey()
if err != nil {
continue // Skip keys we can't parse.
}
keys[k.Kid] = pub
}
v.mu.Lock()
v.keys = keys
v.fetched = time.Now()
v.mu.Unlock()
return nil
}
// jwkKey represents a single JWK entry.
type jwkKey struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"` // RSA modulus
E string `json:"e"` // RSA exponent
Crv string `json:"crv"` // EC curve
X string `json:"x"` // EC x coordinate
Y string `json:"y"` // EC y coordinate
}
func (k *jwkKey) toPublicKey() (crypto.PublicKey, error) {
switch k.Kty {
case "RSA":
return k.toRSAPublicKey()
case "EC":
return k.toECPublicKey()
default:
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
}
}
func (k *jwkKey) toRSAPublicKey() (*rsa.PublicKey, error) {
nBytes, err := base64URLDecode(k.N)
if err != nil {
return nil, fmt.Errorf("decode RSA n: %w", err)
}
eBytes, err := base64URLDecode(k.E)
if err != nil {
return nil, fmt.Errorf("decode RSA e: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := 0
for _, b := range eBytes {
e = e<<8 + int(b)
}
return &rsa.PublicKey{N: n, E: e}, nil
}
func (k *jwkKey) toECPublicKey() (*ecdsa.PublicKey, error) {
var curve elliptic.Curve
switch k.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
default:
return nil, fmt.Errorf("unsupported curve: %s", k.Crv)
}
xBytes, err := base64URLDecode(k.X)
if err != nil {
return nil, fmt.Errorf("decode EC x: %w", err)
}
yBytes, err := base64URLDecode(k.Y)
if err != nil {
return nil, fmt.Errorf("decode EC y: %w", err)
}
return &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
}, nil
}
// base64URLDecode decodes base64url-encoded data (with or without padding).
func base64URLDecode(s string) ([]byte, error) {
// Add padding if needed.
switch len(s) % 4 {
case 2:
s += "=="
case 3:
s += "="
}
return base64.URLEncoding.DecodeString(s)
}
// ValidateAudience checks whether expectedAudience is present in the audience claim list.
func ValidateAudience(audiences []string, expectedAudience string) error {
if expectedAudience == "" {
return fmt.Errorf("oidc: expected audience must not be empty")
}
for _, a := range audiences {
if a == expectedAudience {
return nil
}
}
return fmt.Errorf("oidc: token audience %v does not contain expected audience %q", audiences, expectedAudience)
}
// requireHTTPS validates that a URL uses the https scheme.
func requireHTTPS(rawURL, fieldName string) error {
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("oidc: %s is not a valid URL: %w", fieldName, err)
}
if u.Scheme != "https" {
return fmt.Errorf("oidc: %s must use https:// scheme, got %q", fieldName, u.Scheme)
}
if u.Host == "" {
return fmt.Errorf("oidc: %s has no host", fieldName)
}
return nil
}