- Network-policy SPIRE plugin extension - Governance event notification with merkle anchoring - Shellstream specs for consent channels + HFL embedded ABI - All 17 audit findings from AUDIT.md remediated - SSH credential composer + substrate key manager updates - Test coverage for config + sshcert packages Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
451 lines
12 KiB
Go
451 lines
12 KiB
Go
// Package oidc provides OIDC token verification for SPIRE workload attestation.
|
|
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Config holds OIDC verifier configuration.
|
|
type Config struct {
|
|
// Issuer is the expected OIDC issuer URL. MUST be an https:// URL.
|
|
Issuer string
|
|
|
|
// Audience is the expected token audience. Required.
|
|
Audience string
|
|
|
|
// JWKSURL overrides automatic OIDC discovery for the JWKS endpoint.
|
|
// If set, MUST be an https:// URL.
|
|
JWKSURL string
|
|
|
|
// RequireNonce requires a nonce claim in verified tokens.
|
|
RequireNonce bool
|
|
}
|
|
|
|
// Claims represents the verified claims from an OIDC token.
|
|
type Claims struct {
|
|
Subject string
|
|
Issuer string
|
|
Audience []string
|
|
Email string
|
|
Groups []string
|
|
JTI string // JWT ID — unique token identifier for replay detection
|
|
}
|
|
|
|
// Verifier validates OIDC tokens and extracts claims.
|
|
type Verifier interface {
|
|
// Verify validates the token against the expected audience and returns the claims.
|
|
// The expectedAudience parameter is required and MUST be checked by implementations.
|
|
Verify(ctx context.Context, rawToken string, expectedAudience string) (*Claims, error)
|
|
}
|
|
|
|
// jwksVerifier implements Verifier using JWKS key fetching and JWT validation.
|
|
type jwksVerifier struct {
|
|
issuer string
|
|
audience string
|
|
jwksURL string
|
|
requireNonce bool
|
|
httpClient *http.Client
|
|
|
|
mu sync.RWMutex
|
|
keys map[string]crypto.PublicKey
|
|
fetched time.Time
|
|
}
|
|
|
|
// NewVerifier creates an OIDC token verifier from the given configuration.
|
|
func NewVerifier(cfg Config) (Verifier, error) {
|
|
if cfg.Issuer == "" {
|
|
return nil, fmt.Errorf("oidc: issuer is required")
|
|
}
|
|
if err := requireHTTPS(cfg.Issuer, "issuer"); err != nil {
|
|
return nil, err
|
|
}
|
|
if cfg.Audience == "" {
|
|
return nil, fmt.Errorf("oidc: audience is required")
|
|
}
|
|
if cfg.JWKSURL != "" {
|
|
if err := requireHTTPS(cfg.JWKSURL, "jwks_url"); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
jwksURL := cfg.JWKSURL
|
|
if jwksURL == "" {
|
|
// OIDC discovery: fetch .well-known/openid-configuration
|
|
jwksURL = strings.TrimRight(cfg.Issuer, "/") + "/.well-known/openid-configuration"
|
|
}
|
|
|
|
return &jwksVerifier{
|
|
issuer: cfg.Issuer,
|
|
audience: cfg.Audience,
|
|
jwksURL: jwksURL,
|
|
requireNonce: cfg.RequireNonce,
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
keys: make(map[string]crypto.PublicKey),
|
|
}, nil
|
|
}
|
|
|
|
// Verify validates a JWT token string and returns the verified claims.
|
|
func (v *jwksVerifier) Verify(ctx context.Context, rawToken string, expectedAudience string) (*Claims, error) {
|
|
if expectedAudience == "" {
|
|
return nil, fmt.Errorf("oidc: expected audience must not be empty")
|
|
}
|
|
|
|
// Split JWT into parts.
|
|
parts := strings.Split(rawToken, ".")
|
|
if len(parts) != 3 {
|
|
return nil, fmt.Errorf("oidc: invalid JWT format: expected 3 parts, got %d", len(parts))
|
|
}
|
|
|
|
// Decode header.
|
|
headerBytes, err := base64URLDecode(parts[0])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("oidc: decode JWT header: %w", err)
|
|
}
|
|
var header struct {
|
|
Alg string `json:"alg"`
|
|
Kid string `json:"kid"`
|
|
Typ string `json:"typ"`
|
|
}
|
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
|
return nil, fmt.Errorf("oidc: parse JWT header: %w", err)
|
|
}
|
|
|
|
// Decode payload.
|
|
payloadBytes, err := base64URLDecode(parts[1])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("oidc: decode JWT payload: %w", err)
|
|
}
|
|
|
|
// Parse claims.
|
|
var rawClaims struct {
|
|
Iss string `json:"iss"`
|
|
Sub string `json:"sub"`
|
|
Aud json.RawMessage `json:"aud"`
|
|
Exp int64 `json:"exp"`
|
|
Iat int64 `json:"iat"`
|
|
Nbf int64 `json:"nbf"`
|
|
Email string `json:"email"`
|
|
Groups []string `json:"groups"`
|
|
JTI string `json:"jti"`
|
|
Nonce string `json:"nonce"`
|
|
}
|
|
if err := json.Unmarshal(payloadBytes, &rawClaims); err != nil {
|
|
return nil, fmt.Errorf("oidc: parse JWT claims: %w", err)
|
|
}
|
|
|
|
// Validate issuer.
|
|
if rawClaims.Iss != v.issuer {
|
|
return nil, fmt.Errorf("oidc: issuer mismatch: got %q, expected %q", rawClaims.Iss, v.issuer)
|
|
}
|
|
|
|
// Parse audience (can be string or array).
|
|
var audiences []string
|
|
if len(rawClaims.Aud) > 0 {
|
|
if rawClaims.Aud[0] == '"' {
|
|
var single string
|
|
if err := json.Unmarshal(rawClaims.Aud, &single); err == nil {
|
|
audiences = []string{single}
|
|
}
|
|
} else {
|
|
json.Unmarshal(rawClaims.Aud, &audiences)
|
|
}
|
|
}
|
|
|
|
// Validate audience.
|
|
if err := ValidateAudience(audiences, expectedAudience); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Validate expiry.
|
|
now := time.Now().Unix()
|
|
if rawClaims.Exp > 0 && now > rawClaims.Exp {
|
|
return nil, fmt.Errorf("oidc: token expired at %d, current time %d", rawClaims.Exp, now)
|
|
}
|
|
|
|
// Validate not-before.
|
|
if rawClaims.Nbf > 0 && now < rawClaims.Nbf {
|
|
return nil, fmt.Errorf("oidc: token not valid until %d, current time %d", rawClaims.Nbf, now)
|
|
}
|
|
|
|
// Validate nonce if required.
|
|
if v.requireNonce && rawClaims.Nonce == "" {
|
|
return nil, fmt.Errorf("oidc: nonce is required but not present in token")
|
|
}
|
|
|
|
// Verify signature.
|
|
sigBytes, err := base64URLDecode(parts[2])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("oidc: decode JWT signature: %w", err)
|
|
}
|
|
|
|
if err := v.verifySignature(ctx, header.Alg, header.Kid, parts[0]+"."+parts[1], sigBytes); err != nil {
|
|
return nil, fmt.Errorf("oidc: signature verification failed: %w", err)
|
|
}
|
|
|
|
return &Claims{
|
|
Subject: rawClaims.Sub,
|
|
Issuer: rawClaims.Iss,
|
|
Audience: audiences,
|
|
Email: rawClaims.Email,
|
|
Groups: rawClaims.Groups,
|
|
JTI: rawClaims.JTI,
|
|
}, nil
|
|
}
|
|
|
|
// verifySignature verifies the JWT signature against the JWKS-fetched public key.
|
|
func (v *jwksVerifier) verifySignature(ctx context.Context, alg, kid, signedContent string, signature []byte) error {
|
|
key, err := v.getKey(ctx, kid)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
signedBytes := []byte(signedContent)
|
|
|
|
switch alg {
|
|
case "RS256":
|
|
rsaKey, ok := key.(*rsa.PublicKey)
|
|
if !ok {
|
|
return fmt.Errorf("key %q is not RSA", kid)
|
|
}
|
|
h := crypto.SHA256.New()
|
|
h.Write(signedBytes)
|
|
return rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, h.Sum(nil), signature)
|
|
case "ES256":
|
|
ecKey, ok := key.(*ecdsa.PublicKey)
|
|
if !ok {
|
|
return fmt.Errorf("key %q is not ECDSA", kid)
|
|
}
|
|
h := crypto.SHA256.New()
|
|
h.Write(signedBytes)
|
|
if !ecdsa.VerifyASN1(ecKey, h.Sum(nil), signature) {
|
|
return fmt.Errorf("ECDSA signature verification failed")
|
|
}
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
}
|
|
|
|
// getKey retrieves a public key by kid, fetching JWKS if needed.
|
|
func (v *jwksVerifier) getKey(ctx context.Context, kid string) (crypto.PublicKey, error) {
|
|
v.mu.RLock()
|
|
key, ok := v.keys[kid]
|
|
fetched := v.fetched
|
|
v.mu.RUnlock()
|
|
|
|
if ok {
|
|
return key, nil
|
|
}
|
|
|
|
// Refresh JWKS if not fetched recently (max once per 5 minutes).
|
|
if time.Since(fetched) < 5*time.Minute && len(v.keys) > 0 {
|
|
return nil, fmt.Errorf("key %q not found in JWKS", kid)
|
|
}
|
|
|
|
if err := v.fetchJWKS(ctx); err != nil {
|
|
return nil, fmt.Errorf("fetch JWKS: %w", err)
|
|
}
|
|
|
|
v.mu.RLock()
|
|
defer v.mu.RUnlock()
|
|
key, ok = v.keys[kid]
|
|
if !ok {
|
|
return nil, fmt.Errorf("key %q not found in JWKS after refresh", kid)
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
// fetchJWKS fetches the JWKS document and populates the key cache.
|
|
func (v *jwksVerifier) fetchJWKS(ctx context.Context) error {
|
|
jwksURL := v.jwksURL
|
|
|
|
// If the URL is a discovery endpoint, resolve the actual JWKS URI first.
|
|
if strings.Contains(jwksURL, ".well-known/openid-configuration") {
|
|
discoveryURL := jwksURL
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := v.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("fetch OIDC discovery: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var discovery struct {
|
|
JWKSURI string `json:"jwks_uri"`
|
|
}
|
|
if err := json.Unmarshal(body, &discovery); err != nil {
|
|
return fmt.Errorf("parse OIDC discovery: %w", err)
|
|
}
|
|
if discovery.JWKSURI == "" {
|
|
return fmt.Errorf("OIDC discovery document missing jwks_uri")
|
|
}
|
|
jwksURL = discovery.JWKSURI
|
|
// Cache resolved JWKS URL for future fetches.
|
|
v.jwksURL = jwksURL
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, err := v.httpClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("fetch JWKS: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var jwks struct {
|
|
Keys []jwkKey `json:"keys"`
|
|
}
|
|
if err := json.Unmarshal(body, &jwks); err != nil {
|
|
return fmt.Errorf("parse JWKS: %w", err)
|
|
}
|
|
|
|
keys := make(map[string]crypto.PublicKey, len(jwks.Keys))
|
|
for _, k := range jwks.Keys {
|
|
if k.Use != "" && k.Use != "sig" {
|
|
continue
|
|
}
|
|
pub, err := k.toPublicKey()
|
|
if err != nil {
|
|
continue // Skip keys we can't parse.
|
|
}
|
|
keys[k.Kid] = pub
|
|
}
|
|
|
|
v.mu.Lock()
|
|
v.keys = keys
|
|
v.fetched = time.Now()
|
|
v.mu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// jwkKey represents a single JWK entry.
|
|
type jwkKey struct {
|
|
Kty string `json:"kty"`
|
|
Kid string `json:"kid"`
|
|
Use string `json:"use"`
|
|
Alg string `json:"alg"`
|
|
N string `json:"n"` // RSA modulus
|
|
E string `json:"e"` // RSA exponent
|
|
Crv string `json:"crv"` // EC curve
|
|
X string `json:"x"` // EC x coordinate
|
|
Y string `json:"y"` // EC y coordinate
|
|
}
|
|
|
|
func (k *jwkKey) toPublicKey() (crypto.PublicKey, error) {
|
|
switch k.Kty {
|
|
case "RSA":
|
|
return k.toRSAPublicKey()
|
|
case "EC":
|
|
return k.toECPublicKey()
|
|
default:
|
|
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
|
}
|
|
}
|
|
|
|
func (k *jwkKey) toRSAPublicKey() (*rsa.PublicKey, error) {
|
|
nBytes, err := base64URLDecode(k.N)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode RSA n: %w", err)
|
|
}
|
|
eBytes, err := base64URLDecode(k.E)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode RSA e: %w", err)
|
|
}
|
|
n := new(big.Int).SetBytes(nBytes)
|
|
e := 0
|
|
for _, b := range eBytes {
|
|
e = e<<8 + int(b)
|
|
}
|
|
return &rsa.PublicKey{N: n, E: e}, nil
|
|
}
|
|
|
|
func (k *jwkKey) toECPublicKey() (*ecdsa.PublicKey, error) {
|
|
var curve elliptic.Curve
|
|
switch k.Crv {
|
|
case "P-256":
|
|
curve = elliptic.P256()
|
|
case "P-384":
|
|
curve = elliptic.P384()
|
|
default:
|
|
return nil, fmt.Errorf("unsupported curve: %s", k.Crv)
|
|
}
|
|
xBytes, err := base64URLDecode(k.X)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode EC x: %w", err)
|
|
}
|
|
yBytes, err := base64URLDecode(k.Y)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode EC y: %w", err)
|
|
}
|
|
return &ecdsa.PublicKey{
|
|
Curve: curve,
|
|
X: new(big.Int).SetBytes(xBytes),
|
|
Y: new(big.Int).SetBytes(yBytes),
|
|
}, nil
|
|
}
|
|
|
|
// base64URLDecode decodes base64url-encoded data (with or without padding).
|
|
func base64URLDecode(s string) ([]byte, error) {
|
|
// Add padding if needed.
|
|
switch len(s) % 4 {
|
|
case 2:
|
|
s += "=="
|
|
case 3:
|
|
s += "="
|
|
}
|
|
return base64.URLEncoding.DecodeString(s)
|
|
}
|
|
|
|
// ValidateAudience checks whether expectedAudience is present in the audience claim list.
|
|
func ValidateAudience(audiences []string, expectedAudience string) error {
|
|
if expectedAudience == "" {
|
|
return fmt.Errorf("oidc: expected audience must not be empty")
|
|
}
|
|
for _, a := range audiences {
|
|
if a == expectedAudience {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("oidc: token audience %v does not contain expected audience %q", audiences, expectedAudience)
|
|
}
|
|
|
|
// requireHTTPS validates that a URL uses the https scheme.
|
|
func requireHTTPS(rawURL, fieldName string) error {
|
|
u, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return fmt.Errorf("oidc: %s is not a valid URL: %w", fieldName, err)
|
|
}
|
|
if u.Scheme != "https" {
|
|
return fmt.Errorf("oidc: %s must use https:// scheme, got %q", fieldName, u.Scheme)
|
|
}
|
|
if u.Host == "" {
|
|
return fmt.Errorf("oidc: %s has no host", fieldName)
|
|
}
|
|
return nil
|
|
}
|