guildhouse-spire-plugins/pkg/shellstream/shellstream_test.go

550 lines
15 KiB
Go

package shellstream
import (
"encoding/base64"
"strings"
"testing"
)
// Helper to create a minimal valid extensions set.
func minimalExtensions() *ShellstreamExtensions {
return &ShellstreamExtensions{
TenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
Roles: []string{"analyst"},
}
}
// Helper to create a fully-populated extensions set.
func fullExtensions() *ShellstreamExtensions {
ext := &ShellstreamExtensions{
SatScope: &SatScope{
RegistryType: "oci",
Verbs: []string{"push", "pull"},
ResourcePattern: "tenant-a/*",
},
SatHash: "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
TenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
Roles: []string{"administrator", "engineer"},
CeremonyID: "11223344-5566-7788-99aa-bbccddeeff00",
CeremonyType: "single_approval",
MerkleRoot: "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210",
MerkleProof: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
}
ext.WithGovernanceEpoch(42)
return ext
}
func TestEncodeDecodeRoundTrip(t *testing.T) {
ext := fullExtensions()
encoded, err := Encode(ext)
if err != nil {
t.Fatalf("Encode: %v", err)
}
decoded, err := Decode(encoded)
if err != nil {
t.Fatalf("Decode: %v", err)
}
// Verify all fields round-trip.
if decoded.TenantID != ext.TenantID {
t.Errorf("TenantID: got %q, want %q", decoded.TenantID, ext.TenantID)
}
if len(decoded.Roles) != len(ext.Roles) {
t.Fatalf("Roles length: got %d, want %d", len(decoded.Roles), len(ext.Roles))
}
for i, r := range decoded.Roles {
if r != ext.Roles[i] {
t.Errorf("Roles[%d]: got %q, want %q", i, r, ext.Roles[i])
}
}
if decoded.SatHash != ext.SatHash {
t.Errorf("SatHash: got %q, want %q", decoded.SatHash, ext.SatHash)
}
if decoded.SatScope == nil {
t.Fatal("SatScope: got nil")
}
if decoded.SatScope.RegistryType != ext.SatScope.RegistryType {
t.Errorf("SatScope.RegistryType: got %q, want %q", decoded.SatScope.RegistryType, ext.SatScope.RegistryType)
}
if len(decoded.SatScope.Verbs) != len(ext.SatScope.Verbs) {
t.Fatalf("SatScope.Verbs length: got %d, want %d", len(decoded.SatScope.Verbs), len(ext.SatScope.Verbs))
}
if decoded.SatScope.ResourcePattern != ext.SatScope.ResourcePattern {
t.Errorf("SatScope.ResourcePattern: got %q, want %q", decoded.SatScope.ResourcePattern, ext.SatScope.ResourcePattern)
}
if decoded.CeremonyID != ext.CeremonyID {
t.Errorf("CeremonyID: got %q, want %q", decoded.CeremonyID, ext.CeremonyID)
}
if decoded.CeremonyType != ext.CeremonyType {
t.Errorf("CeremonyType: got %q, want %q", decoded.CeremonyType, ext.CeremonyType)
}
if decoded.MerkleRoot != ext.MerkleRoot {
t.Errorf("MerkleRoot: got %q, want %q", decoded.MerkleRoot, ext.MerkleRoot)
}
if len(decoded.MerkleProof) != len(ext.MerkleProof) {
t.Fatalf("MerkleProof length: got %d, want %d", len(decoded.MerkleProof), len(ext.MerkleProof))
}
for i, b := range decoded.MerkleProof {
if b != ext.MerkleProof[i] {
t.Errorf("MerkleProof[%d]: got %x, want %x", i, b, ext.MerkleProof[i])
}
}
if decoded.GovernanceEpoch != ext.GovernanceEpoch {
t.Errorf("GovernanceEpoch: got %d, want %d", decoded.GovernanceEpoch, ext.GovernanceEpoch)
}
if !decoded.HasGovernanceEpoch() {
t.Error("HasGovernanceEpoch: got false, want true")
}
}
func TestEncodeDecodeMinimal(t *testing.T) {
ext := minimalExtensions()
encoded, err := Encode(ext)
if err != nil {
t.Fatalf("Encode: %v", err)
}
// Only tenant-id and roles should be present.
if len(encoded) != 2 {
t.Errorf("encoded map length: got %d, want 2 (keys: %v)", len(encoded), mapKeys(encoded))
}
if _, ok := encoded[ExtTenantID]; !ok {
t.Error("missing tenant-id")
}
if _, ok := encoded[ExtRoles]; !ok {
t.Error("missing roles")
}
decoded, err := Decode(encoded)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if decoded.TenantID != ext.TenantID {
t.Errorf("TenantID: got %q, want %q", decoded.TenantID, ext.TenantID)
}
if decoded.SatScope != nil {
t.Error("SatScope should be nil for minimal extensions")
}
if decoded.CeremonyID != "" {
t.Error("CeremonyID should be empty for minimal extensions")
}
if decoded.HasGovernanceEpoch() {
t.Error("HasGovernanceEpoch should be false for minimal extensions")
}
}
func TestDecodeUnknownExtensionsIgnored(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
"unknown-ext@guildhouse.io": "some-value",
"permit-pty": "",
"completely-unrelated": "ignored",
}
decoded, err := Decode(m)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if decoded.TenantID != "a1b2c3d4-e5f6-7890-abcd-ef1234567890" {
t.Errorf("TenantID: got %q", decoded.TenantID)
}
}
func TestValidateRequiredTenantID(t *testing.T) {
ext := &ShellstreamExtensions{
Roles: []string{"analyst"},
}
err := Validate(ext)
if err == nil {
t.Fatal("expected error for missing tenant-id")
}
if !strings.Contains(err.Error(), "tenant-id is required") {
t.Errorf("unexpected error: %v", err)
}
}
func TestValidateRequiredRoles(t *testing.T) {
ext := &ShellstreamExtensions{
TenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
}
err := Validate(ext)
if err == nil {
t.Fatal("expected error for missing roles")
}
if !strings.Contains(err.Error(), "roles is required") {
t.Errorf("unexpected error: %v", err)
}
}
func TestValidateInvalidTenantIDFormat(t *testing.T) {
ext := &ShellstreamExtensions{
TenantID: "not-a-uuid",
Roles: []string{"analyst"},
}
err := Validate(ext)
if err == nil {
t.Fatal("expected error for invalid UUID")
}
if !strings.Contains(err.Error(), "not a valid UUID") {
t.Errorf("unexpected error: %v", err)
}
}
func TestValidateSatScopeRequiresSatHash(t *testing.T) {
ext := minimalExtensions()
ext.SatScope = &SatScope{
RegistryType: "oci",
Verbs: []string{"pull"},
ResourcePattern: "*",
}
// Missing SatHash.
err := Validate(ext)
if err == nil {
t.Fatal("expected error for sat-scope without sat-hash")
}
if !strings.Contains(err.Error(), "sat-scope requires sat-hash") {
t.Errorf("unexpected error: %v", err)
}
}
func TestValidateSatHashRequiresSatScope(t *testing.T) {
ext := minimalExtensions()
ext.SatHash = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"
// Missing SatScope.
err := Validate(ext)
if err == nil {
t.Fatal("expected error for sat-hash without sat-scope")
}
if !strings.Contains(err.Error(), "sat-hash requires sat-scope") {
t.Errorf("unexpected error: %v", err)
}
}
func TestValidateSatHashFormat(t *testing.T) {
ext := minimalExtensions()
ext.SatScope = &SatScope{
RegistryType: "oci",
Verbs: []string{"pull"},
ResourcePattern: "*",
}
// Too short.
ext.SatHash = "abcdef"
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "64 hex characters") {
t.Errorf("expected 64-char error, got: %v", err)
}
// Uppercase.
ext.SatHash = "ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789"
err = Validate(ext)
if err == nil || !strings.Contains(err.Error(), "lowercase") {
t.Errorf("expected lowercase error, got: %v", err)
}
// Valid.
ext.SatHash = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"
err = Validate(ext)
if err != nil {
t.Errorf("unexpected error for valid sat-hash: %v", err)
}
}
func TestValidateCeremonyCooccurrence(t *testing.T) {
ext := minimalExtensions()
// ceremony-id without ceremony-type.
ext.CeremonyID = "11223344-5566-7788-99aa-bbccddeeff00"
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "ceremony-id requires ceremony-type") {
t.Errorf("expected co-occurrence error, got: %v", err)
}
// ceremony-type without ceremony-id.
ext.CeremonyID = ""
ext.CeremonyType = "single_approval"
err = Validate(ext)
if err == nil || !strings.Contains(err.Error(), "ceremony-type requires ceremony-id") {
t.Errorf("expected co-occurrence error, got: %v", err)
}
}
func TestValidateUnknownCeremonyType(t *testing.T) {
ext := minimalExtensions()
ext.CeremonyID = "11223344-5566-7788-99aa-bbccddeeff00"
ext.CeremonyType = "unknown_type"
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "unknown ceremony-type") {
t.Errorf("expected unknown ceremony-type error, got: %v", err)
}
}
func TestValidateMerkleProofRequiresRoot(t *testing.T) {
ext := minimalExtensions()
ext.MerkleProof = []byte{0x01, 0x02}
// Missing MerkleRoot.
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "merkle-proof requires merkle-root") {
t.Errorf("expected co-occurrence error, got: %v", err)
}
}
func TestValidateMerkleRootFormat(t *testing.T) {
ext := minimalExtensions()
// Too short.
ext.MerkleRoot = "abcdef"
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "64 hex characters") {
t.Errorf("expected 64-char error, got: %v", err)
}
// Uppercase.
ext.MerkleRoot = "FEDCBA9876543210FEDCBA9876543210FEDCBA9876543210FEDCBA9876543210"
err = Validate(ext)
if err == nil || !strings.Contains(err.Error(), "lowercase") {
t.Errorf("expected lowercase error, got: %v", err)
}
}
func TestValidateEmptyRoleName(t *testing.T) {
ext := &ShellstreamExtensions{
TenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
Roles: []string{"analyst", ""},
}
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "empty role name") {
t.Errorf("expected empty role error, got: %v", err)
}
}
func TestValidateRoleWithComma(t *testing.T) {
ext := &ShellstreamExtensions{
TenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
Roles: []string{"analyst,viewer"},
}
err := Validate(ext)
if err == nil || !strings.Contains(err.Error(), "commas or spaces") {
t.Errorf("expected comma error, got: %v", err)
}
}
func TestDecodeSatScopeJSON(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtSatScope: `{"registry_type":"helm","verbs":["install","upgrade"],"resource_pattern":"ns/*"}`,
ExtSatHash: "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
}
decoded, err := Decode(m)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if decoded.SatScope == nil {
t.Fatal("SatScope is nil")
}
if decoded.SatScope.RegistryType != "helm" {
t.Errorf("RegistryType: got %q, want %q", decoded.SatScope.RegistryType, "helm")
}
if len(decoded.SatScope.Verbs) != 2 || decoded.SatScope.Verbs[0] != "install" {
t.Errorf("Verbs: got %v", decoded.SatScope.Verbs)
}
if decoded.SatScope.ResourcePattern != "ns/*" {
t.Errorf("ResourcePattern: got %q", decoded.SatScope.ResourcePattern)
}
}
func TestDecodeInvalidSatScopeJSON(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtSatScope: "not-valid-json",
}
_, err := Decode(m)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
if !strings.Contains(err.Error(), "unmarshal sat-scope") {
t.Errorf("unexpected error: %v", err)
}
}
func TestDecodeRolesCommaParsing(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "administrator,engineer,analyst",
}
decoded, err := Decode(m)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if len(decoded.Roles) != 3 {
t.Fatalf("Roles length: got %d, want 3", len(decoded.Roles))
}
expected := []string{"administrator", "engineer", "analyst"}
for i, r := range decoded.Roles {
if r != expected[i] {
t.Errorf("Roles[%d]: got %q, want %q", i, r, expected[i])
}
}
}
func TestDecodeMerkleProofBase64(t *testing.T) {
proof := []byte{0xde, 0xad, 0xbe, 0xef, 0x01, 0x02, 0x03, 0x04}
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtMerkleRoot: "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210",
ExtMerkleProof: base64.StdEncoding.EncodeToString(proof),
}
decoded, err := Decode(m)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if len(decoded.MerkleProof) != len(proof) {
t.Fatalf("MerkleProof length: got %d, want %d", len(decoded.MerkleProof), len(proof))
}
for i, b := range decoded.MerkleProof {
if b != proof[i] {
t.Errorf("MerkleProof[%d]: got %x, want %x", i, b, proof[i])
}
}
}
func TestDecodeInvalidMerkleProofBase64(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtMerkleProof: "not-valid-base64!!!",
}
_, err := Decode(m)
if err == nil {
t.Fatal("expected error for invalid base64")
}
if !strings.Contains(err.Error(), "decode merkle-proof") {
t.Errorf("unexpected error: %v", err)
}
}
func TestDecodeGovernanceEpoch(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtGovernanceEpoch: "12345",
}
decoded, err := Decode(m)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if decoded.GovernanceEpoch != 12345 {
t.Errorf("GovernanceEpoch: got %d, want 12345", decoded.GovernanceEpoch)
}
if !decoded.HasGovernanceEpoch() {
t.Error("HasGovernanceEpoch: got false, want true")
}
}
func TestDecodeGovernanceEpochZero(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtGovernanceEpoch: "0",
}
decoded, err := Decode(m)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if decoded.GovernanceEpoch != 0 {
t.Errorf("GovernanceEpoch: got %d, want 0", decoded.GovernanceEpoch)
}
if !decoded.HasGovernanceEpoch() {
t.Error("HasGovernanceEpoch: got false, want true (epoch 0 is valid)")
}
}
func TestDecodeInvalidGovernanceEpoch(t *testing.T) {
m := map[string]string{
ExtTenantID: "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
ExtRoles: "analyst",
ExtGovernanceEpoch: "not-a-number",
}
_, err := Decode(m)
if err == nil {
t.Fatal("expected error for invalid epoch")
}
if !strings.Contains(err.Error(), "parse governance-epoch") {
t.Errorf("unexpected error: %v", err)
}
}
func TestEncodeNilReturnsError(t *testing.T) {
_, err := Encode(nil)
if err == nil {
t.Fatal("expected error for nil extensions")
}
}
func TestValidateNilReturnsError(t *testing.T) {
err := Validate(nil)
if err == nil {
t.Fatal("expected error for nil extensions")
}
}
func TestValidateFullExtensions(t *testing.T) {
ext := fullExtensions()
err := Validate(ext)
if err != nil {
t.Fatalf("unexpected validation error: %v", err)
}
}
func TestValidateMinimalExtensions(t *testing.T) {
ext := minimalExtensions()
err := Validate(ext)
if err != nil {
t.Fatalf("unexpected validation error: %v", err)
}
}
func TestIsValidUUID(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"a1b2c3d4-e5f6-7890-abcd-ef1234567890", true},
{"00000000-0000-0000-0000-000000000000", true},
{"A1B2C3D4-E5F6-7890-ABCD-EF1234567890", false}, // uppercase
{"a1b2c3d4e5f6-7890-abcd-ef1234567890", false}, // missing hyphen
{"too-short", false},
{"", false},
{"a1b2c3d4-e5f6-7890-abcd-ef12345678901", false}, // too long
}
for _, tt := range tests {
got := isValidUUID(tt.input)
if got != tt.want {
t.Errorf("isValidUUID(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
// mapKeys returns the keys of a map for debug output.
func mapKeys(m map[string]string) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}