Files
cocos/pkg/atls/atls_test.go
T
Sammy Kerata Oina c1cbcec851
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
COCOS-577 - Introduce Go-based CoRIM generation and deprecate Rust attestation policy scripts. (#578)
* feat: Introduce Go-based CoRIM generation and deprecate Rust attestation policy scripts.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update dependencies and refactor attestation policy handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Migrate attestation verification to use CoRIM and remove deprecated policy handling and EAT verification tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Removed the `tdx` and `sev-snp` attestation policy scripts and their build configurations, along with related build and installation steps from the main Makefile.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Remove Rust CI workflow and Cargo Dependabot configuration, and enhance Go test setup for attestation policy paths.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Use WriteString instead of Write([]byte) for writing policy file content in test.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Refactor `ca-bundle` command to fetch bundles by product string using a configurable HTTP getter with improved error handling, and simplify `attestation_policy` command usage.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: ignore return value of cmd.Help()

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Implement CoRIM generation for Azure and GCP attestation policies and add a CLI command to download and verify GCP OVMF files.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Upgrade Python virtual environment setup to include setuptools and wheel, append computation ID to Docker container names, and improve test robustness with error assertions and conditional skips for runtime tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: Enhance attestation verification tests, including CoRIM integration and specific platform types like Azure SNP, vTPM, TDX, and IGVM.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add comprehensive test cases for `VerifyWithCoRIM` including success and measurement mismatch, and refine reference value validation.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add Azure and TDX attestation verification tests and abstract external service dependencies for improved testability.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Add new test cases for Azure measurement extraction, EAT platform types, IGVM measurement stopping, vTPM CoRIM verification, and GCP OVMF download CLI.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: enhance CLI CoRIM generation and ATLS certificate verification tests, and refactor the Azure MAA client to use an interface.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-03-19 17:01:24 +01:00

1309 lines
38 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/pem"
"fmt"
"math/big"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/absmach/certs"
certssdk "github.com/absmach/certs/sdk"
sdkmocks "github.com/absmach/certs/sdk/mocks"
"github.com/absmach/supermq/pkg/errors"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/veraison/corim/corim"
"golang.org/x/crypto/sha3"
)
// var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
// legacy config removed
// ... (existing mocks) ...
// mockAttestationClient is a simple mock for testing.
type mockAttestationClient struct {
mock.Mock
}
func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
args := m.Called(ctx, reportData, nonce, attType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
args := m.Called(ctx, reportData, nonce, attType)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
args := m.Called(ctx, nonce)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).([]byte), args.Error(1)
}
func (m *mockAttestationClient) Close() error {
args := m.Called()
return args.Error(0)
}
func generateTestCertPEM(t *testing.T) string {
return generateTestCertPEMWithSubject(t, "test")
}
func generateTestCertPEMWithSubject(t *testing.T, commonName string) string {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
return strings.ReplaceAll(string(certPEM), "\n", "\\n")
}
func generateTestCertificateWithExtensions(t *testing.T, extensions []pkix.Extension) *x509.Certificate {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "test",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: extensions,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
// TestCertificateSubject tests the CertificateSubject functionality.
func TestDefaultCertificateSubject(t *testing.T) {
subject := DefaultCertificateSubject()
assert.Equal(t, "Ultraviolet", subject.Organization)
assert.Equal(t, "Serbia", subject.Country)
assert.Equal(t, "", subject.Province)
assert.Equal(t, "Belgrade", subject.Locality)
assert.Equal(t, "Bulevar Arsenija Carnojevica 103", subject.StreetAddress)
assert.Equal(t, "11000", subject.PostalCode)
}
// TestUnifiedCertificateGenerator tests the unified certificate generator.
func TestUnifiedCertificateGenerator(t *testing.T) {
t.Run("SelfSignedGenerator", func(t *testing.T) {
generator, err := NewProvider(nil, attestation.SNPvTPM, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, generator)
})
t.Run("CASignedGenerator", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
generator, err := NewProvider(nil, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
assert.NoError(t, err)
assert.NotNil(t, generator)
})
}
// TestPlatformAttestationProvider tests the platform attestation provider.
func TestPlatformAttestationProvider(t *testing.T) {
t.Run("NewAttestationProvider", func(t *testing.T) {
mockClient := new(mockAttestationClient)
cases := []struct {
name string
platformType attestation.PlatformType
expectError bool
}{
{"SNPvTPM", attestation.SNPvTPM, false},
{"Azure", attestation.Azure, false},
{"TDX", attestation.TDX, false},
{"Invalid", attestation.PlatformType(999), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
provider, err := NewAttestationProvider(mockClient, c.platformType)
if c.expectError {
assert.Error(t, err)
assert.Nil(t, provider)
} else {
assert.NoError(t, err)
assert.NotNil(t, provider)
assert.Equal(t, c.platformType, provider.PlatformType())
}
})
}
})
t.Run("GetAttestation", func(t *testing.T) {
mockClient := new(mockAttestationClient)
expectedAttestation := []byte("test-attestation")
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedAttestation, nil)
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
pubKey := []byte("test-pubkey")
nonce := []byte("test-nonce")
attestation, err := provider.Attest(pubKey, nonce)
assert.NoError(t, err)
assert.Equal(t, expectedAttestation, attestation)
mockClient.AssertExpectations(t)
})
t.Run("GetAttestationError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
_, err = provider.Attest([]byte("pubkey"), []byte("nonce"))
assert.Error(t, err)
})
}
// TestAttestedCertificateProvider tests the attested certificate provider.
func TestAttestedCertificateProvider(t *testing.T) {
t.Run("GetCertificateSuccess", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
provider := NewAttestedProvider(attestationProvider, subject)
// Create valid client hello with nonce
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
assert.NoError(t, err)
assert.NotNil(t, cert)
assert.NotEmpty(t, cert.Certificate)
assert.NotNil(t, cert.PrivateKey)
})
t.Run("InvalidServerName", func(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
clientHello := &tls.ClientHelloInfo{ServerName: "invalid-server-name"}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract nonce")
})
t.Run("AttestationError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get attestation")
})
}
// TestNewProvider tests the factory function.
func TestNewProvider(t *testing.T) {
mockClient := new(mockAttestationClient)
t.Run("SelfSignedProvider", func(t *testing.T) {
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("CASignedProviderWithSDK", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("SelfSignedProviderNilSDK", func(t *testing.T) {
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", nil)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("InvalidPlatformType", func(t *testing.T) {
_, err := NewProvider(mockClient, attestation.PlatformType(999), "", "", nil)
assert.Error(t, err)
})
}
// TestCertificateVerifier tests certificate verification.
func TestCertificateVerifier(t *testing.T) {
// Setup test policy
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("NewCertificateVerifier", func(t *testing.T) {
rootCAs := x509.NewCertPool()
verifier := certificateVerifier{rootCAs: rootCAs}
assert.Equal(t, rootCAs, verifier.rootCAs)
})
t.Run("VerifyPeerCertificateNoCertificates", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "no certificates provided")
})
t.Run("VerifyPeerCertificateInvalidCert", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
err := verifier.VerifyPeerCertificate([][]byte{[]byte("invalid")}, nil, []byte("nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to parse x509 certificate")
})
t.Run("VerifyPeerCertificateNoAttestationExtension", func(t *testing.T) {
cert := createSelfSignedCert(t)
verifier := NewCertificateVerifier(nil)
err := verifier.VerifyPeerCertificate([][]byte{cert.Raw}, nil, []byte("nonce"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
}
// TestExtractNonceFromSNI tests nonce extraction from SNI.
func TestExtractNonceFromSNI(t *testing.T) {
t.Run("ValidNonce", func(t *testing.T) {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
extractedNonce, err := extractNonceFromSNI(serverName)
assert.NoError(t, err)
assert.Equal(t, nonce, extractedNonce)
})
t.Run("InvalidServerName", func(t *testing.T) {
_, err := extractNonceFromSNI("invalid-server-name")
assert.Error(t, err)
})
t.Run("InvalidNonceLength", func(t *testing.T) {
shortNonce := make([]byte, 32) // Too short
serverName := hex.EncodeToString(shortNonce) + ".nonce"
_, err := extractNonceFromSNI(serverName)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid nonce length")
})
t.Run("InvalidHexEncoding", func(t *testing.T) {
serverName := "invalid-hex-encoding.nonce"
_, err := extractNonceFromSNI(serverName)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to decode nonce")
})
t.Run("MissingNonceSuffix", func(t *testing.T) {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".invalid"
_, err = extractNonceFromSNI(serverName)
assert.Error(t, err)
})
}
// TestHasNonceSuffix tests the nonce suffix checking.
func TestHasNonceSuffix(t *testing.T) {
t.Run("ValidSuffix", func(t *testing.T) {
assert.True(t, hasNonceSuffix("test.nonce"))
})
t.Run("InvalidSuffix", func(t *testing.T) {
assert.False(t, hasNonceSuffix("test.invalid"))
})
t.Run("TooShort", func(t *testing.T) {
assert.False(t, hasNonceSuffix(".non"))
})
t.Run("EmptyString", func(t *testing.T) {
assert.False(t, hasNonceSuffix(""))
})
}
// TestOIDFunctions tests OID-related functions.
func TestPlatformVerifier(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
oldPath := attestation.AttestationPolicyPath
t.Cleanup(func() {
attestation.AttestationPolicyPath = oldPath
})
cases := []struct {
name string
platformType attestation.PlatformType
expectedError bool
}{
{"SNPvTPM", attestation.SNPvTPM, false},
{"Azure", attestation.Azure, false},
{"TDX", attestation.TDX, false}, // Expected success with new verifier logic
{"Invalid", attestation.PlatformType(999), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
verifier, err := platformVerifier(c.platformType)
if c.expectedError {
assert.Error(t, err)
assert.Nil(t, verifier)
} else {
assert.NoError(t, err)
assert.NotNil(t, verifier)
}
})
}
}
func TestGetOID(t *testing.T) {
cases := []struct {
name string
platformType attestation.PlatformType
expectedOID asn1.ObjectIdentifier
expectedError bool
}{
{"SNPvTPM", attestation.SNPvTPM, SNPvTPMOID, false},
{"Azure", attestation.Azure, AzureOID, false},
{"TDX", attestation.TDX, TDXOID, false},
{"Invalid", attestation.PlatformType(999), nil, true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
oid, err := OID(c.platformType)
if c.expectedError {
assert.Error(t, err)
assert.Nil(t, oid)
} else {
assert.NoError(t, err)
assert.Equal(t, c.expectedOID, oid)
}
})
}
}
func TestPlatformTypeFromOID(t *testing.T) {
cases := []struct {
name string
oid asn1.ObjectIdentifier
expectedType attestation.PlatformType
expectedError bool
}{
{"SNPvTPM", SNPvTPMOID, attestation.SNPvTPM, false},
{"Azure", AzureOID, attestation.Azure, false},
{"TDX", TDXOID, attestation.TDX, false},
{"Invalid", asn1.ObjectIdentifier{1, 2, 3}, attestation.PlatformType(0), true},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
pType, err := platformTypeFromOID(c.oid)
if c.expectedError {
assert.Error(t, err)
assert.Equal(t, attestation.PlatformType(0), pType)
} else {
assert.NoError(t, err)
assert.Equal(t, c.expectedType, pType)
}
})
}
}
// TestVerifyCertificateExtension tests certificate extension verification.
func TestVerifyCertificateExtension(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
oldPath := attestation.AttestationPolicyPath
t.Cleanup(func() {
attestation.AttestationPolicyPath = oldPath
})
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
teeNonce := append(pubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
cases := []struct {
name string
extension []byte
pubKey []byte
nonce []byte
platformType attestation.PlatformType
expectError bool
}{
{
name: "ValidExtensionSNPvTPM",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true, // Expected due to invalid attestation data
},
{
name: "InvalidPlatformType",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.PlatformType(999),
expectError: true,
},
{
name: "EmptyExtension",
extension: []byte{},
pubKey: pubKeyDER,
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "EmptyPublicKey",
extension: hashNonce[:],
pubKey: []byte{},
nonce: nonce,
platformType: attestation.SNPvTPM,
expectError: true,
},
{
name: "EmptyNonce",
extension: hashNonce[:],
pubKey: pubKeyDER,
nonce: []byte{},
platformType: attestation.SNPvTPM,
expectError: true,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
v := certificateVerifier{}
err := v.verifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
if c.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// Helper functions
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
// Return a dummy attestation report to avoid parsing issues with stale binary
return &sevsnp.Attestation{
Report: &sevsnp.Report{
FamilyId: make([]byte, 16),
ImageId: make([]byte, 16),
Measurement: make([]byte, 48),
HostData: make([]byte, 32),
ReportIdMa: make([]byte, 32),
Policy: 0, // Valid policy? Or ignore
},
}
}
func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error {
// Create a dummy CoRIM
c := corim.NewUnsignedCorim()
c.SetID("cocos-test-id")
corimBytes, err := c.ToCBOR()
if err != nil {
return err
}
policyPath := filepath.Join(policyDirectory, "attestation_policy.json")
err = os.WriteFile(policyPath, corimBytes, 0o644)
if err != nil {
return nil
}
attestation.AttestationPolicyPath = policyPath
return nil
}
// TestCertificateVerification unified test suite for certificate verification.
func TestCertificateVerification(t *testing.T) {
// Setup common test data
selfSignedCert := createSelfSignedCert(t)
leafCert, rootCert := generateCertificateChain(t)
rootCAs := createCertPool(rootCert)
emptyPool := x509.NewCertPool()
t.Run("SelfSignedCertificates", func(t *testing.T) {
testCases := []testCase{
{
name: "ValidSelfSignedCertificate",
cert: selfSignedCert,
rootCAs: nil,
expectError: false,
},
{
name: "EmptyCertificate",
cert: &x509.Certificate{},
rootCAs: nil,
expectError: true,
errorMsg: "x509: missing ASN.1 contents; use ParseCertificate",
},
}
runCertificateVerificationTests(t, testCases)
})
t.Run("CertificateChainVerification", func(t *testing.T) {
testCases := []testCase{
{
name: "ValidCertificateWithRootCA",
cert: leafCert,
rootCAs: rootCAs,
expectError: false,
},
{
name: "SelfSignedCertificate",
cert: rootCert,
rootCAs: nil, // Self-signed verification
expectError: false,
},
{
name: "InvalidCertificateWithEmptyPool",
cert: rootCert,
rootCAs: emptyPool,
expectError: true,
},
}
runCertificateVerificationTests(t, testCases)
})
t.Run("ATLSPeerCertificateVerification", func(t *testing.T) {
nonce := generateNonce(t)
testCases := []atlsTestCase{
{
name: "InvalidCertificateData",
rawCerts: [][]byte{[]byte("invalid cert data")},
nonce: nonce,
rootCAs: rootCAs,
expectError: true,
errorMsg: "failed to parse x509 certificate",
},
{
name: "ValidCertificateNoAttestationExtension",
rawCerts: [][]byte{leafCert.Raw},
nonce: nonce,
rootCAs: rootCAs,
expectError: true,
errorMsg: "attestation extension not found in certificate",
},
}
runATLSVerificationTests(t, testCases)
})
}
// TestAttestedCAProvider tests the CA-signed certificate provider.
func TestAttestedCAProvider(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
cvmID := "test-cvm-id"
agentToken := "test-token"
t.Run("NewAttestedCAProvider", func(t *testing.T) {
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
assert.NotNil(t, provider)
})
t.Run("SetTTL", func(t *testing.T) {
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
newTTL := time.Hour * 48
provider.(*attestedCertificateProvider).SetTTL(newTTL)
attestedProvider := provider.(*attestedCertificateProvider)
assert.Equal(t, newTTL, attestedProvider.ttl)
})
}
// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation.
func TestCASignedCertificateErrors(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
cvmID := "test-cvm-id"
agentToken := "test-token"
cases := []struct {
name string
certificate string
sdkError error
expectedError string
}{
{"SDKIssueError", "", errors.NewSDKError(errors.New("SDK error")), "SDK error"},
{"InvalidPEMWithRemainingData", "-----BEGIN CERTIFICATE-----\\nVGVzdA==\\n-----END CERTIFICATE-----\\nExtra data here", nil, "unexpected remaining data"},
{"NoPEMBlockFound", "", nil, "no PEM block found"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{Certificate: c.certificate}, c.sdkError)
provider := NewAttestedCAProvider(attestationProvider, subject, mockSDK, cvmID, agentToken)
attestedProvider := provider.(*attestedCertificateProvider)
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
extension := pkix.Extension{
Id: SNPvTPMOID,
Value: []byte("test-data"),
}
_, err = attestedProvider.generateCASignedCertificate(t.Context(), privateKey, extension)
assert.Error(t, err)
assert.Contains(t, err.Error(), c.expectedError)
})
}
}
// TestGetCertificateErrors tests error paths in certificate generation.
func TestGetCertificateErrors(t *testing.T) {
t.Run("InvalidServerNameFormat", func(t *testing.T) {
mockClient := new(mockAttestationClient)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
clientHello := &tls.ClientHelloInfo{
ServerName: "invalid-format",
}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract nonce")
})
t.Run("AttestationProviderError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get attestation")
})
t.Run("CASignedCertificateError", func(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
require.NoError(t, err)
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
sdkErr := errors.NewSDKError(errors.New("CA error"))
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{}, sdkErr)
provider := NewAttestedCAProvider(attestationProvider, DefaultCertificateSubject(), mockSDK, "test-cvm", "test-token")
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to generate certificate")
})
}
// TestCertificateVerificationEdgeCases tests edge cases in certificate verification.
func TestCertificateVerificationEdgeCases(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("VerifyPeerCertificateWithMultipleCerts", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
cert1 := createSelfSignedCert(t)
cert2 := createSelfSignedCert(t)
nonce := generateNonce(t)
err := verifier.VerifyPeerCertificate([][]byte{cert1.Raw, cert2.Raw}, nil, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyAttestationExtensionWithNoExtensions", func(t *testing.T) {
cert := createSelfSignedCert(t)
verifier := certificateVerifier{}
nonce := generateNonce(t)
err := verifier.verifyAttestationExtension(cert, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyAttestationExtensionWithWrongOID", func(t *testing.T) {
wrongOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5}
extension := pkix.Extension{
Id: wrongOID,
Value: []byte("test-data"),
}
cert := generateTestCertificateWithExtensions(t, []pkix.Extension{extension})
verifier := certificateVerifier{}
nonce := generateNonce(t)
err := verifier.verifyAttestationExtension(cert, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyCertificateExtensionPlatformVerifierError", func(t *testing.T) {
verifier := certificateVerifier{}
invalidPlatformType := attestation.PlatformType(999)
err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType)
assert.Error(t, err)
// The error occurs during EAT token decoding before platform type validation
assert.Contains(t, err.Error(), "failed to decode EAT token")
})
}
// TestCertificateWithAttestationExtension tests certificates with attestation extensions.
func TestCertificateWithAttestationExtension(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("CertificateWithValidAttestationExtension", func(t *testing.T) {
// Create certificate with attestation extension
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
_, err = x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
extension := pkix.Extension{
Id: SNPvTPMOID,
Value: []byte("test-attestation-data"),
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
ExtraExtensions: []pkix.Extension{extension},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
verifier := certificateVerifier{}
err = verifier.verifyAttestationExtension(cert, nonce)
// Expect error due to invalid attestation data, but extension should be found
assert.Error(t, err)
assert.NotContains(t, err.Error(), "attestation extension not found")
})
}
// TestIntegrationScenarios tests end-to-end integration scenarios.
func TestIntegrationScenarios(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("FullSelfSignedFlow", func(t *testing.T) {
// Setup mock client
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
// Create provider
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
require.NoError(t, err)
// Generate certificate
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
assert.NoError(t, err)
assert.NotNil(t, cert)
assert.NotEmpty(t, cert.Certificate)
assert.NotNil(t, cert.PrivateKey)
// Verify the generated certificate
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
// Check for attestation extension
found := false
for _, ext := range parsedCert.Extensions {
if ext.Id.Equal(SNPvTPMOID) {
found = true
break
}
}
assert.True(t, found, "Attestation extension should be present")
})
t.Run("FullCASignedFlow", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
expectedCert := certssdk.Certificate{Certificate: generateTestCertPEM(t)}
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil))
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
require.NoError(t, err)
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
require.NoError(t, err)
require.NotNil(t, cert)
require.NotEmpty(t, cert.Certificate)
require.NotNil(t, cert.PrivateKey)
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.NotNil(t, parsedCert.Subject)
mockClient.AssertExpectations(t)
mockSDK.AssertExpectations(t)
})
}
// TestConcurrentAccess tests concurrent access scenarios.
func TestConcurrentAccess(t *testing.T) {
mockClient := new(mockAttestationClient)
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
require.NoError(t, err)
const numGoroutines = 10
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
if err != nil {
errors <- err
return
}
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
if err != nil {
errors <- err
return
}
if cert == nil {
errors <- fmt.Errorf("nil certificate returned for goroutine %d", id)
return
}
errors <- nil
}(i)
}
// Collect results
for i := 0; i < numGoroutines; i++ {
err := <-errors
assert.NoError(t, err)
}
}
// TestEdgeCasesAndBoundaries tests edge cases and boundary conditions.
func TestEdgeCasesAndBoundaries(t *testing.T) {
t.Run("LargeNonce", func(t *testing.T) {
largeNonce := make([]byte, 1024) // Much larger than expected
_, err := rand.Read(largeNonce)
require.NoError(t, err)
serverName := hex.EncodeToString(largeNonce) + ".nonce"
_, err = extractNonceFromSNI(serverName)
assert.Error(t, err) // Should fail due to invalid length
})
t.Run("MaxLengthServerName", func(t *testing.T) {
// Create very long server name
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
longPrefix := strings.Repeat("a", 200)
serverName := longPrefix + hex.EncodeToString(nonce) + ".nonce"
_, err = extractNonceFromSNI(serverName)
assert.Error(t, err) // Should fail due to invalid format
})
t.Run("MinimalValidNonce", func(t *testing.T) {
nonce := make([]byte, 64) // Exactly the required length
_, err := rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
extractedNonce, err := extractNonceFromSNI(serverName)
assert.NoError(t, err)
assert.Equal(t, nonce, extractedNonce)
})
}
// Unified test case structures.
type testCase struct {
name string
cert *x509.Certificate
rootCAs *x509.CertPool
expectError bool
errorMsg string
}
type atlsTestCase struct {
name string
rawCerts [][]byte
nonce []byte
rootCAs *x509.CertPool
expectError bool
errorMsg string
}
// Unified test runners.
func runCertificateVerificationTests(t *testing.T, testCases []testCase) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v := certificateVerifier{
rootCAs: tc.rootCAs,
}
err := v.verifyCertificateSignature(tc.cert)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
if tc.errorMsg == "x509: missing ASN.1 contents; use ParseCertificate" {
// For specific error matching
assert.True(t, errors.Contains(err, errors.New(tc.errorMsg)),
fmt.Sprintf("expected error %q, got %v", tc.errorMsg, err))
} else {
assert.Contains(t, err.Error(), tc.errorMsg)
}
}
} else {
assert.NoError(t, err)
}
})
}
}
func runATLSVerificationTests(t *testing.T, testCases []atlsTestCase) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v := certificateVerifier{
rootCAs: tc.rootCAs,
}
err := v.VerifyPeerCertificate(tc.rawCerts, nil, tc.nonce)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
assert.Contains(t, err.Error(), tc.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
// Unified certificate creation utilities.
func createSelfSignedCert(t *testing.T) *x509.Certificate {
privateKey := generateRSAKey(t)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
return createCertificateFromTemplate(t, &template, &template, &privateKey.PublicKey, privateKey)
}
func generateCertificateChain(t *testing.T) (leafCert, rootCert *x509.Certificate) {
// Generate root certificate
rootKey := generateRSAKey(t)
rootTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Root CA"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
rootCert = createCertificateFromTemplate(t, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey)
// Generate leaf certificate signed by root
leafKey := generateRSAKey(t)
leafTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Leaf"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
leafCert = createCertificateFromTemplate(t, &leafTemplate, &rootTemplate, &leafKey.PublicKey, rootKey)
return leafCert, rootCert
}
// Helper functions for consistency.
func generateRSAKey(t *testing.T) *rsa.PrivateKey {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
return privateKey
}
func createCertificateFromTemplate(t *testing.T, template, parent *x509.Certificate, pub interface{}, priv interface{}) *x509.Certificate {
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
func createCertPool(certs ...*x509.Certificate) *x509.CertPool {
pool := x509.NewCertPool()
for _, cert := range certs {
pool.AddCert(cert)
}
return pool
}
func generateNonce(t *testing.T) []byte {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
return nonce
}