mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
c1cbcec851
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat: Introduce Go-based CoRIM generation and deprecate Rust attestation policy scripts. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update dependencies and refactor attestation policy handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Migrate attestation verification to use CoRIM and remove deprecated policy handling and EAT verification tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Removed the `tdx` and `sev-snp` attestation policy scripts and their build configurations, along with related build and installation steps from the main Makefile. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Remove Rust CI workflow and Cargo Dependabot configuration, and enhance Go test setup for attestation policy paths. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Use WriteString instead of Write([]byte) for writing policy file content in test. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Refactor `ca-bundle` command to fetch bundles by product string using a configurable HTTP getter with improved error handling, and simplify `attestation_policy` command usage. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: ignore return value of cmd.Help() Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Implement CoRIM generation for Azure and GCP attestation policies and add a CLI command to download and verify GCP OVMF files. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Upgrade Python virtual environment setup to include setuptools and wheel, append computation ID to Docker container names, and improve test robustness with error assertions and conditional skips for runtime tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Enhance attestation verification tests, including CoRIM integration and specific platform types like Azure SNP, vTPM, TDX, and IGVM. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add comprehensive test cases for `VerifyWithCoRIM` including success and measurement mismatch, and refine reference value validation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add Azure and TDX attestation verification tests and abstract external service dependencies for improved testability. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add new test cases for Azure measurement extraction, EAT platform types, IGVM measurement stopping, vTPM CoRIM verification, and GCP OVMF download CLI. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: enhance CLI CoRIM generation and ATLS certificate verification tests, and refactor the Azure MAA client to use an interface. Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
1309 lines
38 KiB
Go
1309 lines
38 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package atls
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/asn1"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"math/big"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/absmach/certs"
|
|
certssdk "github.com/absmach/certs/sdk"
|
|
sdkmocks "github.com/absmach/certs/sdk/mocks"
|
|
"github.com/absmach/supermq/pkg/errors"
|
|
"github.com/google/go-sev-guest/proto/sevsnp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation"
|
|
"github.com/veraison/corim/corim"
|
|
"golang.org/x/crypto/sha3"
|
|
)
|
|
|
|
// var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
|
// legacy config removed
|
|
|
|
// ... (existing mocks) ...
|
|
|
|
// mockAttestationClient is a simple mock for testing.
|
|
type mockAttestationClient struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
|
args := m.Called(ctx, reportData, nonce, attType)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]byte), args.Error(1)
|
|
}
|
|
|
|
func (m *mockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
|
args := m.Called(ctx, reportData, nonce, attType)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]byte), args.Error(1)
|
|
}
|
|
|
|
func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
|
|
args := m.Called(ctx, nonce)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).([]byte), args.Error(1)
|
|
}
|
|
|
|
func (m *mockAttestationClient) Close() error {
|
|
args := m.Called()
|
|
return args.Error(0)
|
|
}
|
|
|
|
func generateTestCertPEM(t *testing.T) string {
|
|
return generateTestCertPEMWithSubject(t, "test")
|
|
}
|
|
|
|
func generateTestCertPEMWithSubject(t *testing.T, commonName string) string {
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
CommonName: commonName,
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
require.NoError(t, err)
|
|
|
|
certPEM := pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certDER,
|
|
})
|
|
|
|
return strings.ReplaceAll(string(certPEM), "\n", "\\n")
|
|
}
|
|
|
|
func generateTestCertificateWithExtensions(t *testing.T, extensions []pkix.Extension) *x509.Certificate {
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
CommonName: "test",
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
ExtraExtensions: extensions,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
require.NoError(t, err)
|
|
|
|
cert, err := x509.ParseCertificate(certDER)
|
|
require.NoError(t, err)
|
|
|
|
return cert
|
|
}
|
|
|
|
// TestCertificateSubject tests the CertificateSubject functionality.
|
|
func TestDefaultCertificateSubject(t *testing.T) {
|
|
subject := DefaultCertificateSubject()
|
|
|
|
assert.Equal(t, "Ultraviolet", subject.Organization)
|
|
assert.Equal(t, "Serbia", subject.Country)
|
|
assert.Equal(t, "", subject.Province)
|
|
assert.Equal(t, "Belgrade", subject.Locality)
|
|
assert.Equal(t, "Bulevar Arsenija Carnojevica 103", subject.StreetAddress)
|
|
assert.Equal(t, "11000", subject.PostalCode)
|
|
}
|
|
|
|
// TestUnifiedCertificateGenerator tests the unified certificate generator.
|
|
func TestUnifiedCertificateGenerator(t *testing.T) {
|
|
t.Run("SelfSignedGenerator", func(t *testing.T) {
|
|
generator, err := NewProvider(nil, attestation.SNPvTPM, "", "", nil)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, generator)
|
|
})
|
|
|
|
t.Run("CASignedGenerator", func(t *testing.T) {
|
|
mockSDK := sdkmocks.NewSDK(t)
|
|
|
|
generator, err := NewProvider(nil, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, generator)
|
|
})
|
|
}
|
|
|
|
// TestPlatformAttestationProvider tests the platform attestation provider.
|
|
func TestPlatformAttestationProvider(t *testing.T) {
|
|
t.Run("NewAttestationProvider", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
cases := []struct {
|
|
name string
|
|
platformType attestation.PlatformType
|
|
expectError bool
|
|
}{
|
|
{"SNPvTPM", attestation.SNPvTPM, false},
|
|
{"Azure", attestation.Azure, false},
|
|
{"TDX", attestation.TDX, false},
|
|
{"Invalid", attestation.PlatformType(999), true},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
provider, err := NewAttestationProvider(mockClient, c.platformType)
|
|
|
|
if c.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, provider)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, provider)
|
|
assert.Equal(t, c.platformType, provider.PlatformType())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("GetAttestation", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
expectedAttestation := []byte("test-attestation")
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedAttestation, nil)
|
|
|
|
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
pubKey := []byte("test-pubkey")
|
|
nonce := []byte("test-nonce")
|
|
|
|
attestation, err := provider.Attest(pubKey, nonce)
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedAttestation, attestation)
|
|
mockClient.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("GetAttestationError", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
|
|
|
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
_, err = provider.Attest([]byte("pubkey"), []byte("nonce"))
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
// TestAttestedCertificateProvider tests the attested certificate provider.
|
|
func TestAttestedCertificateProvider(t *testing.T) {
|
|
t.Run("GetCertificateSuccess", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
|
|
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
subject := DefaultCertificateSubject()
|
|
|
|
provider := NewAttestedProvider(attestationProvider, subject)
|
|
|
|
// Create valid client hello with nonce
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
cert, err := provider.GetCertificate(clientHello)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, cert)
|
|
assert.NotEmpty(t, cert.Certificate)
|
|
assert.NotNil(t, cert.PrivateKey)
|
|
})
|
|
|
|
t.Run("InvalidServerName", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
|
|
|
clientHello := &tls.ClientHelloInfo{ServerName: "invalid-server-name"}
|
|
|
|
_, err = provider.GetCertificate(clientHello)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to extract nonce")
|
|
})
|
|
|
|
t.Run("AttestationError", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
|
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
|
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
_, err = provider.GetCertificate(clientHello)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to get attestation")
|
|
})
|
|
}
|
|
|
|
// TestNewProvider tests the factory function.
|
|
func TestNewProvider(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
|
|
t.Run("SelfSignedProvider", func(t *testing.T) {
|
|
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, provider)
|
|
})
|
|
|
|
t.Run("CASignedProviderWithSDK", func(t *testing.T) {
|
|
mockSDK := sdkmocks.NewSDK(t)
|
|
|
|
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, provider)
|
|
})
|
|
|
|
t.Run("SelfSignedProviderNilSDK", func(t *testing.T) {
|
|
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", nil)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, provider)
|
|
})
|
|
|
|
t.Run("InvalidPlatformType", func(t *testing.T) {
|
|
_, err := NewProvider(mockClient, attestation.PlatformType(999), "", "", nil)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
// TestCertificateVerifier tests certificate verification.
|
|
func TestCertificateVerifier(t *testing.T) {
|
|
// Setup test policy
|
|
tempDir, err := os.MkdirTemp("", "policy")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
attestationPB := prepVerifyAttReport(t)
|
|
err = setAttestationPolicy(attestationPB, tempDir)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("NewCertificateVerifier", func(t *testing.T) {
|
|
rootCAs := x509.NewCertPool()
|
|
verifier := certificateVerifier{rootCAs: rootCAs}
|
|
|
|
assert.Equal(t, rootCAs, verifier.rootCAs)
|
|
})
|
|
|
|
t.Run("VerifyPeerCertificateNoCertificates", func(t *testing.T) {
|
|
verifier := NewCertificateVerifier(nil)
|
|
err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce"))
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no certificates provided")
|
|
})
|
|
|
|
t.Run("VerifyPeerCertificateInvalidCert", func(t *testing.T) {
|
|
verifier := NewCertificateVerifier(nil)
|
|
err := verifier.VerifyPeerCertificate([][]byte{[]byte("invalid")}, nil, []byte("nonce"))
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to parse x509 certificate")
|
|
})
|
|
|
|
t.Run("VerifyPeerCertificateNoAttestationExtension", func(t *testing.T) {
|
|
cert := createSelfSignedCert(t)
|
|
verifier := NewCertificateVerifier(nil)
|
|
|
|
err := verifier.VerifyPeerCertificate([][]byte{cert.Raw}, nil, []byte("nonce"))
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "attestation extension not found")
|
|
})
|
|
}
|
|
|
|
// TestExtractNonceFromSNI tests nonce extraction from SNI.
|
|
func TestExtractNonceFromSNI(t *testing.T) {
|
|
t.Run("ValidNonce", func(t *testing.T) {
|
|
nonce := make([]byte, 64)
|
|
_, err := rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
|
|
extractedNonce, err := extractNonceFromSNI(serverName)
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, nonce, extractedNonce)
|
|
})
|
|
|
|
t.Run("InvalidServerName", func(t *testing.T) {
|
|
_, err := extractNonceFromSNI("invalid-server-name")
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("InvalidNonceLength", func(t *testing.T) {
|
|
shortNonce := make([]byte, 32) // Too short
|
|
serverName := hex.EncodeToString(shortNonce) + ".nonce"
|
|
|
|
_, err := extractNonceFromSNI(serverName)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid nonce length")
|
|
})
|
|
|
|
t.Run("InvalidHexEncoding", func(t *testing.T) {
|
|
serverName := "invalid-hex-encoding.nonce"
|
|
|
|
_, err := extractNonceFromSNI(serverName)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to decode nonce")
|
|
})
|
|
|
|
t.Run("MissingNonceSuffix", func(t *testing.T) {
|
|
nonce := make([]byte, 64)
|
|
_, err := rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".invalid"
|
|
|
|
_, err = extractNonceFromSNI(serverName)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
// TestHasNonceSuffix tests the nonce suffix checking.
|
|
func TestHasNonceSuffix(t *testing.T) {
|
|
t.Run("ValidSuffix", func(t *testing.T) {
|
|
assert.True(t, hasNonceSuffix("test.nonce"))
|
|
})
|
|
|
|
t.Run("InvalidSuffix", func(t *testing.T) {
|
|
assert.False(t, hasNonceSuffix("test.invalid"))
|
|
})
|
|
|
|
t.Run("TooShort", func(t *testing.T) {
|
|
assert.False(t, hasNonceSuffix(".non"))
|
|
})
|
|
|
|
t.Run("EmptyString", func(t *testing.T) {
|
|
assert.False(t, hasNonceSuffix(""))
|
|
})
|
|
}
|
|
|
|
// TestOIDFunctions tests OID-related functions.
|
|
func TestPlatformVerifier(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "policy")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
attestationPB := prepVerifyAttReport(t)
|
|
err = setAttestationPolicy(attestationPB, tempDir)
|
|
require.NoError(t, err)
|
|
|
|
oldPath := attestation.AttestationPolicyPath
|
|
t.Cleanup(func() {
|
|
attestation.AttestationPolicyPath = oldPath
|
|
})
|
|
|
|
cases := []struct {
|
|
name string
|
|
platformType attestation.PlatformType
|
|
expectedError bool
|
|
}{
|
|
{"SNPvTPM", attestation.SNPvTPM, false},
|
|
{"Azure", attestation.Azure, false},
|
|
{"TDX", attestation.TDX, false}, // Expected success with new verifier logic
|
|
{"Invalid", attestation.PlatformType(999), true},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
verifier, err := platformVerifier(c.platformType)
|
|
|
|
if c.expectedError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, verifier)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, verifier)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetOID(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
platformType attestation.PlatformType
|
|
expectedOID asn1.ObjectIdentifier
|
|
expectedError bool
|
|
}{
|
|
{"SNPvTPM", attestation.SNPvTPM, SNPvTPMOID, false},
|
|
{"Azure", attestation.Azure, AzureOID, false},
|
|
{"TDX", attestation.TDX, TDXOID, false},
|
|
{"Invalid", attestation.PlatformType(999), nil, true},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
oid, err := OID(c.platformType)
|
|
|
|
if c.expectedError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, oid)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, c.expectedOID, oid)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPlatformTypeFromOID(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
oid asn1.ObjectIdentifier
|
|
expectedType attestation.PlatformType
|
|
expectedError bool
|
|
}{
|
|
{"SNPvTPM", SNPvTPMOID, attestation.SNPvTPM, false},
|
|
{"Azure", AzureOID, attestation.Azure, false},
|
|
{"TDX", TDXOID, attestation.TDX, false},
|
|
{"Invalid", asn1.ObjectIdentifier{1, 2, 3}, attestation.PlatformType(0), true},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
pType, err := platformTypeFromOID(c.oid)
|
|
|
|
if c.expectedError {
|
|
assert.Error(t, err)
|
|
assert.Equal(t, attestation.PlatformType(0), pType)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, c.expectedType, pType)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestVerifyCertificateExtension tests certificate extension verification.
|
|
func TestVerifyCertificateExtension(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "policy")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
attestationPB := prepVerifyAttReport(t)
|
|
err = setAttestationPolicy(attestationPB, tempDir)
|
|
require.NoError(t, err)
|
|
|
|
oldPath := attestation.AttestationPolicyPath
|
|
t.Cleanup(func() {
|
|
attestation.AttestationPolicyPath = oldPath
|
|
})
|
|
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
require.NoError(t, err)
|
|
|
|
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
|
require.NoError(t, err)
|
|
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
teeNonce := append(pubKeyDER, nonce...)
|
|
hashNonce := sha3.Sum512(teeNonce)
|
|
|
|
cases := []struct {
|
|
name string
|
|
extension []byte
|
|
pubKey []byte
|
|
nonce []byte
|
|
platformType attestation.PlatformType
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "ValidExtensionSNPvTPM",
|
|
extension: hashNonce[:],
|
|
pubKey: pubKeyDER,
|
|
nonce: nonce,
|
|
platformType: attestation.SNPvTPM,
|
|
expectError: true, // Expected due to invalid attestation data
|
|
},
|
|
{
|
|
name: "InvalidPlatformType",
|
|
extension: hashNonce[:],
|
|
pubKey: pubKeyDER,
|
|
nonce: nonce,
|
|
platformType: attestation.PlatformType(999),
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "EmptyExtension",
|
|
extension: []byte{},
|
|
pubKey: pubKeyDER,
|
|
nonce: nonce,
|
|
platformType: attestation.SNPvTPM,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "EmptyPublicKey",
|
|
extension: hashNonce[:],
|
|
pubKey: []byte{},
|
|
nonce: nonce,
|
|
platformType: attestation.SNPvTPM,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "EmptyNonce",
|
|
extension: hashNonce[:],
|
|
pubKey: pubKeyDER,
|
|
nonce: []byte{},
|
|
platformType: attestation.SNPvTPM,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
v := certificateVerifier{}
|
|
err := v.verifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
|
|
if c.expectError {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
|
|
// Return a dummy attestation report to avoid parsing issues with stale binary
|
|
return &sevsnp.Attestation{
|
|
Report: &sevsnp.Report{
|
|
FamilyId: make([]byte, 16),
|
|
ImageId: make([]byte, 16),
|
|
Measurement: make([]byte, 48),
|
|
HostData: make([]byte, 32),
|
|
ReportIdMa: make([]byte, 32),
|
|
Policy: 0, // Valid policy? Or ignore
|
|
},
|
|
}
|
|
}
|
|
|
|
func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error {
|
|
// Create a dummy CoRIM
|
|
c := corim.NewUnsignedCorim()
|
|
c.SetID("cocos-test-id")
|
|
|
|
corimBytes, err := c.ToCBOR()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
policyPath := filepath.Join(policyDirectory, "attestation_policy.json")
|
|
|
|
err = os.WriteFile(policyPath, corimBytes, 0o644)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
attestation.AttestationPolicyPath = policyPath
|
|
|
|
return nil
|
|
}
|
|
|
|
// TestCertificateVerification unified test suite for certificate verification.
|
|
func TestCertificateVerification(t *testing.T) {
|
|
// Setup common test data
|
|
selfSignedCert := createSelfSignedCert(t)
|
|
leafCert, rootCert := generateCertificateChain(t)
|
|
rootCAs := createCertPool(rootCert)
|
|
emptyPool := x509.NewCertPool()
|
|
|
|
t.Run("SelfSignedCertificates", func(t *testing.T) {
|
|
testCases := []testCase{
|
|
{
|
|
name: "ValidSelfSignedCertificate",
|
|
cert: selfSignedCert,
|
|
rootCAs: nil,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "EmptyCertificate",
|
|
cert: &x509.Certificate{},
|
|
rootCAs: nil,
|
|
expectError: true,
|
|
errorMsg: "x509: missing ASN.1 contents; use ParseCertificate",
|
|
},
|
|
}
|
|
|
|
runCertificateVerificationTests(t, testCases)
|
|
})
|
|
|
|
t.Run("CertificateChainVerification", func(t *testing.T) {
|
|
testCases := []testCase{
|
|
{
|
|
name: "ValidCertificateWithRootCA",
|
|
cert: leafCert,
|
|
rootCAs: rootCAs,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "SelfSignedCertificate",
|
|
cert: rootCert,
|
|
rootCAs: nil, // Self-signed verification
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "InvalidCertificateWithEmptyPool",
|
|
cert: rootCert,
|
|
rootCAs: emptyPool,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
runCertificateVerificationTests(t, testCases)
|
|
})
|
|
|
|
t.Run("ATLSPeerCertificateVerification", func(t *testing.T) {
|
|
nonce := generateNonce(t)
|
|
|
|
testCases := []atlsTestCase{
|
|
{
|
|
name: "InvalidCertificateData",
|
|
rawCerts: [][]byte{[]byte("invalid cert data")},
|
|
nonce: nonce,
|
|
rootCAs: rootCAs,
|
|
expectError: true,
|
|
errorMsg: "failed to parse x509 certificate",
|
|
},
|
|
{
|
|
name: "ValidCertificateNoAttestationExtension",
|
|
rawCerts: [][]byte{leafCert.Raw},
|
|
nonce: nonce,
|
|
rootCAs: rootCAs,
|
|
expectError: true,
|
|
errorMsg: "attestation extension not found in certificate",
|
|
},
|
|
}
|
|
|
|
runATLSVerificationTests(t, testCases)
|
|
})
|
|
}
|
|
|
|
// TestAttestedCAProvider tests the CA-signed certificate provider.
|
|
func TestAttestedCAProvider(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
subject := DefaultCertificateSubject()
|
|
cvmID := "test-cvm-id"
|
|
agentToken := "test-token"
|
|
|
|
t.Run("NewAttestedCAProvider", func(t *testing.T) {
|
|
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
|
|
assert.NotNil(t, provider)
|
|
})
|
|
|
|
t.Run("SetTTL", func(t *testing.T) {
|
|
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
|
|
|
|
newTTL := time.Hour * 48
|
|
provider.(*attestedCertificateProvider).SetTTL(newTTL)
|
|
|
|
attestedProvider := provider.(*attestedCertificateProvider)
|
|
assert.Equal(t, newTTL, attestedProvider.ttl)
|
|
})
|
|
}
|
|
|
|
// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation.
|
|
func TestCASignedCertificateErrors(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
subject := DefaultCertificateSubject()
|
|
cvmID := "test-cvm-id"
|
|
agentToken := "test-token"
|
|
|
|
cases := []struct {
|
|
name string
|
|
certificate string
|
|
sdkError error
|
|
expectedError string
|
|
}{
|
|
{"SDKIssueError", "", errors.NewSDKError(errors.New("SDK error")), "SDK error"},
|
|
{"InvalidPEMWithRemainingData", "-----BEGIN CERTIFICATE-----\\nVGVzdA==\\n-----END CERTIFICATE-----\\nExtra data here", nil, "unexpected remaining data"},
|
|
{"NoPEMBlockFound", "", nil, "no PEM block found"},
|
|
}
|
|
|
|
for _, c := range cases {
|
|
t.Run(c.name, func(t *testing.T) {
|
|
mockSDK := sdkmocks.NewSDK(t)
|
|
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
|
|
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
|
|
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{Certificate: c.certificate}, c.sdkError)
|
|
|
|
provider := NewAttestedCAProvider(attestationProvider, subject, mockSDK, cvmID, agentToken)
|
|
attestedProvider := provider.(*attestedCertificateProvider)
|
|
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
require.NoError(t, err)
|
|
|
|
extension := pkix.Extension{
|
|
Id: SNPvTPMOID,
|
|
Value: []byte("test-data"),
|
|
}
|
|
|
|
_, err = attestedProvider.generateCASignedCertificate(t.Context(), privateKey, extension)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), c.expectedError)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetCertificateErrors tests error paths in certificate generation.
|
|
func TestGetCertificateErrors(t *testing.T) {
|
|
t.Run("InvalidServerNameFormat", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
|
|
|
clientHello := &tls.ClientHelloInfo{
|
|
ServerName: "invalid-format",
|
|
}
|
|
|
|
_, err = provider.GetCertificate(clientHello)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to extract nonce")
|
|
})
|
|
|
|
t.Run("AttestationProviderError", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
|
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
|
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
_, err = provider.GetCertificate(clientHello)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to get attestation")
|
|
})
|
|
|
|
t.Run("CASignedCertificateError", func(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
|
|
|
|
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
|
require.NoError(t, err)
|
|
|
|
mockSDK := sdkmocks.NewSDK(t)
|
|
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
|
|
sdkErr := errors.NewSDKError(errors.New("CA error"))
|
|
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
|
|
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{}, sdkErr)
|
|
|
|
provider := NewAttestedCAProvider(attestationProvider, DefaultCertificateSubject(), mockSDK, "test-cvm", "test-token")
|
|
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
_, err = provider.GetCertificate(clientHello)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to generate certificate")
|
|
})
|
|
}
|
|
|
|
// TestCertificateVerificationEdgeCases tests edge cases in certificate verification.
|
|
func TestCertificateVerificationEdgeCases(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "policy")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
attestationPB := prepVerifyAttReport(t)
|
|
err = setAttestationPolicy(attestationPB, tempDir)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("VerifyPeerCertificateWithMultipleCerts", func(t *testing.T) {
|
|
verifier := NewCertificateVerifier(nil)
|
|
cert1 := createSelfSignedCert(t)
|
|
cert2 := createSelfSignedCert(t)
|
|
nonce := generateNonce(t)
|
|
|
|
err := verifier.VerifyPeerCertificate([][]byte{cert1.Raw, cert2.Raw}, nil, nonce)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "attestation extension not found")
|
|
})
|
|
|
|
t.Run("VerifyAttestationExtensionWithNoExtensions", func(t *testing.T) {
|
|
cert := createSelfSignedCert(t)
|
|
verifier := certificateVerifier{}
|
|
nonce := generateNonce(t)
|
|
|
|
err := verifier.verifyAttestationExtension(cert, nonce)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "attestation extension not found")
|
|
})
|
|
|
|
t.Run("VerifyAttestationExtensionWithWrongOID", func(t *testing.T) {
|
|
wrongOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5}
|
|
extension := pkix.Extension{
|
|
Id: wrongOID,
|
|
Value: []byte("test-data"),
|
|
}
|
|
|
|
cert := generateTestCertificateWithExtensions(t, []pkix.Extension{extension})
|
|
verifier := certificateVerifier{}
|
|
nonce := generateNonce(t)
|
|
|
|
err := verifier.verifyAttestationExtension(cert, nonce)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "attestation extension not found")
|
|
})
|
|
|
|
t.Run("VerifyCertificateExtensionPlatformVerifierError", func(t *testing.T) {
|
|
verifier := certificateVerifier{}
|
|
invalidPlatformType := attestation.PlatformType(999)
|
|
|
|
err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType)
|
|
assert.Error(t, err)
|
|
// The error occurs during EAT token decoding before platform type validation
|
|
assert.Contains(t, err.Error(), "failed to decode EAT token")
|
|
})
|
|
}
|
|
|
|
// TestCertificateWithAttestationExtension tests certificates with attestation extensions.
|
|
func TestCertificateWithAttestationExtension(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "policy")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
attestationPB := prepVerifyAttReport(t)
|
|
err = setAttestationPolicy(attestationPB, tempDir)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("CertificateWithValidAttestationExtension", func(t *testing.T) {
|
|
// Create certificate with attestation extension
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
require.NoError(t, err)
|
|
|
|
_, err = x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
|
require.NoError(t, err)
|
|
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
extension := pkix.Extension{
|
|
Id: SNPvTPMOID,
|
|
Value: []byte("test-attestation-data"),
|
|
}
|
|
|
|
template := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Test Org"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
ExtraExtensions: []pkix.Extension{extension},
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
|
require.NoError(t, err)
|
|
|
|
cert, err := x509.ParseCertificate(certDER)
|
|
require.NoError(t, err)
|
|
|
|
verifier := certificateVerifier{}
|
|
err = verifier.verifyAttestationExtension(cert, nonce)
|
|
|
|
// Expect error due to invalid attestation data, but extension should be found
|
|
assert.Error(t, err)
|
|
assert.NotContains(t, err.Error(), "attestation extension not found")
|
|
})
|
|
}
|
|
|
|
// TestIntegrationScenarios tests end-to-end integration scenarios.
|
|
func TestIntegrationScenarios(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "policy")
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
attestationPB := prepVerifyAttReport(t)
|
|
err = setAttestationPolicy(attestationPB, tempDir)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("FullSelfSignedFlow", func(t *testing.T) {
|
|
// Setup mock client
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
|
|
|
// Create provider
|
|
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
|
|
require.NoError(t, err)
|
|
|
|
// Generate certificate
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
cert, err := provider.GetCertificate(clientHello)
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, cert)
|
|
assert.NotEmpty(t, cert.Certificate)
|
|
assert.NotNil(t, cert.PrivateKey)
|
|
|
|
// Verify the generated certificate
|
|
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
|
|
require.NoError(t, err)
|
|
|
|
// Check for attestation extension
|
|
found := false
|
|
for _, ext := range parsedCert.Extensions {
|
|
if ext.Id.Equal(SNPvTPMOID) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, found, "Attestation extension should be present")
|
|
})
|
|
|
|
t.Run("FullCASignedFlow", func(t *testing.T) {
|
|
mockSDK := sdkmocks.NewSDK(t)
|
|
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
|
|
expectedCert := certssdk.Certificate{Certificate: generateTestCertPEM(t)}
|
|
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
|
|
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil))
|
|
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
|
|
|
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
|
require.NoError(t, err)
|
|
|
|
nonce := make([]byte, 64)
|
|
_, err = rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
cert, err := provider.GetCertificate(clientHello)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, cert)
|
|
require.NotEmpty(t, cert.Certificate)
|
|
require.NotNil(t, cert.PrivateKey)
|
|
|
|
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
|
|
require.NoError(t, err)
|
|
|
|
assert.NotNil(t, parsedCert.Subject)
|
|
|
|
mockClient.AssertExpectations(t)
|
|
mockSDK.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
// TestConcurrentAccess tests concurrent access scenarios.
|
|
func TestConcurrentAccess(t *testing.T) {
|
|
mockClient := new(mockAttestationClient)
|
|
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
|
|
|
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
|
|
require.NoError(t, err)
|
|
|
|
const numGoroutines = 10
|
|
errors := make(chan error, numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func(id int) {
|
|
nonce := make([]byte, 64)
|
|
_, err := rand.Read(nonce)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
|
|
|
|
cert, err := provider.GetCertificate(clientHello)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
|
|
if cert == nil {
|
|
errors <- fmt.Errorf("nil certificate returned for goroutine %d", id)
|
|
return
|
|
}
|
|
|
|
errors <- nil
|
|
}(i)
|
|
}
|
|
|
|
// Collect results
|
|
for i := 0; i < numGoroutines; i++ {
|
|
err := <-errors
|
|
assert.NoError(t, err)
|
|
}
|
|
}
|
|
|
|
// TestEdgeCasesAndBoundaries tests edge cases and boundary conditions.
|
|
func TestEdgeCasesAndBoundaries(t *testing.T) {
|
|
t.Run("LargeNonce", func(t *testing.T) {
|
|
largeNonce := make([]byte, 1024) // Much larger than expected
|
|
_, err := rand.Read(largeNonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(largeNonce) + ".nonce"
|
|
_, err = extractNonceFromSNI(serverName)
|
|
assert.Error(t, err) // Should fail due to invalid length
|
|
})
|
|
|
|
t.Run("MaxLengthServerName", func(t *testing.T) {
|
|
// Create very long server name
|
|
nonce := make([]byte, 64)
|
|
_, err := rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
longPrefix := strings.Repeat("a", 200)
|
|
serverName := longPrefix + hex.EncodeToString(nonce) + ".nonce"
|
|
|
|
_, err = extractNonceFromSNI(serverName)
|
|
assert.Error(t, err) // Should fail due to invalid format
|
|
})
|
|
|
|
t.Run("MinimalValidNonce", func(t *testing.T) {
|
|
nonce := make([]byte, 64) // Exactly the required length
|
|
_, err := rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
|
|
serverName := hex.EncodeToString(nonce) + ".nonce"
|
|
extractedNonce, err := extractNonceFromSNI(serverName)
|
|
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, nonce, extractedNonce)
|
|
})
|
|
}
|
|
|
|
// Unified test case structures.
|
|
type testCase struct {
|
|
name string
|
|
cert *x509.Certificate
|
|
rootCAs *x509.CertPool
|
|
expectError bool
|
|
errorMsg string
|
|
}
|
|
|
|
type atlsTestCase struct {
|
|
name string
|
|
rawCerts [][]byte
|
|
nonce []byte
|
|
rootCAs *x509.CertPool
|
|
expectError bool
|
|
errorMsg string
|
|
}
|
|
|
|
// Unified test runners.
|
|
func runCertificateVerificationTests(t *testing.T, testCases []testCase) {
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
v := certificateVerifier{
|
|
rootCAs: tc.rootCAs,
|
|
}
|
|
err := v.verifyCertificateSignature(tc.cert)
|
|
|
|
if tc.expectError {
|
|
assert.Error(t, err)
|
|
if tc.errorMsg != "" {
|
|
if tc.errorMsg == "x509: missing ASN.1 contents; use ParseCertificate" {
|
|
// For specific error matching
|
|
assert.True(t, errors.Contains(err, errors.New(tc.errorMsg)),
|
|
fmt.Sprintf("expected error %q, got %v", tc.errorMsg, err))
|
|
} else {
|
|
assert.Contains(t, err.Error(), tc.errorMsg)
|
|
}
|
|
}
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func runATLSVerificationTests(t *testing.T, testCases []atlsTestCase) {
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
v := certificateVerifier{
|
|
rootCAs: tc.rootCAs,
|
|
}
|
|
err := v.VerifyPeerCertificate(tc.rawCerts, nil, tc.nonce)
|
|
|
|
if tc.expectError {
|
|
assert.Error(t, err)
|
|
if tc.errorMsg != "" {
|
|
assert.Contains(t, err.Error(), tc.errorMsg)
|
|
}
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Unified certificate creation utilities.
|
|
func createSelfSignedCert(t *testing.T) *x509.Certificate {
|
|
privateKey := generateRSAKey(t)
|
|
|
|
template := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Test Org"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
return createCertificateFromTemplate(t, &template, &template, &privateKey.PublicKey, privateKey)
|
|
}
|
|
|
|
func generateCertificateChain(t *testing.T) (leafCert, rootCert *x509.Certificate) {
|
|
// Generate root certificate
|
|
rootKey := generateRSAKey(t)
|
|
rootTemplate := x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Test Root CA"},
|
|
Country: []string{"US"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
|
|
rootCert = createCertificateFromTemplate(t, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey)
|
|
|
|
// Generate leaf certificate signed by root
|
|
leafKey := generateRSAKey(t)
|
|
leafTemplate := x509.Certificate{
|
|
SerialNumber: big.NewInt(2),
|
|
Subject: pkix.Name{
|
|
Organization: []string{"Test Leaf"},
|
|
Country: []string{"US"},
|
|
},
|
|
NotBefore: time.Now(),
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
}
|
|
|
|
leafCert = createCertificateFromTemplate(t, &leafTemplate, &rootTemplate, &leafKey.PublicKey, rootKey)
|
|
|
|
return leafCert, rootCert
|
|
}
|
|
|
|
// Helper functions for consistency.
|
|
func generateRSAKey(t *testing.T) *rsa.PrivateKey {
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
return privateKey
|
|
}
|
|
|
|
func createCertificateFromTemplate(t *testing.T, template, parent *x509.Certificate, pub interface{}, priv interface{}) *x509.Certificate {
|
|
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
|
|
require.NoError(t, err)
|
|
|
|
cert, err := x509.ParseCertificate(certDER)
|
|
require.NoError(t, err)
|
|
|
|
return cert
|
|
}
|
|
|
|
func createCertPool(certs ...*x509.Certificate) *x509.CertPool {
|
|
pool := x509.NewCertPool()
|
|
for _, cert := range certs {
|
|
pool.AddCert(cert)
|
|
}
|
|
return pool
|
|
}
|
|
|
|
func generateNonce(t *testing.T) []byte {
|
|
nonce := make([]byte, 64)
|
|
_, err := rand.Read(nonce)
|
|
require.NoError(t, err)
|
|
return nonce
|
|
}
|