// Package oidc provides OIDC token verification for SPIRE workload attestation. package oidc import ( "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" "encoding/base64" "encoding/json" "fmt" "io" "math/big" "net/http" "net/url" "strings" "sync" "time" ) // Config holds OIDC verifier configuration. type Config struct { // Issuer is the expected OIDC issuer URL. MUST be an https:// URL. Issuer string // Audience is the expected token audience. Required. Audience string // JWKSURL overrides automatic OIDC discovery for the JWKS endpoint. // If set, MUST be an https:// URL. JWKSURL string // RequireNonce requires a nonce claim in verified tokens. RequireNonce bool } // Claims represents the verified claims from an OIDC token. type Claims struct { Subject string Issuer string Audience []string Email string Groups []string JTI string // JWT ID — unique token identifier for replay detection } // Verifier validates OIDC tokens and extracts claims. type Verifier interface { // Verify validates the token against the expected audience and returns the claims. // The expectedAudience parameter is required and MUST be checked by implementations. Verify(ctx context.Context, rawToken string, expectedAudience string) (*Claims, error) } // jwksVerifier implements Verifier using JWKS key fetching and JWT validation. type jwksVerifier struct { issuer string audience string jwksURL string requireNonce bool httpClient *http.Client mu sync.RWMutex keys map[string]crypto.PublicKey fetched time.Time } // NewVerifier creates an OIDC token verifier from the given configuration. func NewVerifier(cfg Config) (Verifier, error) { if cfg.Issuer == "" { return nil, fmt.Errorf("oidc: issuer is required") } if err := requireHTTPS(cfg.Issuer, "issuer"); err != nil { return nil, err } if cfg.Audience == "" { return nil, fmt.Errorf("oidc: audience is required") } if cfg.JWKSURL != "" { if err := requireHTTPS(cfg.JWKSURL, "jwks_url"); err != nil { return nil, err } } jwksURL := cfg.JWKSURL if jwksURL == "" { // OIDC discovery: fetch .well-known/openid-configuration jwksURL = strings.TrimRight(cfg.Issuer, "/") + "/.well-known/openid-configuration" } return &jwksVerifier{ issuer: cfg.Issuer, audience: cfg.Audience, jwksURL: jwksURL, requireNonce: cfg.RequireNonce, httpClient: &http.Client{Timeout: 10 * time.Second}, keys: make(map[string]crypto.PublicKey), }, nil } // Verify validates a JWT token string and returns the verified claims. func (v *jwksVerifier) Verify(ctx context.Context, rawToken string, expectedAudience string) (*Claims, error) { if expectedAudience == "" { return nil, fmt.Errorf("oidc: expected audience must not be empty") } // Split JWT into parts. parts := strings.Split(rawToken, ".") if len(parts) != 3 { return nil, fmt.Errorf("oidc: invalid JWT format: expected 3 parts, got %d", len(parts)) } // Decode header. headerBytes, err := base64URLDecode(parts[0]) if err != nil { return nil, fmt.Errorf("oidc: decode JWT header: %w", err) } var header struct { Alg string `json:"alg"` Kid string `json:"kid"` Typ string `json:"typ"` } if err := json.Unmarshal(headerBytes, &header); err != nil { return nil, fmt.Errorf("oidc: parse JWT header: %w", err) } // Decode payload. payloadBytes, err := base64URLDecode(parts[1]) if err != nil { return nil, fmt.Errorf("oidc: decode JWT payload: %w", err) } // Parse claims. var rawClaims struct { Iss string `json:"iss"` Sub string `json:"sub"` Aud json.RawMessage `json:"aud"` Exp int64 `json:"exp"` Iat int64 `json:"iat"` Nbf int64 `json:"nbf"` Email string `json:"email"` Groups []string `json:"groups"` JTI string `json:"jti"` Nonce string `json:"nonce"` } if err := json.Unmarshal(payloadBytes, &rawClaims); err != nil { return nil, fmt.Errorf("oidc: parse JWT claims: %w", err) } // Validate issuer. if rawClaims.Iss != v.issuer { return nil, fmt.Errorf("oidc: issuer mismatch: got %q, expected %q", rawClaims.Iss, v.issuer) } // Parse audience (can be string or array). var audiences []string if len(rawClaims.Aud) > 0 { if rawClaims.Aud[0] == '"' { var single string if err := json.Unmarshal(rawClaims.Aud, &single); err == nil { audiences = []string{single} } } else { json.Unmarshal(rawClaims.Aud, &audiences) } } // Validate audience. if err := ValidateAudience(audiences, expectedAudience); err != nil { return nil, err } // Validate expiry. now := time.Now().Unix() if rawClaims.Exp > 0 && now > rawClaims.Exp { return nil, fmt.Errorf("oidc: token expired at %d, current time %d", rawClaims.Exp, now) } // Validate not-before. if rawClaims.Nbf > 0 && now < rawClaims.Nbf { return nil, fmt.Errorf("oidc: token not valid until %d, current time %d", rawClaims.Nbf, now) } // Validate nonce if required. if v.requireNonce && rawClaims.Nonce == "" { return nil, fmt.Errorf("oidc: nonce is required but not present in token") } // Verify signature. sigBytes, err := base64URLDecode(parts[2]) if err != nil { return nil, fmt.Errorf("oidc: decode JWT signature: %w", err) } if err := v.verifySignature(ctx, header.Alg, header.Kid, parts[0]+"."+parts[1], sigBytes); err != nil { return nil, fmt.Errorf("oidc: signature verification failed: %w", err) } return &Claims{ Subject: rawClaims.Sub, Issuer: rawClaims.Iss, Audience: audiences, Email: rawClaims.Email, Groups: rawClaims.Groups, JTI: rawClaims.JTI, }, nil } // verifySignature verifies the JWT signature against the JWKS-fetched public key. func (v *jwksVerifier) verifySignature(ctx context.Context, alg, kid, signedContent string, signature []byte) error { key, err := v.getKey(ctx, kid) if err != nil { return err } signedBytes := []byte(signedContent) switch alg { case "RS256": rsaKey, ok := key.(*rsa.PublicKey) if !ok { return fmt.Errorf("key %q is not RSA", kid) } h := crypto.SHA256.New() h.Write(signedBytes) return rsa.VerifyPKCS1v15(rsaKey, crypto.SHA256, h.Sum(nil), signature) case "ES256": ecKey, ok := key.(*ecdsa.PublicKey) if !ok { return fmt.Errorf("key %q is not ECDSA", kid) } h := crypto.SHA256.New() h.Write(signedBytes) if !ecdsa.VerifyASN1(ecKey, h.Sum(nil), signature) { return fmt.Errorf("ECDSA signature verification failed") } return nil default: return fmt.Errorf("unsupported algorithm: %s", alg) } } // getKey retrieves a public key by kid, fetching JWKS if needed. func (v *jwksVerifier) getKey(ctx context.Context, kid string) (crypto.PublicKey, error) { v.mu.RLock() key, ok := v.keys[kid] fetched := v.fetched v.mu.RUnlock() if ok { return key, nil } // Refresh JWKS if not fetched recently (max once per 5 minutes). if time.Since(fetched) < 5*time.Minute && len(v.keys) > 0 { return nil, fmt.Errorf("key %q not found in JWKS", kid) } if err := v.fetchJWKS(ctx); err != nil { return nil, fmt.Errorf("fetch JWKS: %w", err) } v.mu.RLock() defer v.mu.RUnlock() key, ok = v.keys[kid] if !ok { return nil, fmt.Errorf("key %q not found in JWKS after refresh", kid) } return key, nil } // fetchJWKS fetches the JWKS document and populates the key cache. func (v *jwksVerifier) fetchJWKS(ctx context.Context) error { jwksURL := v.jwksURL // If the URL is a discovery endpoint, resolve the actual JWKS URI first. if strings.Contains(jwksURL, ".well-known/openid-configuration") { discoveryURL := jwksURL req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryURL, nil) if err != nil { return err } resp, err := v.httpClient.Do(req) if err != nil { return fmt.Errorf("fetch OIDC discovery: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return err } var discovery struct { JWKSURI string `json:"jwks_uri"` } if err := json.Unmarshal(body, &discovery); err != nil { return fmt.Errorf("parse OIDC discovery: %w", err) } if discovery.JWKSURI == "" { return fmt.Errorf("OIDC discovery document missing jwks_uri") } jwksURL = discovery.JWKSURI // Cache resolved JWKS URL for future fetches. v.jwksURL = jwksURL } req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil) if err != nil { return err } resp, err := v.httpClient.Do(req) if err != nil { return fmt.Errorf("fetch JWKS: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { return err } var jwks struct { Keys []jwkKey `json:"keys"` } if err := json.Unmarshal(body, &jwks); err != nil { return fmt.Errorf("parse JWKS: %w", err) } keys := make(map[string]crypto.PublicKey, len(jwks.Keys)) for _, k := range jwks.Keys { if k.Use != "" && k.Use != "sig" { continue } pub, err := k.toPublicKey() if err != nil { continue // Skip keys we can't parse. } keys[k.Kid] = pub } v.mu.Lock() v.keys = keys v.fetched = time.Now() v.mu.Unlock() return nil } // jwkKey represents a single JWK entry. type jwkKey struct { Kty string `json:"kty"` Kid string `json:"kid"` Use string `json:"use"` Alg string `json:"alg"` N string `json:"n"` // RSA modulus E string `json:"e"` // RSA exponent Crv string `json:"crv"` // EC curve X string `json:"x"` // EC x coordinate Y string `json:"y"` // EC y coordinate } func (k *jwkKey) toPublicKey() (crypto.PublicKey, error) { switch k.Kty { case "RSA": return k.toRSAPublicKey() case "EC": return k.toECPublicKey() default: return nil, fmt.Errorf("unsupported key type: %s", k.Kty) } } func (k *jwkKey) toRSAPublicKey() (*rsa.PublicKey, error) { nBytes, err := base64URLDecode(k.N) if err != nil { return nil, fmt.Errorf("decode RSA n: %w", err) } eBytes, err := base64URLDecode(k.E) if err != nil { return nil, fmt.Errorf("decode RSA e: %w", err) } n := new(big.Int).SetBytes(nBytes) e := 0 for _, b := range eBytes { e = e<<8 + int(b) } return &rsa.PublicKey{N: n, E: e}, nil } func (k *jwkKey) toECPublicKey() (*ecdsa.PublicKey, error) { var curve elliptic.Curve switch k.Crv { case "P-256": curve = elliptic.P256() case "P-384": curve = elliptic.P384() default: return nil, fmt.Errorf("unsupported curve: %s", k.Crv) } xBytes, err := base64URLDecode(k.X) if err != nil { return nil, fmt.Errorf("decode EC x: %w", err) } yBytes, err := base64URLDecode(k.Y) if err != nil { return nil, fmt.Errorf("decode EC y: %w", err) } return &ecdsa.PublicKey{ Curve: curve, X: new(big.Int).SetBytes(xBytes), Y: new(big.Int).SetBytes(yBytes), }, nil } // base64URLDecode decodes base64url-encoded data (with or without padding). func base64URLDecode(s string) ([]byte, error) { // Add padding if needed. switch len(s) % 4 { case 2: s += "==" case 3: s += "=" } return base64.URLEncoding.DecodeString(s) } // ValidateAudience checks whether expectedAudience is present in the audience claim list. func ValidateAudience(audiences []string, expectedAudience string) error { if expectedAudience == "" { return fmt.Errorf("oidc: expected audience must not be empty") } for _, a := range audiences { if a == expectedAudience { return nil } } return fmt.Errorf("oidc: token audience %v does not contain expected audience %q", audiences, expectedAudience) } // requireHTTPS validates that a URL uses the https scheme. func requireHTTPS(rawURL, fieldName string) error { u, err := url.Parse(rawURL) if err != nil { return fmt.Errorf("oidc: %s is not a valid URL: %w", fieldName, err) } if u.Scheme != "https" { return fmt.Errorf("oidc: %s must use https:// scheme, got %q", fieldName, u.Scheme) } if u.Host == "" { return fmt.Errorf("oidc: %s has no host", fieldName) } return nil }