From 80bf813c48300f02259b338b91f844c71be582ea Mon Sep 17 00:00:00 2001 From: Danko Miladinovic <72250944+danko-miladinovic@users.noreply.github.com> Date: Thu, 26 Mar 2026 16:57:09 +0100 Subject: [PATCH] NOISSUE - Post-handshake aTLS (#582) * initial post-handshake aTLS implementation * add header * rebased * remove grpc.go and http.go * fix authenticator issues * add freshness nonce --------- Co-authored-by: ultraviolet Co-authored-by: ultraviolet --- pkg/atls/atls.go | 69 - pkg/atls/atls_test.go | 1308 ------------------ pkg/atls/attestation_provider.go | 82 -- pkg/atls/certificate_provider.go | 190 --- pkg/atls/certificate_verifier.go | 210 --- pkg/atls/certificate_verifier_test.go | 338 ----- pkg/atls/ea/authenticator.go | 370 +++++ pkg/atls/ea/authenticator_test.go | 537 +++++++ pkg/atls/ea/certificate.go | 95 ++ pkg/atls/ea/certverify.go | 148 ++ pkg/atls/ea/cmw_attestation.go | 51 + pkg/atls/ea/exporters.go | 67 + pkg/atls/ea/extensions.go | 61 + pkg/atls/ea/finished.go | 27 + pkg/atls/ea/handshake.go | 55 + pkg/atls/ea/policy.go | 60 + pkg/atls/ea/request.go | 214 +++ pkg/atls/ea/session.go | 59 + pkg/atls/ea/sigscheme.go | 82 ++ pkg/atls/ea/util.go | 34 + pkg/atls/eaattestation/binding.go | 88 ++ pkg/atls/eaattestation/binding_test.go | 199 +++ pkg/atls/eaattestation/types.go | 87 ++ pkg/atls/eaattestation/verify.go | 75 + pkg/atls/evidence_verifier.go | 92 ++ pkg/atls/internal_transport/conn.go | 280 ++++ pkg/atls/internal_transport/conn_test.go | 100 ++ pkg/atls/internal_transport/listener.go | 49 + pkg/atls/internal_transport/protocol.go | 62 + pkg/atls/internal_transport/protocol_test.go | 40 + pkg/atls/mocks/certificateprovider.go | 100 +- pkg/atls/provider.go | 84 ++ pkg/atls/server_tls.go | 93 ++ pkg/atls/tls_helpers.go | 104 ++ pkg/atls/transport.go | 83 ++ pkg/atls/transport_test.go | 78 ++ pkg/clients/clients.go | 33 +- pkg/clients/grpc/connect_test.go | 32 + pkg/clients/grpc/grpc.go | 54 +- pkg/clients/http/client.go | 43 +- pkg/clients/http/client_test.go | 57 + pkg/ingress/proxy.go | 86 +- pkg/ingress/proxy_test.go | 33 +- pkg/tls/tls.go | 22 +- pkg/tls/tls_test.go | 10 +- 45 files changed, 3716 insertions(+), 2325 deletions(-) delete mode 100644 pkg/atls/atls.go delete mode 100644 pkg/atls/atls_test.go delete mode 100644 pkg/atls/attestation_provider.go delete mode 100644 pkg/atls/certificate_provider.go delete mode 100644 pkg/atls/certificate_verifier.go delete mode 100644 pkg/atls/certificate_verifier_test.go create mode 100644 pkg/atls/ea/authenticator.go create mode 100644 pkg/atls/ea/authenticator_test.go create mode 100644 pkg/atls/ea/certificate.go create mode 100644 pkg/atls/ea/certverify.go create mode 100644 pkg/atls/ea/cmw_attestation.go create mode 100644 pkg/atls/ea/exporters.go create mode 100644 pkg/atls/ea/extensions.go create mode 100644 pkg/atls/ea/finished.go create mode 100644 pkg/atls/ea/handshake.go create mode 100644 pkg/atls/ea/policy.go create mode 100644 pkg/atls/ea/request.go create mode 100644 pkg/atls/ea/session.go create mode 100644 pkg/atls/ea/sigscheme.go create mode 100644 pkg/atls/ea/util.go create mode 100644 pkg/atls/eaattestation/binding.go create mode 100644 pkg/atls/eaattestation/binding_test.go create mode 100644 pkg/atls/eaattestation/types.go create mode 100644 pkg/atls/eaattestation/verify.go create mode 100644 pkg/atls/evidence_verifier.go create mode 100644 pkg/atls/internal_transport/conn.go create mode 100644 pkg/atls/internal_transport/conn_test.go create mode 100644 pkg/atls/internal_transport/listener.go create mode 100644 pkg/atls/internal_transport/protocol.go create mode 100644 pkg/atls/internal_transport/protocol_test.go create mode 100644 pkg/atls/provider.go create mode 100644 pkg/atls/server_tls.go create mode 100644 pkg/atls/tls_helpers.go create mode 100644 pkg/atls/transport.go create mode 100644 pkg/atls/transport_test.go diff --git a/pkg/atls/atls.go b/pkg/atls/atls.go deleted file mode 100644 index 399080ed..00000000 --- a/pkg/atls/atls.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package atls - -import ( - "encoding/asn1" - "encoding/hex" - "fmt" -) - -const ( - defaultNotAfterYears = 1 - nonceLength = 64 - nonceSuffix = ".nonce" -) - -// Platform-specific OIDs for certificate extensions. -var ( - SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0} - AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1} - TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2} -) - -// CertificateSubject contains certificate subject information. -type CertificateSubject struct { - Organization string - CommonName string - Country string - Province string - Locality string - StreetAddress string - PostalCode string -} - -// DefaultCertificateSubject returns the default certificate subject for Ultraviolet. -func DefaultCertificateSubject() CertificateSubject { - return CertificateSubject{ - Organization: "Ultraviolet", - CommonName: "Ultraviolet", - Country: "Serbia", - Province: "", - Locality: "Belgrade", - StreetAddress: "Bulevar Arsenija Carnojevica 103", - PostalCode: "11000", - } -} - -func extractNonceFromSNI(serverName string) ([]byte, error) { - if len(serverName) < len(nonceSuffix) || !hasNonceSuffix(serverName) { - return nil, fmt.Errorf("invalid server name: %s", serverName) - } - - nonceStr := serverName[:len(serverName)-len(nonceSuffix)] - nonce, err := hex.DecodeString(nonceStr) - if err != nil { - return nil, fmt.Errorf("failed to decode nonce: %w", err) - } - - if len(nonce) != nonceLength { - return nil, fmt.Errorf("invalid nonce length: expected %d bytes, got %d bytes", nonceLength, len(nonce)) - } - - return nonce, nil -} - -func hasNonceSuffix(serverName string) bool { - return len(serverName) >= len(nonceSuffix) && - serverName[len(serverName)-len(nonceSuffix):] == nonceSuffix -} diff --git a/pkg/atls/atls_test.go b/pkg/atls/atls_test.go deleted file mode 100644 index d1b185a5..00000000 --- a/pkg/atls/atls_test.go +++ /dev/null @@ -1,1308 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package atls - -import ( - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "encoding/hex" - "encoding/pem" - "fmt" - "math/big" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/absmach/certs" - certssdk "github.com/absmach/certs/sdk" - sdkmocks "github.com/absmach/certs/sdk/mocks" - "github.com/absmach/supermq/pkg/errors" - "github.com/google/go-sev-guest/proto/sevsnp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/veraison/corim/corim" - "golang.org/x/crypto/sha3" -) - -// var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}} -// legacy config removed - -// ... (existing mocks) ... - -// mockAttestationClient is a simple mock for testing. -type mockAttestationClient struct { - mock.Mock -} - -func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) { - args := m.Called(ctx, reportData, nonce, attType) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]byte), args.Error(1) -} - -func (m *mockAttestationClient) GetRawEvidence(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) { - args := m.Called(ctx, reportData, nonce, attType) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]byte), args.Error(1) -} - -func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) { - args := m.Called(ctx, nonce) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]byte), args.Error(1) -} - -func (m *mockAttestationClient) Close() error { - args := m.Called() - return args.Error(0) -} - -func generateTestCertPEM(t *testing.T) string { - return generateTestCertPEMWithSubject(t, "test") -} - -func generateTestCertPEMWithSubject(t *testing.T, commonName string) string { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - CommonName: commonName, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - IsCA: true, - } - - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - require.NoError(t, err) - - certPEM := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certDER, - }) - - return strings.ReplaceAll(string(certPEM), "\n", "\\n") -} - -func generateTestCertificateWithExtensions(t *testing.T, extensions []pkix.Extension) *x509.Certificate { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - CommonName: "test", - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - ExtraExtensions: extensions, - } - - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - require.NoError(t, err) - - cert, err := x509.ParseCertificate(certDER) - require.NoError(t, err) - - return cert -} - -// TestCertificateSubject tests the CertificateSubject functionality. -func TestDefaultCertificateSubject(t *testing.T) { - subject := DefaultCertificateSubject() - - assert.Equal(t, "Ultraviolet", subject.Organization) - assert.Equal(t, "Serbia", subject.Country) - assert.Equal(t, "", subject.Province) - assert.Equal(t, "Belgrade", subject.Locality) - assert.Equal(t, "Bulevar Arsenija Carnojevica 103", subject.StreetAddress) - assert.Equal(t, "11000", subject.PostalCode) -} - -// TestUnifiedCertificateGenerator tests the unified certificate generator. -func TestUnifiedCertificateGenerator(t *testing.T) { - t.Run("SelfSignedGenerator", func(t *testing.T) { - generator, err := NewProvider(nil, attestation.SNPvTPM, "", "", nil) - assert.NoError(t, err) - assert.NotNil(t, generator) - }) - - t.Run("CASignedGenerator", func(t *testing.T) { - mockSDK := sdkmocks.NewSDK(t) - - generator, err := NewProvider(nil, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) - assert.NoError(t, err) - assert.NotNil(t, generator) - }) -} - -// TestPlatformAttestationProvider tests the platform attestation provider. -func TestPlatformAttestationProvider(t *testing.T) { - t.Run("NewAttestationProvider", func(t *testing.T) { - mockClient := new(mockAttestationClient) - cases := []struct { - name string - platformType attestation.PlatformType - expectError bool - }{ - {"SNPvTPM", attestation.SNPvTPM, false}, - {"Azure", attestation.Azure, false}, - {"TDX", attestation.TDX, false}, - {"Invalid", attestation.PlatformType(999), true}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - provider, err := NewAttestationProvider(mockClient, c.platformType) - - if c.expectError { - assert.Error(t, err) - assert.Nil(t, provider) - } else { - assert.NoError(t, err) - assert.NotNil(t, provider) - assert.Equal(t, c.platformType, provider.PlatformType()) - } - }) - } - }) - - t.Run("GetAttestation", func(t *testing.T) { - mockClient := new(mockAttestationClient) - expectedAttestation := []byte("test-attestation") - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedAttestation, nil) - - provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - pubKey := []byte("test-pubkey") - nonce := []byte("test-nonce") - - attestation, err := provider.Attest(pubKey, nonce) - - assert.NoError(t, err) - assert.Equal(t, expectedAttestation, attestation) - mockClient.AssertExpectations(t) - }) - - t.Run("GetAttestationError", func(t *testing.T) { - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) - - provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - _, err = provider.Attest([]byte("pubkey"), []byte("nonce")) - assert.Error(t, err) - }) -} - -// TestAttestedCertificateProvider tests the attested certificate provider. -func TestAttestedCertificateProvider(t *testing.T) { - t.Run("GetCertificateSuccess", func(t *testing.T) { - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil) - - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - subject := DefaultCertificateSubject() - - provider := NewAttestedProvider(attestationProvider, subject) - - // Create valid client hello with nonce - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - cert, err := provider.GetCertificate(clientHello) - - assert.NoError(t, err) - assert.NotNil(t, cert) - assert.NotEmpty(t, cert.Certificate) - assert.NotNil(t, cert.PrivateKey) - }) - - t.Run("InvalidServerName", func(t *testing.T) { - mockClient := new(mockAttestationClient) - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) - - clientHello := &tls.ClientHelloInfo{ServerName: "invalid-server-name"} - - _, err = provider.GetCertificate(clientHello) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to extract nonce") - }) - - t.Run("AttestationError", func(t *testing.T) { - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) - - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) - - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - _, err = provider.GetCertificate(clientHello) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get attestation") - }) -} - -// TestNewProvider tests the factory function. -func TestNewProvider(t *testing.T) { - mockClient := new(mockAttestationClient) - - t.Run("SelfSignedProvider", func(t *testing.T) { - provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil) - assert.NoError(t, err) - assert.NotNil(t, provider) - }) - - t.Run("CASignedProviderWithSDK", func(t *testing.T) { - mockSDK := sdkmocks.NewSDK(t) - - provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) - assert.NoError(t, err) - assert.NotNil(t, provider) - }) - - t.Run("SelfSignedProviderNilSDK", func(t *testing.T) { - provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", nil) - assert.NoError(t, err) - assert.NotNil(t, provider) - }) - - t.Run("InvalidPlatformType", func(t *testing.T) { - _, err := NewProvider(mockClient, attestation.PlatformType(999), "", "", nil) - assert.Error(t, err) - }) -} - -// TestCertificateVerifier tests certificate verification. -func TestCertificateVerifier(t *testing.T) { - // Setup test policy - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - attestationPB := prepVerifyAttReport(t) - err = setAttestationPolicy(attestationPB, tempDir) - require.NoError(t, err) - - t.Run("NewCertificateVerifier", func(t *testing.T) { - rootCAs := x509.NewCertPool() - verifier := certificateVerifier{rootCAs: rootCAs} - - assert.Equal(t, rootCAs, verifier.rootCAs) - }) - - t.Run("VerifyPeerCertificateNoCertificates", func(t *testing.T) { - verifier := NewCertificateVerifier(nil) - err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce")) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "no certificates provided") - }) - - t.Run("VerifyPeerCertificateInvalidCert", func(t *testing.T) { - verifier := NewCertificateVerifier(nil) - err := verifier.VerifyPeerCertificate([][]byte{[]byte("invalid")}, nil, []byte("nonce")) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to parse x509 certificate") - }) - - t.Run("VerifyPeerCertificateNoAttestationExtension", func(t *testing.T) { - cert := createSelfSignedCert(t) - verifier := NewCertificateVerifier(nil) - - err := verifier.VerifyPeerCertificate([][]byte{cert.Raw}, nil, []byte("nonce")) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "attestation extension not found") - }) -} - -// TestExtractNonceFromSNI tests nonce extraction from SNI. -func TestExtractNonceFromSNI(t *testing.T) { - t.Run("ValidNonce", func(t *testing.T) { - nonce := make([]byte, 64) - _, err := rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - - extractedNonce, err := extractNonceFromSNI(serverName) - - assert.NoError(t, err) - assert.Equal(t, nonce, extractedNonce) - }) - - t.Run("InvalidServerName", func(t *testing.T) { - _, err := extractNonceFromSNI("invalid-server-name") - assert.Error(t, err) - }) - - t.Run("InvalidNonceLength", func(t *testing.T) { - shortNonce := make([]byte, 32) // Too short - serverName := hex.EncodeToString(shortNonce) + ".nonce" - - _, err := extractNonceFromSNI(serverName) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid nonce length") - }) - - t.Run("InvalidHexEncoding", func(t *testing.T) { - serverName := "invalid-hex-encoding.nonce" - - _, err := extractNonceFromSNI(serverName) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to decode nonce") - }) - - t.Run("MissingNonceSuffix", func(t *testing.T) { - nonce := make([]byte, 64) - _, err := rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".invalid" - - _, err = extractNonceFromSNI(serverName) - assert.Error(t, err) - }) -} - -// TestHasNonceSuffix tests the nonce suffix checking. -func TestHasNonceSuffix(t *testing.T) { - t.Run("ValidSuffix", func(t *testing.T) { - assert.True(t, hasNonceSuffix("test.nonce")) - }) - - t.Run("InvalidSuffix", func(t *testing.T) { - assert.False(t, hasNonceSuffix("test.invalid")) - }) - - t.Run("TooShort", func(t *testing.T) { - assert.False(t, hasNonceSuffix(".non")) - }) - - t.Run("EmptyString", func(t *testing.T) { - assert.False(t, hasNonceSuffix("")) - }) -} - -// TestOIDFunctions tests OID-related functions. -func TestPlatformVerifier(t *testing.T) { - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - attestationPB := prepVerifyAttReport(t) - err = setAttestationPolicy(attestationPB, tempDir) - require.NoError(t, err) - - oldPath := attestation.AttestationPolicyPath - t.Cleanup(func() { - attestation.AttestationPolicyPath = oldPath - }) - - cases := []struct { - name string - platformType attestation.PlatformType - expectedError bool - }{ - {"SNPvTPM", attestation.SNPvTPM, false}, - {"Azure", attestation.Azure, false}, - {"TDX", attestation.TDX, false}, // Expected success with new verifier logic - {"Invalid", attestation.PlatformType(999), true}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - verifier, err := platformVerifier(c.platformType) - - if c.expectedError { - assert.Error(t, err) - assert.Nil(t, verifier) - } else { - assert.NoError(t, err) - assert.NotNil(t, verifier) - } - }) - } -} - -func TestGetOID(t *testing.T) { - cases := []struct { - name string - platformType attestation.PlatformType - expectedOID asn1.ObjectIdentifier - expectedError bool - }{ - {"SNPvTPM", attestation.SNPvTPM, SNPvTPMOID, false}, - {"Azure", attestation.Azure, AzureOID, false}, - {"TDX", attestation.TDX, TDXOID, false}, - {"Invalid", attestation.PlatformType(999), nil, true}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - oid, err := OID(c.platformType) - - if c.expectedError { - assert.Error(t, err) - assert.Nil(t, oid) - } else { - assert.NoError(t, err) - assert.Equal(t, c.expectedOID, oid) - } - }) - } -} - -func TestPlatformTypeFromOID(t *testing.T) { - cases := []struct { - name string - oid asn1.ObjectIdentifier - expectedType attestation.PlatformType - expectedError bool - }{ - {"SNPvTPM", SNPvTPMOID, attestation.SNPvTPM, false}, - {"Azure", AzureOID, attestation.Azure, false}, - {"TDX", TDXOID, attestation.TDX, false}, - {"Invalid", asn1.ObjectIdentifier{1, 2, 3}, attestation.PlatformType(0), true}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - pType, err := platformTypeFromOID(c.oid) - - if c.expectedError { - assert.Error(t, err) - assert.Equal(t, attestation.PlatformType(0), pType) - } else { - assert.NoError(t, err) - assert.Equal(t, c.expectedType, pType) - } - }) - } -} - -// TestVerifyCertificateExtension tests certificate extension verification. -func TestVerifyCertificateExtension(t *testing.T) { - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - attestationPB := prepVerifyAttReport(t) - err = setAttestationPolicy(attestationPB, tempDir) - require.NoError(t, err) - - oldPath := attestation.AttestationPolicyPath - t.Cleanup(func() { - attestation.AttestationPolicyPath = oldPath - }) - - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) - require.NoError(t, err) - - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - teeNonce := append(pubKeyDER, nonce...) - hashNonce := sha3.Sum512(teeNonce) - - cases := []struct { - name string - extension []byte - pubKey []byte - nonce []byte - platformType attestation.PlatformType - expectError bool - }{ - { - name: "ValidExtensionSNPvTPM", - extension: hashNonce[:], - pubKey: pubKeyDER, - nonce: nonce, - platformType: attestation.SNPvTPM, - expectError: true, // Expected due to invalid attestation data - }, - { - name: "InvalidPlatformType", - extension: hashNonce[:], - pubKey: pubKeyDER, - nonce: nonce, - platformType: attestation.PlatformType(999), - expectError: true, - }, - { - name: "EmptyExtension", - extension: []byte{}, - pubKey: pubKeyDER, - nonce: nonce, - platformType: attestation.SNPvTPM, - expectError: true, - }, - { - name: "EmptyPublicKey", - extension: hashNonce[:], - pubKey: []byte{}, - nonce: nonce, - platformType: attestation.SNPvTPM, - expectError: true, - }, - { - name: "EmptyNonce", - extension: hashNonce[:], - pubKey: pubKeyDER, - nonce: []byte{}, - platformType: attestation.SNPvTPM, - expectError: true, - }, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - v := certificateVerifier{} - err := v.verifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType) - if c.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -// Helper functions - -func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation { - // Return a dummy attestation report to avoid parsing issues with stale binary - return &sevsnp.Attestation{ - Report: &sevsnp.Report{ - FamilyId: make([]byte, 16), - ImageId: make([]byte, 16), - Measurement: make([]byte, 48), - HostData: make([]byte, 32), - ReportIdMa: make([]byte, 32), - Policy: 0, // Valid policy? Or ignore - }, - } -} - -func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error { - // Create a dummy CoRIM - c := corim.NewUnsignedCorim() - c.SetID("cocos-test-id") - - corimBytes, err := c.ToCBOR() - if err != nil { - return err - } - - policyPath := filepath.Join(policyDirectory, "attestation_policy.json") - - err = os.WriteFile(policyPath, corimBytes, 0o644) - if err != nil { - return nil - } - - attestation.AttestationPolicyPath = policyPath - - return nil -} - -// TestCertificateVerification unified test suite for certificate verification. -func TestCertificateVerification(t *testing.T) { - // Setup common test data - selfSignedCert := createSelfSignedCert(t) - leafCert, rootCert := generateCertificateChain(t) - rootCAs := createCertPool(rootCert) - emptyPool := x509.NewCertPool() - - t.Run("SelfSignedCertificates", func(t *testing.T) { - testCases := []testCase{ - { - name: "ValidSelfSignedCertificate", - cert: selfSignedCert, - rootCAs: nil, - expectError: false, - }, - { - name: "EmptyCertificate", - cert: &x509.Certificate{}, - rootCAs: nil, - expectError: true, - errorMsg: "x509: missing ASN.1 contents; use ParseCertificate", - }, - } - - runCertificateVerificationTests(t, testCases) - }) - - t.Run("CertificateChainVerification", func(t *testing.T) { - testCases := []testCase{ - { - name: "ValidCertificateWithRootCA", - cert: leafCert, - rootCAs: rootCAs, - expectError: false, - }, - { - name: "SelfSignedCertificate", - cert: rootCert, - rootCAs: nil, // Self-signed verification - expectError: false, - }, - { - name: "InvalidCertificateWithEmptyPool", - cert: rootCert, - rootCAs: emptyPool, - expectError: true, - }, - } - - runCertificateVerificationTests(t, testCases) - }) - - t.Run("ATLSPeerCertificateVerification", func(t *testing.T) { - nonce := generateNonce(t) - - testCases := []atlsTestCase{ - { - name: "InvalidCertificateData", - rawCerts: [][]byte{[]byte("invalid cert data")}, - nonce: nonce, - rootCAs: rootCAs, - expectError: true, - errorMsg: "failed to parse x509 certificate", - }, - { - name: "ValidCertificateNoAttestationExtension", - rawCerts: [][]byte{leafCert.Raw}, - nonce: nonce, - rootCAs: rootCAs, - expectError: true, - errorMsg: "attestation extension not found in certificate", - }, - } - - runATLSVerificationTests(t, testCases) - }) -} - -// TestAttestedCAProvider tests the CA-signed certificate provider. -func TestAttestedCAProvider(t *testing.T) { - mockClient := new(mockAttestationClient) - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - subject := DefaultCertificateSubject() - cvmID := "test-cvm-id" - agentToken := "test-token" - - t.Run("NewAttestedCAProvider", func(t *testing.T) { - provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken) - assert.NotNil(t, provider) - }) - - t.Run("SetTTL", func(t *testing.T) { - provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken) - - newTTL := time.Hour * 48 - provider.(*attestedCertificateProvider).SetTTL(newTTL) - - attestedProvider := provider.(*attestedCertificateProvider) - assert.Equal(t, newTTL, attestedProvider.ttl) - }) -} - -// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation. -func TestCASignedCertificateErrors(t *testing.T) { - mockClient := new(mockAttestationClient) - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - subject := DefaultCertificateSubject() - cvmID := "test-cvm-id" - agentToken := "test-token" - - cases := []struct { - name string - certificate string - sdkError error - expectedError string - }{ - {"SDKIssueError", "", errors.NewSDKError(errors.New("SDK error")), "SDK error"}, - {"InvalidPEMWithRemainingData", "-----BEGIN CERTIFICATE-----\\nVGVzdA==\\n-----END CERTIFICATE-----\\nExtra data here", nil, "unexpected remaining data"}, - {"NoPEMBlockFound", "", nil, "no PEM block found"}, - } - - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - mockSDK := sdkmocks.NewSDK(t) - expectedCSR := certs.CSR{CSR: []byte("test-csr")} - mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil)) - mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{Certificate: c.certificate}, c.sdkError) - - provider := NewAttestedCAProvider(attestationProvider, subject, mockSDK, cvmID, agentToken) - attestedProvider := provider.(*attestedCertificateProvider) - - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - extension := pkix.Extension{ - Id: SNPvTPMOID, - Value: []byte("test-data"), - } - - _, err = attestedProvider.generateCASignedCertificate(t.Context(), privateKey, extension) - assert.Error(t, err) - assert.Contains(t, err.Error(), c.expectedError) - }) - } -} - -// TestGetCertificateErrors tests error paths in certificate generation. -func TestGetCertificateErrors(t *testing.T) { - t.Run("InvalidServerNameFormat", func(t *testing.T) { - mockClient := new(mockAttestationClient) - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) - - clientHello := &tls.ClientHelloInfo{ - ServerName: "invalid-format", - } - - _, err = provider.GetCertificate(clientHello) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to extract nonce") - }) - - t.Run("AttestationProviderError", func(t *testing.T) { - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) - - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) - - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - _, err = provider.GetCertificate(clientHello) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get attestation") - }) - - t.Run("CASignedCertificateError", func(t *testing.T) { - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil) - - attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) - require.NoError(t, err) - - mockSDK := sdkmocks.NewSDK(t) - expectedCSR := certs.CSR{CSR: []byte("test-csr")} - sdkErr := errors.NewSDKError(errors.New("CA error")) - mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil)) - mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{}, sdkErr) - - provider := NewAttestedCAProvider(attestationProvider, DefaultCertificateSubject(), mockSDK, "test-cvm", "test-token") - - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - _, err = provider.GetCertificate(clientHello) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to generate certificate") - }) -} - -// TestCertificateVerificationEdgeCases tests edge cases in certificate verification. -func TestCertificateVerificationEdgeCases(t *testing.T) { - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - attestationPB := prepVerifyAttReport(t) - err = setAttestationPolicy(attestationPB, tempDir) - require.NoError(t, err) - - t.Run("VerifyPeerCertificateWithMultipleCerts", func(t *testing.T) { - verifier := NewCertificateVerifier(nil) - cert1 := createSelfSignedCert(t) - cert2 := createSelfSignedCert(t) - nonce := generateNonce(t) - - err := verifier.VerifyPeerCertificate([][]byte{cert1.Raw, cert2.Raw}, nil, nonce) - assert.Error(t, err) - assert.Contains(t, err.Error(), "attestation extension not found") - }) - - t.Run("VerifyAttestationExtensionWithNoExtensions", func(t *testing.T) { - cert := createSelfSignedCert(t) - verifier := certificateVerifier{} - nonce := generateNonce(t) - - err := verifier.verifyAttestationExtension(cert, nonce) - assert.Error(t, err) - assert.Contains(t, err.Error(), "attestation extension not found") - }) - - t.Run("VerifyAttestationExtensionWithWrongOID", func(t *testing.T) { - wrongOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5} - extension := pkix.Extension{ - Id: wrongOID, - Value: []byte("test-data"), - } - - cert := generateTestCertificateWithExtensions(t, []pkix.Extension{extension}) - verifier := certificateVerifier{} - nonce := generateNonce(t) - - err := verifier.verifyAttestationExtension(cert, nonce) - assert.Error(t, err) - assert.Contains(t, err.Error(), "attestation extension not found") - }) - - t.Run("VerifyCertificateExtensionPlatformVerifierError", func(t *testing.T) { - verifier := certificateVerifier{} - invalidPlatformType := attestation.PlatformType(999) - - err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType) - assert.Error(t, err) - // The error occurs during EAT token decoding before platform type validation - assert.Contains(t, err.Error(), "failed to decode EAT token") - }) -} - -// TestCertificateWithAttestationExtension tests certificates with attestation extensions. -func TestCertificateWithAttestationExtension(t *testing.T) { - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - attestationPB := prepVerifyAttReport(t) - err = setAttestationPolicy(attestationPB, tempDir) - require.NoError(t, err) - - t.Run("CertificateWithValidAttestationExtension", func(t *testing.T) { - // Create certificate with attestation extension - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - _, err = x509.MarshalPKIXPublicKey(&privateKey.PublicKey) - require.NoError(t, err) - - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - extension := pkix.Extension{ - Id: SNPvTPMOID, - Value: []byte("test-attestation-data"), - } - - template := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Org"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - ExtraExtensions: []pkix.Extension{extension}, - } - - certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) - require.NoError(t, err) - - cert, err := x509.ParseCertificate(certDER) - require.NoError(t, err) - - verifier := certificateVerifier{} - err = verifier.verifyAttestationExtension(cert, nonce) - - // Expect error due to invalid attestation data, but extension should be found - assert.Error(t, err) - assert.NotContains(t, err.Error(), "attestation extension not found") - }) -} - -// TestIntegrationScenarios tests end-to-end integration scenarios. -func TestIntegrationScenarios(t *testing.T) { - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - attestationPB := prepVerifyAttReport(t) - err = setAttestationPolicy(attestationPB, tempDir) - require.NoError(t, err) - - t.Run("FullSelfSignedFlow", func(t *testing.T) { - // Setup mock client - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) - - // Create provider - provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil) - require.NoError(t, err) - - // Generate certificate - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - cert, err := provider.GetCertificate(clientHello) - assert.NoError(t, err) - assert.NotNil(t, cert) - assert.NotEmpty(t, cert.Certificate) - assert.NotNil(t, cert.PrivateKey) - - // Verify the generated certificate - parsedCert, err := x509.ParseCertificate(cert.Certificate[0]) - require.NoError(t, err) - - // Check for attestation extension - found := false - for _, ext := range parsedCert.Extensions { - if ext.Id.Equal(SNPvTPMOID) { - found = true - break - } - } - assert.True(t, found, "Attestation extension should be present") - }) - - t.Run("FullCASignedFlow", func(t *testing.T) { - mockSDK := sdkmocks.NewSDK(t) - expectedCSR := certs.CSR{CSR: []byte("test-csr")} - expectedCert := certssdk.Certificate{Certificate: generateTestCertPEM(t)} - mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil)) - mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil)) - - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) - - provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) - require.NoError(t, err) - - nonce := make([]byte, 64) - _, err = rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - cert, err := provider.GetCertificate(clientHello) - require.NoError(t, err) - require.NotNil(t, cert) - require.NotEmpty(t, cert.Certificate) - require.NotNil(t, cert.PrivateKey) - - parsedCert, err := x509.ParseCertificate(cert.Certificate[0]) - require.NoError(t, err) - - assert.NotNil(t, parsedCert.Subject) - - mockClient.AssertExpectations(t) - mockSDK.AssertExpectations(t) - }) -} - -// TestConcurrentAccess tests concurrent access scenarios. -func TestConcurrentAccess(t *testing.T) { - mockClient := new(mockAttestationClient) - mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) - - provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil) - require.NoError(t, err) - - const numGoroutines = 10 - errors := make(chan error, numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func(id int) { - nonce := make([]byte, 64) - _, err := rand.Read(nonce) - if err != nil { - errors <- err - return - } - - serverName := hex.EncodeToString(nonce) + ".nonce" - clientHello := &tls.ClientHelloInfo{ServerName: serverName} - - cert, err := provider.GetCertificate(clientHello) - if err != nil { - errors <- err - return - } - - if cert == nil { - errors <- fmt.Errorf("nil certificate returned for goroutine %d", id) - return - } - - errors <- nil - }(i) - } - - // Collect results - for i := 0; i < numGoroutines; i++ { - err := <-errors - assert.NoError(t, err) - } -} - -// TestEdgeCasesAndBoundaries tests edge cases and boundary conditions. -func TestEdgeCasesAndBoundaries(t *testing.T) { - t.Run("LargeNonce", func(t *testing.T) { - largeNonce := make([]byte, 1024) // Much larger than expected - _, err := rand.Read(largeNonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(largeNonce) + ".nonce" - _, err = extractNonceFromSNI(serverName) - assert.Error(t, err) // Should fail due to invalid length - }) - - t.Run("MaxLengthServerName", func(t *testing.T) { - // Create very long server name - nonce := make([]byte, 64) - _, err := rand.Read(nonce) - require.NoError(t, err) - - longPrefix := strings.Repeat("a", 200) - serverName := longPrefix + hex.EncodeToString(nonce) + ".nonce" - - _, err = extractNonceFromSNI(serverName) - assert.Error(t, err) // Should fail due to invalid format - }) - - t.Run("MinimalValidNonce", func(t *testing.T) { - nonce := make([]byte, 64) // Exactly the required length - _, err := rand.Read(nonce) - require.NoError(t, err) - - serverName := hex.EncodeToString(nonce) + ".nonce" - extractedNonce, err := extractNonceFromSNI(serverName) - - assert.NoError(t, err) - assert.Equal(t, nonce, extractedNonce) - }) -} - -// Unified test case structures. -type testCase struct { - name string - cert *x509.Certificate - rootCAs *x509.CertPool - expectError bool - errorMsg string -} - -type atlsTestCase struct { - name string - rawCerts [][]byte - nonce []byte - rootCAs *x509.CertPool - expectError bool - errorMsg string -} - -// Unified test runners. -func runCertificateVerificationTests(t *testing.T, testCases []testCase) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - v := certificateVerifier{ - rootCAs: tc.rootCAs, - } - err := v.verifyCertificateSignature(tc.cert) - - if tc.expectError { - assert.Error(t, err) - if tc.errorMsg != "" { - if tc.errorMsg == "x509: missing ASN.1 contents; use ParseCertificate" { - // For specific error matching - assert.True(t, errors.Contains(err, errors.New(tc.errorMsg)), - fmt.Sprintf("expected error %q, got %v", tc.errorMsg, err)) - } else { - assert.Contains(t, err.Error(), tc.errorMsg) - } - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func runATLSVerificationTests(t *testing.T, testCases []atlsTestCase) { - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - v := certificateVerifier{ - rootCAs: tc.rootCAs, - } - err := v.VerifyPeerCertificate(tc.rawCerts, nil, tc.nonce) - - if tc.expectError { - assert.Error(t, err) - if tc.errorMsg != "" { - assert.Contains(t, err.Error(), tc.errorMsg) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -// Unified certificate creation utilities. -func createSelfSignedCert(t *testing.T) *x509.Certificate { - privateKey := generateRSAKey(t) - - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Org"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - return createCertificateFromTemplate(t, &template, &template, &privateKey.PublicKey, privateKey) -} - -func generateCertificateChain(t *testing.T) (leafCert, rootCert *x509.Certificate) { - // Generate root certificate - rootKey := generateRSAKey(t) - rootTemplate := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Root CA"}, - Country: []string{"US"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - IsCA: true, - } - - rootCert = createCertificateFromTemplate(t, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey) - - // Generate leaf certificate signed by root - leafKey := generateRSAKey(t) - leafTemplate := x509.Certificate{ - SerialNumber: big.NewInt(2), - Subject: pkix.Name{ - Organization: []string{"Test Leaf"}, - Country: []string{"US"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - } - - leafCert = createCertificateFromTemplate(t, &leafTemplate, &rootTemplate, &leafKey.PublicKey, rootKey) - - return leafCert, rootCert -} - -// Helper functions for consistency. -func generateRSAKey(t *testing.T) *rsa.PrivateKey { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - return privateKey -} - -func createCertificateFromTemplate(t *testing.T, template, parent *x509.Certificate, pub interface{}, priv interface{}) *x509.Certificate { - certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv) - require.NoError(t, err) - - cert, err := x509.ParseCertificate(certDER) - require.NoError(t, err) - - return cert -} - -func createCertPool(certs ...*x509.Certificate) *x509.CertPool { - pool := x509.NewCertPool() - for _, cert := range certs { - pool.AddCert(cert) - } - return pool -} - -func generateNonce(t *testing.T) []byte { - nonce := make([]byte, 64) - _, err := rand.Read(nonce) - require.NoError(t, err) - return nonce -} diff --git a/pkg/atls/attestation_provider.go b/pkg/atls/attestation_provider.go deleted file mode 100644 index 9a3fbbba..00000000 --- a/pkg/atls/attestation_provider.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package atls - -import ( - "context" - "encoding/asn1" - "fmt" - - "github.com/ultravioletrs/cocos/pkg/attestation" - attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" - "golang.org/x/crypto/sha3" -) - -// AttestationProvider defines the interface for platform attestation operations. -type AttestationProvider interface { - Attest(pubKey []byte, nonce []byte) ([]byte, error) - OID() asn1.ObjectIdentifier - PlatformType() attestation.PlatformType -} - -// PlatformAttestationProvider handles platform attestation operations. -type platformAttestationProvider struct { - attClient attestation_client.Client - oid asn1.ObjectIdentifier - platformType attestation.PlatformType -} - -// NewAttestationProvider creates a new attestation provider for the given platform type. -func NewAttestationProvider(attClient attestation_client.Client, platformType attestation.PlatformType) (AttestationProvider, error) { - oid, err := OID(platformType) - if err != nil { - return nil, fmt.Errorf("failed to get OID: %w", err) - } - - return &platformAttestationProvider{ - attClient: attClient, - oid: oid, - platformType: platformType, - }, nil -} - -func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) { - teeNonce := append(pubKey, nonce...) - hashNonce := sha3.Sum512(teeNonce) - - var reportData [64]byte - copy(reportData[:], hashNonce[:]) - - var nonceArray [32]byte - copy(nonceArray[:], hashNonce[:32]) - - // Get signed EAT token from attestation service - // The attestation service maintains a persistent signing key and returns a pre-signed token - eatToken, err := p.attClient.GetAttestation(context.Background(), reportData, nonceArray, p.platformType) - if err != nil { - return nil, fmt.Errorf("failed to get attestation from service: %w", err) - } - - return eatToken, nil -} - -func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier { - return p.oid -} - -func (p *platformAttestationProvider) PlatformType() attestation.PlatformType { - return p.platformType -} - -func OID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) { - switch platformType { - case attestation.SNPvTPM: - return SNPvTPMOID, nil - case attestation.Azure: - return AzureOID, nil - case attestation.TDX: - return TDXOID, nil - default: - return nil, fmt.Errorf("unsupported platform type: %d", platformType) - } -} diff --git a/pkg/atls/certificate_provider.go b/pkg/atls/certificate_provider.go deleted file mode 100644 index 75e9ccfb..00000000 --- a/pkg/atls/certificate_provider.go +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package atls - -import ( - "context" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "math/big" - "strings" - "time" - - "github.com/absmach/certs" - sdk "github.com/absmach/certs/sdk" - "github.com/ultravioletrs/cocos/pkg/attestation" - attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" -) - -// CertificateProvider defines the interface for providing TLS certificates. -type CertificateProvider interface { - GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) -} - -// AttestedCertificateProvider provides attested TLS certificates. -type attestedCertificateProvider struct { - attestationProvider AttestationProvider - certsSDK sdk.SDK - agentToken string - subject CertificateSubject - useCA bool - cvmID string - ttl time.Duration - notAfterYears int -} - -// NewAttestedProvider creates a new attested certificate provider for self-signed certificates. -func NewAttestedProvider( - attestationProvider AttestationProvider, - subject CertificateSubject, -) CertificateProvider { - return &attestedCertificateProvider{ - attestationProvider: attestationProvider, - subject: subject, - useCA: false, - notAfterYears: defaultNotAfterYears, - } -} - -// NewAttestedCAProvider creates a new attested certificate provider for CA-signed certificates. -func NewAttestedCAProvider( - attestationProvider AttestationProvider, - subject CertificateSubject, - certsSDK sdk.SDK, cvmID, agentToken string, -) CertificateProvider { - return &attestedCertificateProvider{ - attestationProvider: attestationProvider, - subject: subject, - certsSDK: certsSDK, - agentToken: agentToken, - useCA: true, - cvmID: cvmID, - ttl: time.Hour * 24 * 365, // Default 1 year - } -} - -// SetTTL sets the certificate TTL for CA-signed certificates. -func (p *attestedCertificateProvider) SetTTL(ttl time.Duration) { - p.ttl = ttl -} - -func (p *attestedCertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return nil, fmt.Errorf("failed to generate private key: %w", err) - } - - pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal public key: %w", err) - } - - nonce, err := extractNonceFromSNI(clientHello.ServerName) - if err != nil { - return nil, fmt.Errorf("failed to extract nonce: %w", err) - } - - attestationData, err := p.attestationProvider.Attest(pubKeyDER, nonce) - if err != nil { - return nil, fmt.Errorf("failed to get attestation: %w", err) - } - - extension := pkix.Extension{ - Id: p.attestationProvider.OID(), - Value: attestationData, - } - - var certDERBytes []byte - if p.useCA { - certDERBytes, err = p.generateCASignedCertificate(clientHello.Context(), privateKey, extension) - } else { - certDERBytes, err = p.generateSelfSignedCertificate(privateKey, extension) - } - - if err != nil { - return nil, fmt.Errorf("failed to generate certificate: %w", err) - } - - return &tls.Certificate{ - Certificate: [][]byte{certDERBytes}, - PrivateKey: privateKey, - }, nil -} - -func (p *attestedCertificateProvider) generateSelfSignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) { - certTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(time.Now().Unix()), - Subject: pkix.Name{ - Organization: []string{p.subject.Organization}, - Country: []string{p.subject.Country}, - Province: []string{p.subject.Province}, - Locality: []string{p.subject.Locality}, - StreetAddress: []string{p.subject.StreetAddress}, - PostalCode: []string{p.subject.PostalCode}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(p.notAfterYears, 0, 0), - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - ExtraExtensions: []pkix.Extension{extension}, - } - - return x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey) -} - -func (p *attestedCertificateProvider) generateCASignedCertificate(ctx context.Context, privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) { - csrMetadata := certs.CSRMetadata{ - Organization: []string{p.subject.Organization}, - Country: []string{p.subject.Country}, - CommonName: p.subject.CommonName, - Province: []string{p.subject.Province}, - Locality: []string{p.subject.Locality}, - StreetAddress: []string{p.subject.StreetAddress}, - PostalCode: []string{p.subject.PostalCode}, - ExtraExtensions: []pkix.Extension{extension}, - } - - csr, sdkerr := p.certsSDK.CreateCSR(ctx, csrMetadata, privateKey) - if sdkerr != nil { - return nil, fmt.Errorf("failed to create CSR: %w", sdkerr) - } - - cert, err := p.certsSDK.IssueFromCSRInternal(ctx, p.cvmID, p.ttl.String(), string(csr.CSR), p.agentToken) - if err != nil { - return nil, err - } - - cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n") - block, rest := pem.Decode([]byte(cleanCertificateString)) - - if len(rest) != 0 { - return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data") - } - if block == nil { - return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found") - } - - return block.Bytes, nil -} - -func NewProvider(attClient attestation_client.Client, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) { - attestationProvider, err := NewAttestationProvider(attClient, platformType) - if err != nil { - return nil, fmt.Errorf("failed to create attestation provider: %w", err) - } - - subject := DefaultCertificateSubject() - - if certsSDK != nil { - return NewAttestedCAProvider(attestationProvider, subject, certsSDK, cvmID, agentToken), nil - } - - return NewAttestedProvider(attestationProvider, subject), nil -} diff --git a/pkg/atls/certificate_verifier.go b/pkg/atls/certificate_verifier.go deleted file mode 100644 index 49928c4b..00000000 --- a/pkg/atls/certificate_verifier.go +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package atls - -import ( - "crypto/x509" - "encoding/asn1" - "fmt" - "log/slog" - "os" - "time" - - "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/azure" - "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/tdx" - "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" - "github.com/veraison/corim/corim" - "golang.org/x/crypto/sha3" -) - -type CertificateVerifier interface { - VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte) error -} - -// CertificateVerifier handles certificate verification operations. -type certificateVerifier struct { - rootCAs *x509.CertPool - verifierProvider func(attestation.PlatformType) (attestation.Verifier, error) -} - -func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier { - return &certificateVerifier{ - rootCAs: rootCAs, - verifierProvider: platformVerifier, - } -} - -func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error { - slog.Debug("Starting peer certificate verification for aTLS") - if len(rawCerts) == 0 { - err := fmt.Errorf("no certificates provided") - slog.Error("aTLS handshake failed", "reason", err.Error()) - return err - } - - cert, err := x509.ParseCertificate(rawCerts[0]) - if err != nil { - err = fmt.Errorf("failed to parse x509 certificate: %w", err) - slog.Error("aTLS handshake failed", "reason", err.Error()) - return err - } - slog.Debug("Successfully parsed peer x509 certificate", "subject", cert.Subject.String()) - - if err := v.verifyCertificateSignature(cert); err != nil { - err = fmt.Errorf("certificate signature verification failed: %w", err) - slog.Error("aTLS handshake failed", "reason", err.Error()) - return err - } - slog.Debug("Successfully verified peer certificate signature") - - err = v.verifyAttestationExtension(cert, nonce) - if err != nil { - slog.Error("aTLS handshake failed", "reason", err.Error()) - return err - } - slog.Debug("Successfully verified aTLS attestation extension") - return nil -} - -func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error { - rootCAs := v.rootCAs - if rootCAs == nil { - rootCAs = x509.NewCertPool() - rootCAs.AddCert(cert) - } - - opts := x509.VerifyOptions{ - Roots: rootCAs, - CurrentTime: time.Now(), - } - - _, err := cert.Verify(opts) - return err -} - -func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error { - for _, ext := range cert.Extensions { - if platformType, err := platformTypeFromOID(ext.Id); err == nil { - slog.Debug("Found attestation extension in peer certificate", "platform_type", platformType) - pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - if err != nil { - return fmt.Errorf("failed to marshal public key: %w", err) - } - return v.verifyCertificateExtension(ext.Value, pubKeyDER, nonce, platformType) - } - } - return fmt.Errorf("attestation extension not found in certificate") -} - -func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error { - // Decode EAT token from certificate extension - // Note: We don't have the public key for verification here, so we decode without verification - // The signature was created by the attester, and we trust the TEE hardware verification - claims, err := eat.DecodeCBOR(extension, nil) - if err != nil { - return fmt.Errorf("failed to decode EAT token: %w", err) - } - - // Verify nonce matches - teeNonce := append(pubKey, nonce...) - hashNonce := sha3.Sum512(teeNonce) - // The attestation provider truncates the 64-byte hash to 32 bytes for the EAT token nonce claim - // This matches the Attestation Service API and standard cryptographic nonce sizes. - expectedNonce := hashNonce[:32] - - // Compare nonces (EAT nonce should match our computed nonce) - if len(claims.Nonce) != len(expectedNonce) { - err := fmt.Errorf("nonce length mismatch: expected %d, got %d", len(expectedNonce), len(claims.Nonce)) - slog.Error("aTLS handshake failed", "reason", err.Error()) - return err - } - - nonceMatch := true - for i := range claims.Nonce { - if claims.Nonce[i] != expectedNonce[i] { - nonceMatch = false - break - } - } - - if !nonceMatch { - err := fmt.Errorf("nonce mismatch in EAT token") - slog.Error("aTLS handshake failed", "reason", err.Error()) - return err - } - - // Get platform verifier - verifier, err := v.verifierProvider(platformType) - if err != nil { - return fmt.Errorf("failed to get platform verifier: %w", err) - } - - // Load and parse CoRIM - if attestation.AttestationPolicyPath == "" { - return fmt.Errorf("attestation policy path is not set") - } - - corimBytes, err := os.ReadFile(attestation.AttestationPolicyPath) - if err != nil { - return fmt.Errorf("failed to read CoRIM file: %w", err) - } - - // Try extracting from COSE Sign1 first - var unsignedCorim *corim.UnsignedCorim - - var sc corim.SignedCorim - if err := sc.FromCOSE(corimBytes); err == nil { - // It's a COSE Sign1 message - unsignedCorim = &sc.UnsignedCorim - } else { - // Try parsing as unsigned CoRIM directly - var uc corim.UnsignedCorim - if err := uc.FromCBOR(corimBytes); err != nil { - return fmt.Errorf("failed to parse CoRIM (tried both signed and unsigned): %w", err) - } - unsignedCorim = &uc - } - - // Re-wrap in Corim struct expected by Verifiers - // Since verifiers expect the struct from the removed internal package, - // we need to update verifiers to accept veraison/corim types - // For now, we pass the unsignedCorim directly - if err = verifier.VerifyWithCoRIM(claims.RawReport, unsignedCorim); err != nil { - return fmt.Errorf("failed to verify attestation with CoRIM: %w", err) - } - - slog.Debug("CoRIM verification passed for aTLS peer certificate") - return nil -} - -func platformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) { - switch { - case oid.Equal(SNPvTPMOID): - return attestation.SNPvTPM, nil - case oid.Equal(AzureOID): - return attestation.Azure, nil - case oid.Equal(TDXOID): - return attestation.TDX, nil - default: - return 0, fmt.Errorf("unsupported OID: %v", oid) - } -} - -func platformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) { - var verifier attestation.Verifier - - switch platformType { - case attestation.SNPvTPM: - verifier = vtpm.NewVerifier(nil) - case attestation.Azure: - verifier = azure.NewVerifier(nil) - case attestation.TDX: - verifier = tdx.NewVerifier() - default: - return nil, fmt.Errorf("unsupported platform type: %d", platformType) - } - - return verifier, nil -} diff --git a/pkg/atls/certificate_verifier_test.go b/pkg/atls/certificate_verifier_test.go deleted file mode 100644 index cdf7ed4c..00000000 --- a/pkg/atls/certificate_verifier_test.go +++ /dev/null @@ -1,338 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -package atls - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/x509" - "crypto/x509/pkix" - "math/big" - "os" - "path/filepath" - "testing" - "time" - - "github.com/fxamacker/cbor/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/veraison/corim/corim" - "golang.org/x/crypto/sha3" -) - -type mockVerifier struct { - verifyWithCoRIMFunc func(report []byte, manifest *corim.UnsignedCorim) error -} - -func (m *mockVerifier) VerifyWithCoRIM(report []byte, manifest *corim.UnsignedCorim) error { - if m.verifyWithCoRIMFunc != nil { - return m.verifyWithCoRIMFunc(report, manifest) - } - return nil -} - -func TestVerifyPeerCertificate_Success(t *testing.T) { - // Setup keys and cert templates - caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "Test CA"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) - require.NoError(t, err) - caCert, err := x509.ParseCertificate(caCertDER) - require.NoError(t, err) - - rootCAs := x509.NewCertPool() - rootCAs.AddCert(caCert) - - // Create verifier with mock platform verifier - verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) - verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) { - return &mockVerifier{ - verifyWithCoRIMFunc: func(report []byte, manifest *corim.UnsignedCorim) error { - return nil - }, - }, nil - } - - peerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) - - // Prepare EAT Claims - nonce := []byte("test-nonce") - peerPubKeyDER, err := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) - require.NoError(t, err) - - teeNonce := append(peerPubKeyDER, nonce...) - hashNonce := sha3.Sum512(teeNonce) - - claims := eat.EATClaims{ - Nonce: hashNonce[:32], - RawReport: []byte("mock-report"), - } - eatBytes, err := cbor.Marshal(claims) - require.NoError(t, err) - - // Create Peer Cert with EAT extension - peerTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(2), - Subject: pkix.Name{CommonName: "Test Peer"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - KeyUsage: x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - ExtraExtensions: []pkix.Extension{ - { - Id: SNPvTPMOID, // Use SNPvTPMOID as default testing OID - Value: eatBytes, - }, - }, - } - peerCertDER, err := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - require.NoError(t, err) - - // Create dummy CoRIM file - tempDir, err := os.MkdirTemp("", "policy") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - c := corim.NewUnsignedCorim() - c.SetID("cocos-test-id") - corimBytes, err := c.ToCBOR() - require.NoError(t, err) - - policyPath := filepath.Join(tempDir, "attestation_policy.json") - err = os.WriteFile(policyPath, corimBytes, 0o644) - require.NoError(t, err) - - oldPolicyPath := attestation.AttestationPolicyPath - attestation.AttestationPolicyPath = policyPath - t.Cleanup(func() { - attestation.AttestationPolicyPath = oldPolicyPath - }) - - err = verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce) - assert.NoError(t, err) -} - -func TestVerifyPeerCertificate_AzureSuccess(t *testing.T) { - caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "Test CA"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) - caCert, _ := x509.ParseCertificate(caCertDER) - - rootCAs := x509.NewCertPool() - rootCAs.AddCert(caCert) - - verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) - verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) { - return &mockVerifier{}, nil - } - - peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - nonce := []byte("test-nonce") - peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) - teeNonce := append(peerPubKeyDER, nonce...) - hashNonce := sha3.Sum512(teeNonce) - - claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")} - eatBytes, _ := cbor.Marshal(claims) - - peerTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(2), - Subject: pkix.Name{CommonName: "Azure Peer"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - ExtraExtensions: []pkix.Extension{{Id: AzureOID, Value: eatBytes}}, - } - peerCertDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - - tempDir := t.TempDir() - c := corim.NewUnsignedCorim() - c.SetID("cocos-test-id") - corimBytes, _ := c.ToCBOR() - policyPath := filepath.Join(tempDir, "policy.cbor") - _ = os.WriteFile(policyPath, corimBytes, 0o644) - - oldPolicyPath := attestation.AttestationPolicyPath - attestation.AttestationPolicyPath = policyPath - t.Cleanup(func() { attestation.AttestationPolicyPath = oldPolicyPath }) - - err := verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce) - assert.NoError(t, err) -} - -func TestVerifyPeerCertificate_TDXSuccess(t *testing.T) { - caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "Test CA"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) - caCert, _ := x509.ParseCertificate(caCertDER) - - rootCAs := x509.NewCertPool() - rootCAs.AddCert(caCert) - - verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) - verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) { - return &mockVerifier{}, nil - } - - peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - nonce := []byte("test-nonce") - peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) - teeNonce := append(peerPubKeyDER, nonce...) - hashNonce := sha3.Sum512(teeNonce) - - claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")} - eatBytes, _ := cbor.Marshal(claims) - - peerTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(3), - Subject: pkix.Name{CommonName: "TDX Peer"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - ExtraExtensions: []pkix.Extension{{Id: TDXOID, Value: eatBytes}}, - } - peerCertDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - - tempDir := t.TempDir() - c := corim.NewUnsignedCorim() - c.SetID("cocos-test-id") - corimBytes, _ := c.ToCBOR() - policyPath := filepath.Join(tempDir, "policy.cbor") - _ = os.WriteFile(policyPath, corimBytes, 0o644) - - oldPolicyPath := attestation.AttestationPolicyPath - attestation.AttestationPolicyPath = policyPath - t.Cleanup(func() { attestation.AttestationPolicyPath = oldPolicyPath }) - - err := verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce) - assert.NoError(t, err) -} - -func TestVerifyPeerCertificate_Failures_More(t *testing.T) { - caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "Test CA"}, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(1 * time.Hour), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) - caCert, _ := x509.ParseCertificate(caCertDER) - rootCAs := x509.NewCertPool() - rootCAs.AddCert(caCert) - - verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) - - // Case 1: Invalid OID - peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - peerTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(4), - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{{Id: []int{1, 2, 3}, Value: []byte("val")}}, - } - certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - err := verifier.VerifyPeerCertificate([][]byte{certDER}, nil, []byte("nonce")) - assert.ErrorContains(t, err, "attestation extension not found") - - // Case 2: Policy path not set - attestation.AttestationPolicyPath = "" - peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) - nonce := []byte("nonce") - teeNonce := append(peerPubKeyDER, nonce...) - hashNonce := sha3.Sum512(teeNonce) - claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")} - eatBytes, _ := cbor.Marshal(claims) - peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}} - certDERWithExt, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - - err = verifier.VerifyPeerCertificate([][]byte{certDERWithExt}, nil, nonce) - assert.ErrorContains(t, err, "attestation policy path is not set") -} - -func TestVerifyPeerCertificate_Failures_Ext(t *testing.T) { - rootCAs := x509.NewCertPool() - verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) - - // Case 1: No certificates - err := verifier.VerifyPeerCertificate([][]byte{}, nil, []byte("nonce")) - assert.ErrorContains(t, err, "no certificates provided") - - // Case 2: Nonce length mismatch - peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - nonce := []byte("nonce") - claims := eat.EATClaims{Nonce: []byte("short"), RawReport: []byte("rep")} - eatBytes, _ := cbor.Marshal(claims) - - caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - caTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{CommonName: "CA"}, - IsCA: true, - BasicConstraintsValid: true, - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - } - caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) - caCert, _ := x509.ParseCertificate(caCertDER) - rootCAs.AddCert(caCert) - - peerTemplate := &x509.Certificate{ - SerialNumber: big.NewInt(5), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}, - } - certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce) - assert.ErrorContains(t, err, "nonce length mismatch") - - // Case 3: Nonce mismatch - peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) - wrongTeeNonce := append(peerPubKeyDER, []byte("wrong-nonce")...) - wrongHashNonce := sha3.Sum512(wrongTeeNonce) - claims.Nonce = wrongHashNonce[:32] - eatBytes, _ = cbor.Marshal(claims) - peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}} - certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce) - assert.ErrorContains(t, err, "nonce mismatch in EAT token") - - // Case 4: Invalid EAT (CBOR decoder failure) - peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: []byte("invalid-cbor")}} - certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) - err = verifier.VerifyPeerCertificate([][]byte{certDER}, nil, nonce) - assert.ErrorContains(t, err, "failed to decode EAT token") -} diff --git a/pkg/atls/ea/authenticator.go b/pkg/atls/ea/authenticator.go new file mode 100644 index 00000000..93551f66 --- /dev/null +++ b/pkg/atls/ea/authenticator.go @@ -0,0 +1,370 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + + eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" +) + +var ( + ErrTruncated = errors.New("ea: truncated input") + ErrInvalidLength = errors.New("ea: invalid length") + ErrUnsupportedHandshakeType = errors.New("ea: unsupported handshake type") + ErrNotTLS13 = errors.New("ea: not TLS 1.3") + ErrUnknownCipherSuite = errors.New("ea: unknown cipher suite") + ErrContextReuse = errors.New("ea: certificate_request_context already used") + ErrInvalidRole = errors.New("ea: invalid authenticator role") + + ErrUnsupportedSignatureScheme = errors.New("ea: unsupported signature scheme") + ErrSignatureMismatch = errors.New("ea: CertificateVerify signature mismatch") + ErrFinishedMismatch = errors.New("ea: Finished MAC mismatch") + ErrContextMismatch = errors.New("ea: certificate_request_context mismatch") + ErrBadRequest = errors.New("ea: bad authenticator request") +) + +type ValidationResult struct { + Context []byte + Chain []*x509.Certificate + CMWAttestation []byte + Attestation *eaattestation.VerifiedPayload + Empty bool +} + +func CreateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) { + return createAuthenticator(nil, st, role, req, nil, identity, leafEntryExtensions) +} + +func CreateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) { + return createAuthenticator(nil, st, role, req, policy, identity, leafEntryExtensions) +} + +func createAuthenticator(session *Session, st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) { + if st.Version != tls.VersionTLS13 { + return nil, ErrNotTLS13 + } + if req == nil && role != RoleServer { + return nil, ErrInvalidRole + } + if err := validateCertificateEntryExtensions(leafEntryExtensions, req, policy); err != nil { + return nil, err + } + emptyAuthenticator := len(identity.Certificate) == 0 && identity.PrivateKey == nil + if !emptyAuthenticator && (len(identity.Certificate) == 0 || identity.PrivateKey == nil) { + return nil, ErrBadRequest + } + + var reqBytes []byte + var offered []uint16 + var ctx []byte + + if req != nil { + var err error + reqBytes, err = req.Marshal() + if err != nil { + return nil, err + } + ctx = append([]byte(nil), req.Context...) + if schemes, ok := req.SignatureSchemes(); ok { + offered = schemes + } else { + return nil, fmt.Errorf("%w: missing signature_algorithms", ErrBadRequest) + } + } else { + c, err := NewRandomContext(32) + if err != nil { + return nil, err + } + ctx = c + } + + hsCtx, h, err := ExportHandshakeContext(st, role) + if err != nil { + return nil, err + } + fk, _, err := ExportFinishedKey(st, role) + if err != nil { + return nil, err + } + + if emptyAuthenticator { + if req == nil { + return nil, ErrBadRequest + } + certBytes, err := (CertificateMessage{Context: ctx}).Marshal() + if err != nil { + return nil, err + } + th := hashConcat(h, hsCtx, reqBytes, certBytes) + verifyData := hmacSum(h, fk, th) + finBytes, err := (FinishedMessage{VerifyData: verifyData}).Marshal() + if err != nil { + return nil, err + } + if err := session.MarkContextUsed(ctx); err != nil { + return nil, err + } + return finBytes, nil + } + + scheme, err := chooseSignatureScheme(identity.PrivateKey, offered) + if err != nil { + return nil, err + } + if req == nil && !policyPermitsSignatureScheme(policy, scheme) { + return nil, ErrUnsupportedSignatureScheme + } + + entries := make([]CertificateEntry, 0, len(identity.Certificate)) + for i, der := range identity.Certificate { + exts := []Extension(nil) + if i == 0 && len(leafEntryExtensions) > 0 { + exts = leafEntryExtensions + } + entries = append(entries, CertificateEntry{CertDER: der, Extensions: exts}) + } + certBytes, err := (CertificateMessage{Context: ctx, Entries: entries}).Marshal() + if err != nil { + return nil, err + } + + th1 := hashConcat(h, hsCtx, reqBytes, certBytes) + sig, err := signCertVerify(identity.PrivateKey, scheme, th1) + if err != nil { + return nil, err + } + cvBytes, err := (CertificateVerifyMessage{Algorithm: scheme, Signature: sig}).Marshal() + if err != nil { + return nil, err + } + + th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes) + verifyData := hmacSum(h, fk, th2) + finBytes, err := (FinishedMessage{VerifyData: verifyData}).Marshal() + if err != nil { + return nil, err + } + + if err := session.MarkContextUsed(ctx); err != nil { + return nil, err + } + return append(append(certBytes, cvBytes...), finBytes...), nil +} + +func ValidateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) { + return validateAuthenticator(nil, st, role, req, nil, nil, authBytes, verifyOpts) +} + +func ValidateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) { + return validateAuthenticator(nil, st, role, req, policy, nil, authBytes, verifyOpts) +} + +func ValidateAuthenticatorWithAttestation(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) { + return validateAuthenticator(nil, st, role, req, nil, &attPolicy, authBytes, verifyOpts) +} + +func ValidateAuthenticatorWithPolicies(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) { + return validateAuthenticator(nil, st, role, req, policy, &attPolicy, authBytes, verifyOpts) +} + +func validateAuthenticator(session *Session, st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, attPolicy *eaattestation.VerificationPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) { + if st.Version != tls.VersionTLS13 { + return nil, ErrNotTLS13 + } + + hsCtx, h, err := ExportHandshakeContext(st, role) + if err != nil { + return nil, err + } + fk, _, err := ExportFinishedKey(st, role) + if err != nil { + return nil, err + } + + var reqBytes []byte + var offered []uint16 + var reqCtx []byte + if req != nil { + reqCtx = req.Context + reqBytes, err = req.Marshal() + if err != nil { + return nil, err + } + if schemes, ok := req.SignatureSchemes(); ok { + offered = schemes + } else { + return nil, fmt.Errorf("%w: missing signature_algorithms", ErrBadRequest) + } + } + + firstHm, rest, err := UnmarshalHandshakeMessage(authBytes) + if err != nil { + return nil, err + } + if firstHm.Type == HandshakeTypeFinished { + if req == nil || len(rest) != 0 { + return nil, ErrUnsupportedHandshakeType + } + finBytes, _ := MarshalHandshakeMessage(firstHm) + finMsg, _, err := UnmarshalFinishedMessage(finBytes) + if err != nil { + return nil, err + } + certBytes, err := (CertificateMessage{Context: reqCtx}).Marshal() + if err != nil { + return nil, err + } + th := hashConcat(h, hsCtx, reqBytes, certBytes) + expectedFin := hmacSum(h, fk, th) + if !constantTimeEqual(expectedFin, finMsg.VerifyData) { + return nil, ErrFinishedMismatch + } + if err := session.MarkContextUsed(reqCtx); err != nil { + return nil, err + } + return &ValidationResult{ + Context: append([]byte(nil), reqCtx...), + Empty: true, + }, nil + } + if firstHm.Type != HandshakeTypeCertificate { + return nil, ErrUnsupportedHandshakeType + } + certHm := firstHm + certBytes, _ := MarshalHandshakeMessage(certHm) + + cvHm, rest, err := UnmarshalHandshakeMessage(rest) + if err != nil || cvHm.Type != HandshakeTypeCertificateVerify { + return nil, ErrUnsupportedHandshakeType + } + cvBytes, _ := MarshalHandshakeMessage(cvHm) + + finHm, rest, err := UnmarshalHandshakeMessage(rest) + if err != nil || finHm.Type != HandshakeTypeFinished || len(rest) != 0 { + return nil, ErrInvalidLength + } + finBytes, _ := MarshalHandshakeMessage(finHm) + + certMsg, _, err := UnmarshalCertificateMessage(certBytes) + if err != nil { + return nil, err + } + cvMsg, _, err := UnmarshalCertificateVerifyMessage(cvBytes) + if err != nil { + return nil, err + } + finMsg, _, err := UnmarshalFinishedMessage(finBytes) + if err != nil { + return nil, err + } + + if req != nil && !bytes.Equal(certMsg.Context, reqCtx) { + return nil, ErrContextMismatch + } + if len(certMsg.Entries) == 0 { + // Empty authenticators are encoded as Finished-only. A Certificate + // message with zero entries followed by CertificateVerify/Finished is + // malformed and must not be accepted as an empty authenticator. + return nil, ErrUnsupportedHandshakeType + } + if err := ValidateCMWAttestationPlacement(certMsg.Entries); err != nil { + return nil, err + } + for _, entry := range certMsg.Entries { + if err := validateCertificateEntryExtensions(entry.Extensions, req, policy); err != nil { + return nil, err + } + } + + extracted, present, err := ExtractCMWAttestationFromExtensions(certMsg.Entries[0].Extensions) + if err != nil { + return nil, err + } + if present && req != nil && !RequestPermitsCertificateExtension(req, CMWAttestationExtensionType) { + return nil, ErrBadRequest + } + if present && req == nil && !PolicyPermitsCertificateExtension(policy, CMWAttestationExtensionType) { + return nil, ErrBadRequest + } + + chain := make([]*x509.Certificate, 0, len(certMsg.Entries)) + for _, e := range certMsg.Entries { + c, err := x509.ParseCertificate(e.CertDER) + if err != nil { + return nil, err + } + chain = append(chain, c) + } + leaf := chain[0] + + if req != nil { + ok := false + for _, s := range offered { + if s == cvMsg.Algorithm { + ok = true + break + } + } + if !ok { + return nil, ErrUnsupportedSignatureScheme + } + } + + th1 := hashConcat(h, hsCtx, reqBytes, certBytes) + if err := verifyCertVerify(leaf.PublicKey, cvMsg.Algorithm, th1, cvMsg.Signature); err != nil { + return nil, err + } + + th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes) + expectedFin := hmacSum(h, fk, th2) + if !constantTimeEqual(expectedFin, finMsg.VerifyData) { + return nil, ErrFinishedMismatch + } + + if verifyOpts != nil { + opts := *verifyOpts + if opts.Intermediates == nil { + opts.Intermediates = x509.NewCertPool() + } + for _, ic := range chain[1:] { + opts.Intermediates.AddCert(ic) + } + if _, err := leaf.Verify(opts); err != nil { + return nil, err + } + } + if err := session.MarkContextUsed(certMsg.Context); err != nil { + return nil, err + } + + res := &ValidationResult{ + Context: append([]byte(nil), certMsg.Context...), + Chain: chain, + } + if present { + res.CMWAttestation = extracted + parsed, err := eaattestation.ParsePayload(extracted) + if err != nil { + return nil, err + } + var verifierPolicy eaattestation.VerificationPolicy + // A nil attestation policy is intentional: VerifyPayload then fails closed + // for any payload that carries evidence or attestation results without + // explicit verifiers being configured. + if attPolicy != nil { + verifierPolicy = *attPolicy + } + verified, err := eaattestation.VerifyPayload(st, eaattestation.ExporterLabelAttestation, certMsg.Context, leaf, parsed, verifierPolicy) + if err != nil { + return nil, err + } + res.Attestation = verified + } + return res, nil +} diff --git a/pkg/atls/ea/authenticator_test.go b/pkg/atls/ea/authenticator_test.go new file mode 100644 index 00000000..cd827142 --- /dev/null +++ b/pkg/atls/ea/authenticator_test.go @@ -0,0 +1,537 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "math/big" + "net" + "testing" + "time" + + attestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" +) + +func selfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "ea-phase3"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost"}, + } + der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv} +} + +func tlsPair(t *testing.T, cert tls.Certificate) (srv, cli *tls.Conn) { + t.Helper() + srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13} + cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13} + a, b := net.Pipe() + srv = tls.Server(a, srvConf) + cli = tls.Client(b, cliConf) + errCh := make(chan error, 2) + go func() { errCh <- srv.Handshake() }() + go func() { errCh <- cli.Handshake() }() + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + t.Fatalf("handshake: %v", err) + } + } + return srv, cli +} + +type acceptEvidenceVerifier struct{} + +func (acceptEvidenceVerifier) VerifyEvidence(evidence []byte) error { return nil } + +func TestDummyAttestationRoundTrip(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(16) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + CMWAttestationOfferExtension(), + }, + } + + leaf, _ := x509.ParseCertificate(cert.Certificate[0]) + srvState := srv.ConnectionState() + _, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + payloadBytes, err := attestation.MarshalPayload(attestation.Payload{ + Version: 1, + Evidence: []byte("dummy-attestation-report"), + MediaType: "application/eat+cwt", + Binder: attestation.AttestationBinder{ + ExporterLabel: attestation.ExporterLabelAttestation, + AIKPubHash: aikPubHash, + Binding: binding, + }, + }) + if err != nil { + t.Fatal(err) + } + ext, err := CMWAttestationDataExtension(payloadBytes) + if err != nil { + t.Fatal(err) + } + + auth, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}) + if err != nil { + t.Fatal(err) + } + + cliState := cli.ConnectionState() + roots := x509.NewCertPool() + roots.AddCert(leaf) + + res, err := ValidateAuthenticatorWithAttestation(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}, attestation.VerificationPolicy{ + EvidenceVerifier: acceptEvidenceVerifier{}, + }) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(res.CMWAttestation, payloadBytes) { + t.Fatalf("cmw mismatch") + } + if res.Attestation == nil || !res.Attestation.BindingVerified || !res.Attestation.EvidenceVerified { + t.Fatalf("expected verified attestation result") + } +} + +func TestRejectIfNotOffered(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(8) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + }, + } + + ext, _ := CMWAttestationDataExtension([]byte("dummy")) + srvState := srv.ConnectionState() + if _, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}); err == nil { + t.Fatalf("expected error when cmw_attestation not offered") + } +} + +func TestAttestationFailsClosedWithoutVerifier(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(16) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + CMWAttestationOfferExtension(), + }, + } + + leaf, _ := x509.ParseCertificate(cert.Certificate[0]) + srvState := srv.ConnectionState() + _, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + payloadBytes, err := attestation.MarshalPayload(attestation.Payload{ + Version: 1, + Evidence: []byte("dummy-attestation-report"), + Binder: attestation.AttestationBinder{ + ExporterLabel: attestation.ExporterLabelAttestation, + AIKPubHash: aikPubHash, + Binding: binding, + }, + }) + if err != nil { + t.Fatal(err) + } + ext, err := CMWAttestationDataExtension(payloadBytes) + if err != nil { + t.Fatal(err) + } + + auth, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}) + if err != nil { + t.Fatal(err) + } + + cliState := cli.ConnectionState() + if _, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil); err != attestation.ErrEvidenceVerificationMissing { + t.Fatalf("got %v, want %v", err, attestation.ErrEvidenceVerificationMissing) + } +} + +func TestRejectCMWAttestationOnIntermediateEntry(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(16) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + CMWAttestationOfferExtension(), + }, + } + + leaf, _ := x509.ParseCertificate(cert.Certificate[0]) + srvState := srv.ConnectionState() + hsCtx, h, err := ExportHandshakeContext(&srvState, RoleServer) + if err != nil { + t.Fatal(err) + } + fk, _, err := ExportFinishedKey(&srvState, RoleServer) + if err != nil { + t.Fatal(err) + } + reqBytes, err := req.Marshal() + if err != nil { + t.Fatal(err) + } + _, aikPubHash, binding, err := attestation.ComputeBinding(&srvState, attestation.ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + payloadBytes, err := attestation.MarshalPayload(attestation.Payload{ + Version: 1, + Evidence: []byte("dummy-attestation-report"), + MediaType: "application/eat+cwt", + Binder: attestation.AttestationBinder{ + ExporterLabel: attestation.ExporterLabelAttestation, + AIKPubHash: aikPubHash, + Binding: binding, + }, + }) + if err != nil { + t.Fatal(err) + } + ext, err := CMWAttestationDataExtension(payloadBytes) + if err != nil { + t.Fatal(err) + } + certBytes, err := (CertificateMessage{ + Context: ctx, + Entries: []CertificateEntry{ + {CertDER: cert.Certificate[0]}, + {CertDER: cert.Certificate[0], Extensions: []Extension{ext}}, + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + th1 := hashConcat(h, hsCtx, reqBytes, certBytes) + sig, err := signCertVerify(cert.PrivateKey, uint16(tls.ECDSAWithP256AndSHA256), th1) + if err != nil { + t.Fatal(err) + } + cvBytes, err := (CertificateVerifyMessage{ + Algorithm: uint16(tls.ECDSAWithP256AndSHA256), + Signature: sig, + }).Marshal() + if err != nil { + t.Fatal(err) + } + th2 := hashConcat(h, hsCtx, reqBytes, certBytes, cvBytes) + finBytes, err := (FinishedMessage{VerifyData: hmacSum(h, fk, th2)}).Marshal() + if err != nil { + t.Fatal(err) + } + auth := append(append(certBytes, cvBytes...), finBytes...) + + cliState := cli.ConnectionState() + if _, err := ValidateAuthenticatorWithAttestation(&cliState, RoleServer, req, auth, nil, attestation.VerificationPolicy{ + EvidenceVerifier: acceptEvidenceVerifier{}, + }); err != ErrBadRequest { + t.Fatalf("got %v, want %v", err, ErrBadRequest) + } +} + +func TestSessionRejectsContextReuse(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(12) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + }, + } + + createSession := NewSession() + srvState := srv.ConnectionState() + auth, err := createSession.CreateAuthenticator(&srvState, RoleServer, req, cert, nil) + if err != nil { + t.Fatal(err) + } + if _, err := createSession.CreateAuthenticator(&srvState, RoleServer, req, cert, nil); err != ErrContextReuse { + t.Fatalf("got %v, want %v", err, ErrContextReuse) + } + + validateSession := NewSession() + cliState := cli.ConnectionState() + roots := x509.NewCertPool() + leaf, _ := x509.ParseCertificate(cert.Certificate[0]) + roots.AddCert(leaf) + + if _, err := validateSession.ValidateAuthenticator(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}); err != nil { + t.Fatal(err) + } + if _, err := validateSession.ValidateAuthenticator(&cliState, RoleServer, req, auth, &x509.VerifyOptions{Roots: roots}); err != ErrContextReuse { + t.Fatalf("got %v, want %v", err, ErrContextReuse) + } +} + +func TestEmptyAuthenticatorRoundTrip(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(10) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + }, + } + + srvState := srv.ConnectionState() + auth, err := CreateAuthenticator(&srvState, RoleServer, req, tls.Certificate{}, nil) + if err != nil { + t.Fatal(err) + } + + cliState := cli.ConnectionState() + res, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil) + if err != nil { + t.Fatal(err) + } + if !res.Empty { + t.Fatalf("expected empty authenticator result") + } + if len(res.Chain) != 0 { + t.Fatalf("expected no certificate chain") + } +} + +func TestRejectCertificateMessageWithEmptyEntries(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + ctx, _ := NewRandomContext(10) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + }, + } + + certBytes, err := (CertificateMessage{Context: ctx}).Marshal() + if err != nil { + t.Fatal(err) + } + cvBytes, err := (CertificateVerifyMessage{ + Algorithm: uint16(tls.ECDSAWithP256AndSHA256), + Signature: []byte{0x01}, + }).Marshal() + if err != nil { + t.Fatal(err) + } + finBytes, err := (FinishedMessage{VerifyData: []byte{0x01}}).Marshal() + if err != nil { + t.Fatal(err) + } + + auth := append(append(certBytes, cvBytes...), finBytes...) + + cliState := cli.ConnectionState() + if _, err := ValidateAuthenticator(&cliState, RoleServer, req, auth, nil); err != ErrUnsupportedHandshakeType { + t.Fatalf("got %v, want %v", err, ErrUnsupportedHandshakeType) + } +} + +func TestRejectSpontaneousClientAuthenticator(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + cliState := cli.ConnectionState() + if _, err := CreateAuthenticator(&cliState, RoleClient, nil, cert, nil); err != ErrInvalidRole { + t.Fatalf("got %v, want %v", err, ErrInvalidRole) + } +} + +func TestRequestParsers(t *testing.T) { + oidDER, err := asn1.Marshal(asn1.ObjectIdentifier{2, 5, 4, 3}) + if err != nil { + t.Fatal(err) + } + oidFilterPayload := append([]byte{byte(len(oidDER))}, oidDER...) + oidFilterPayload = append(oidFilterPayload, 0x00, 0x02, 'o', 'k') + req := AuthenticatorRequest{ + Type: HandshakeTypeCertificateRequest, + Context: []byte{1, 2, 3}, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x04, 0x04, 0x03, 0x08, 0x07}}, + {Type: SignatureAlgorithmsCertExtensionType, Data: []byte{0x00, 0x02, 0x08, 0x07}}, + {Type: CertificateAuthoritiesExtensionType, Data: []byte{0x00, 0x06, 0x00, 0x04, 't', 'e', 's', 't'}}, + {Type: OIDFiltersExtensionType, Data: append([]byte{0x00, byte(len(oidFilterPayload))}, oidFilterPayload...)}, + }, + } + + if got, ok := req.SignatureSchemes(); !ok || len(got) != 2 || got[0] != uint16(tls.ECDSAWithP256AndSHA256) || got[1] != uint16(tls.Ed25519) { + t.Fatalf("unexpected signature schemes: %v %v", got, ok) + } + if got, ok := req.SignatureSchemesCert(); !ok || len(got) != 1 || got[0] != uint16(tls.Ed25519) { + t.Fatalf("unexpected signature_algorithms_cert: %v %v", got, ok) + } + if got, ok := req.CertificateAuthorities(); !ok || len(got) != 1 || string(got[0]) != "test" { + t.Fatalf("unexpected certificate authorities: %q %v", got, ok) + } + if got, ok := req.OIDFilters(); !ok || len(got) != 1 || !got[0].OID.Equal(asn1.ObjectIdentifier{2, 5, 4, 3}) || string(got[0].Values) != "ok" { + t.Fatalf("unexpected oid filters: %#v %v", got, ok) + } +} + +func TestRejectLeafExtensionNotPermittedByRequest(t *testing.T) { + cert := selfSignedCert(t) + srv, _ := tlsPair(t, cert) + defer srv.Close() + + ctx, _ := NewRandomContext(8) + req := &AuthenticatorRequest{ + Type: HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []Extension{ + {Type: SignatureAlgorithmsExtensionType, Data: []byte{0x00, 0x02, 0x04, 0x03}}, + }, + } + + srvState := srv.ConnectionState() + ext := Extension{Type: 0x1234, Data: []byte{0x00}} + if _, err := CreateAuthenticator(&srvState, RoleServer, req, cert, []Extension{ext}); err == nil { + t.Fatalf("expected policy error for unpermitted leaf extension") + } +} + +func TestSpontaneousPolicyPermitsCertificateExtension(t *testing.T) { + cert := selfSignedCert(t) + srv, cli := tlsPair(t, cert) + defer srv.Close() + defer cli.Close() + + payloadBytes, err := attestation.MarshalPayload(attestation.Payload{ + Version: 1, + Evidence: []byte("dummy-attestation-report"), + MediaType: "application/eat+cwt", + Binder: attestation.AttestationBinder{ + ExporterLabel: attestation.ExporterLabelAttestation, + AIKPubHash: []byte("placeholder-aik"), + Binding: []byte("placeholder-binding"), + }, + }) + if err != nil { + t.Fatal(err) + } + ext, err := CMWAttestationDataExtension(payloadBytes) + if err != nil { + t.Fatal(err) + } + policy := &SpontaneousAuthenticatorPolicy{ + AllowedSignatureSchemes: []uint16{uint16(tls.ECDSAWithP256AndSHA256)}, + AllowedCertificateExtensions: []uint16{CMWAttestationExtensionType}, + } + + srvState := srv.ConnectionState() + auth, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, []Extension{ext}) + if err != nil { + t.Fatal(err) + } + if len(auth) == 0 { + t.Fatalf("expected authenticator bytes") + } +} + +func TestSpontaneousPolicyRejectsCertificateExtension(t *testing.T) { + cert := selfSignedCert(t) + srv, _ := tlsPair(t, cert) + defer srv.Close() + + ext, err := CMWAttestationDataExtension([]byte("dummy-attestation-report")) + if err != nil { + t.Fatal(err) + } + policy := &SpontaneousAuthenticatorPolicy{ + AllowedSignatureSchemes: []uint16{uint16(tls.ECDSAWithP256AndSHA256)}, + } + + srvState := srv.ConnectionState() + if _, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, []Extension{ext}); err == nil { + t.Fatalf("expected policy error for unpermitted spontaneous extension") + } +} + +func TestSpontaneousPolicyRejectsSignatureScheme(t *testing.T) { + cert := selfSignedCert(t) + srv, _ := tlsPair(t, cert) + defer srv.Close() + + policy := &SpontaneousAuthenticatorPolicy{ + AllowedSignatureSchemes: []uint16{uint16(tls.Ed25519)}, + } + + srvState := srv.ConnectionState() + if _, err := CreateAuthenticatorWithPolicy(&srvState, RoleServer, nil, policy, cert, nil); err != ErrUnsupportedSignatureScheme { + t.Fatalf("got %v, want %v", err, ErrUnsupportedSignatureScheme) + } +} diff --git a/pkg/atls/ea/certificate.go b/pkg/atls/ea/certificate.go new file mode 100644 index 00000000..e8281b5d --- /dev/null +++ b/pkg/atls/ea/certificate.go @@ -0,0 +1,95 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +type CertificateMessage struct { + Context []byte + Entries []CertificateEntry +} + +type CertificateEntry struct { + CertDER []byte + Extensions []Extension +} + +func (m CertificateMessage) Marshal() ([]byte, error) { + if len(m.Context) > 255 { + return nil, ErrInvalidLength + } + + var listPayload []byte + for _, e := range m.Entries { + if len(e.CertDER) == 0 || len(e.CertDER) > 0xFFFFFF { + return nil, ErrInvalidLength + } + extVec, err := MarshalExtensions(e.Extensions) + if err != nil { + return nil, err + } + entry := make([]byte, 3+len(e.CertDER)+len(extVec)) + putUint24(entry[0:3], uint32(len(e.CertDER))) + copy(entry[3:], e.CertDER) + copy(entry[3+len(e.CertDER):], extVec) + listPayload = append(listPayload, entry...) + } + + body := make([]byte, 1+len(m.Context)+3+len(listPayload)) + body[0] = byte(len(m.Context)) + copy(body[1:], m.Context) + putUint24(body[1+len(m.Context):1+len(m.Context)+3], uint32(len(listPayload))) + copy(body[1+len(m.Context)+3:], listPayload) + + return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeCertificate, Body: body}) +} + +func UnmarshalCertificateMessage(handshakeBytes []byte) (CertificateMessage, []byte, error) { + hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes) + if err != nil { + return CertificateMessage{}, nil, err + } + if len(rest) != 0 || hm.Type != HandshakeTypeCertificate { + return CertificateMessage{}, nil, ErrInvalidLength + } + if len(hm.Body) < 1 { + return CertificateMessage{}, nil, ErrTruncated + } + + ctxLen := int(hm.Body[0]) + if len(hm.Body) < 1+ctxLen+3 { + return CertificateMessage{}, nil, ErrTruncated + } + ctx := append([]byte(nil), hm.Body[1:1+ctxLen]...) + + listLen := int(readUint24(hm.Body[1+ctxLen : 1+ctxLen+3])) + if len(hm.Body) != 1+ctxLen+3+listLen { + return CertificateMessage{}, nil, ErrInvalidLength + } + list := hm.Body[1+ctxLen+3:] + + var entries []CertificateEntry + for i := 0; i < len(list); { + if len(list)-i < 3 { + return CertificateMessage{}, nil, ErrTruncated + } + certLen := int(readUint24(list[i : i+3])) + i += 3 + if certLen <= 0 || certLen > len(list)-i { + return CertificateMessage{}, nil, ErrInvalidLength + } + certDER := append([]byte(nil), list[i:i+certLen]...) + i += certLen + + exts, leftover, err := UnmarshalExtensions(list[i:]) + if err != nil { + return CertificateMessage{}, nil, err + } + consumed := len(list[i:]) - len(leftover) + i += consumed + + entries = append(entries, CertificateEntry{CertDER: certDER, Extensions: exts}) + } + + raw, _ := MarshalHandshakeMessage(hm) + return CertificateMessage{Context: ctx, Entries: entries}, raw, nil +} diff --git a/pkg/atls/ea/certverify.go b/pkg/atls/ea/certverify.go new file mode 100644 index 00000000..4a4b2b76 --- /dev/null +++ b/pkg/atls/ea/certverify.go @@ -0,0 +1,148 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" +) + +type CertificateVerifyMessage struct { + Algorithm uint16 + Signature []byte +} + +func (m CertificateVerifyMessage) Marshal() ([]byte, error) { + if len(m.Signature) > 0xFFFF { + return nil, ErrInvalidLength + } + body := make([]byte, 4+len(m.Signature)) + putUint16(body[0:2], m.Algorithm) + putUint16(body[2:4], uint16(len(m.Signature))) + copy(body[4:], m.Signature) + return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeCertificateVerify, Body: body}) +} + +func UnmarshalCertificateVerifyMessage(handshakeBytes []byte) (CertificateVerifyMessage, []byte, error) { + hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes) + if err != nil { + return CertificateVerifyMessage{}, nil, err + } + if len(rest) != 0 || hm.Type != HandshakeTypeCertificateVerify { + return CertificateVerifyMessage{}, nil, ErrInvalidLength + } + if len(hm.Body) < 4 { + return CertificateVerifyMessage{}, nil, ErrTruncated + } + alg := readUint16(hm.Body[0:2]) + sigLen := int(readUint16(hm.Body[2:4])) + if len(hm.Body) != 4+sigLen { + return CertificateVerifyMessage{}, nil, ErrInvalidLength + } + sig := append([]byte(nil), hm.Body[4:]...) + raw, _ := MarshalHandshakeMessage(hm) + return CertificateVerifyMessage{Algorithm: alg, Signature: sig}, raw, nil +} + +var eaContextString = []byte("Exported Authenticator") + +func buildCertVerifyInput(transcriptHash []byte) []byte { + prefix := bytes.Repeat([]byte{0x20}, 64) + out := make([]byte, 0, len(prefix)+len(eaContextString)+1+len(transcriptHash)) + out = append(out, prefix...) + out = append(out, eaContextString...) + out = append(out, 0x00) + out = append(out, transcriptHash...) + return out +} + +func signCertVerify(priv any, scheme uint16, transcriptHash []byte) ([]byte, error) { + info, err := signatureSchemeInfo(scheme) + if err != nil { + return nil, err + } + msg := buildCertVerifyInput(transcriptHash) + + switch info.Alg { + case sigAlgECDSA: + k, ok := priv.(*ecdsa.PrivateKey) + if !ok { + return nil, ErrUnsupportedSignatureScheme + } + h := info.Hash.New() + h.Write(msg) + return ecdsa.SignASN1(rand.Reader, k, h.Sum(nil)) + + case sigAlgRSAPSS: + k, ok := priv.(*rsa.PrivateKey) + if !ok { + return nil, ErrUnsupportedSignatureScheme + } + h := info.Hash.New() + h.Write(msg) + return rsa.SignPSS(rand.Reader, k, info.Hash, h.Sum(nil), + &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: info.Hash}) + + case sigAlgEd25519: + k, ok := priv.(ed25519.PrivateKey) + if !ok { + return nil, ErrUnsupportedSignatureScheme + } + return ed25519.Sign(k, msg), nil + + default: + return nil, ErrUnsupportedSignatureScheme + } +} + +func verifyCertVerify(pub any, scheme uint16, transcriptHash []byte, signature []byte) error { + info, err := signatureSchemeInfo(scheme) + if err != nil { + return err + } + msg := buildCertVerifyInput(transcriptHash) + + switch info.Alg { + case sigAlgECDSA: + k, ok := pub.(*ecdsa.PublicKey) + if !ok { + return ErrUnsupportedSignatureScheme + } + h := info.Hash.New() + h.Write(msg) + if !ecdsa.VerifyASN1(k, h.Sum(nil), signature) { + return ErrSignatureMismatch + } + return nil + + case sigAlgRSAPSS: + k, ok := pub.(*rsa.PublicKey) + if !ok { + return ErrUnsupportedSignatureScheme + } + h := info.Hash.New() + h.Write(msg) + if err := rsa.VerifyPSS(k, info.Hash, h.Sum(nil), signature, + &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: info.Hash}); err != nil { + return ErrSignatureMismatch + } + return nil + + case sigAlgEd25519: + k, ok := pub.(ed25519.PublicKey) + if !ok { + return ErrUnsupportedSignatureScheme + } + if !ed25519.Verify(k, msg, signature) { + return ErrSignatureMismatch + } + return nil + + default: + return ErrUnsupportedSignatureScheme + } +} diff --git a/pkg/atls/ea/cmw_attestation.go b/pkg/atls/ea/cmw_attestation.go new file mode 100644 index 00000000..b96f6e58 --- /dev/null +++ b/pkg/atls/ea/cmw_attestation.go @@ -0,0 +1,51 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +const CMWAttestationExtensionType uint16 = 0xFF00 + +func CMWAttestationOfferExtension() Extension { + return Extension{Type: CMWAttestationExtensionType, Data: nil} +} + +func CMWAttestationDataExtension(cmw []byte) (Extension, error) { + if len(cmw) == 0 || len(cmw) > 0xFFFF { + return Extension{}, ErrInvalidLength + } + data := make([]byte, 2+len(cmw)) + putUint16(data[0:2], uint16(len(cmw))) + copy(data[2:], cmw) + return Extension{Type: CMWAttestationExtensionType, Data: data}, nil +} + +func ExtractCMWAttestationFromExtensions(exts []Extension) ([]byte, bool, error) { + for _, e := range exts { + if e.Type != CMWAttestationExtensionType { + continue + } + if len(e.Data) < 2 { + return nil, true, ErrInvalidLength + } + l := int(readUint16(e.Data[0:2])) + if l <= 0 || l != len(e.Data)-2 { + return nil, true, ErrInvalidLength + } + return append([]byte(nil), e.Data[2:]...), true, nil + } + return nil, false, nil +} + +func ValidateCMWAttestationPlacement(entries []CertificateEntry) error { + for i, entry := range entries { + for _, ext := range entry.Extensions { + if ext.Type != CMWAttestationExtensionType { + continue + } + if i != 0 { + return ErrBadRequest + } + } + } + return nil +} diff --git a/pkg/atls/ea/exporters.go b/pkg/atls/ea/exporters.go new file mode 100644 index 00000000..d5bdb0c7 --- /dev/null +++ b/pkg/atls/ea/exporters.go @@ -0,0 +1,67 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "crypto" + "crypto/tls" + "fmt" +) + +const ( + LabelClientAuthenticatorHandshakeContext = "EXPORTER-client authenticator handshake context" + LabelServerAuthenticatorHandshakeContext = "EXPORTER-server authenticator handshake context" + LabelClientAuthenticatorFinishedKey = "EXPORTER-client authenticator finished key" + LabelServerAuthenticatorFinishedKey = "EXPORTER-server authenticator finished key" +) + +type Role uint8 + +const ( + RoleClient Role = iota + 1 + RoleServer +) + +func AuthenticatorHashTLS13(cipherSuite uint16) (crypto.Hash, error) { + switch cipherSuite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256: + return crypto.SHA256, nil + case tls.TLS_AES_256_GCM_SHA384: + return crypto.SHA384, nil + default: + return 0, fmt.Errorf("%w: %s (0x%04x)", ErrUnknownCipherSuite, tls.CipherSuiteName(cipherSuite), cipherSuite) + } +} + +func ExportHandshakeContext(st *tls.ConnectionState, role Role) ([]byte, crypto.Hash, error) { + if st.Version != tls.VersionTLS13 { + return nil, 0, ErrNotTLS13 + } + h, err := AuthenticatorHashTLS13(st.CipherSuite) + if err != nil { + return nil, 0, err + } + label := LabelClientAuthenticatorHandshakeContext + if role == RoleServer { + label = LabelServerAuthenticatorHandshakeContext + } + out, err := st.ExportKeyingMaterial(label, nil, h.Size()) + return out, h, err +} + +func ExportFinishedKey(st *tls.ConnectionState, role Role) ([]byte, crypto.Hash, error) { + if st.Version != tls.VersionTLS13 { + return nil, 0, ErrNotTLS13 + } + h, err := AuthenticatorHashTLS13(st.CipherSuite) + if err != nil { + return nil, 0, err + } + label := LabelClientAuthenticatorFinishedKey + if role == RoleServer { + label = LabelServerAuthenticatorFinishedKey + } + out, err := st.ExportKeyingMaterial(label, nil, h.Size()) + return out, h, err +} diff --git a/pkg/atls/ea/extensions.go b/pkg/atls/ea/extensions.go new file mode 100644 index 00000000..3625d55c --- /dev/null +++ b/pkg/atls/ea/extensions.go @@ -0,0 +1,61 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +type Extension struct { + Type uint16 + Data []byte +} + +func MarshalExtensions(exts []Extension) ([]byte, error) { + payloadLen := 0 + for _, e := range exts { + if len(e.Data) > 0xFFFF { + return nil, ErrInvalidLength + } + payloadLen += 4 + len(e.Data) + if payloadLen > 0xFFFF { + return nil, ErrInvalidLength + } + } + out := make([]byte, 2+payloadLen) + putUint16(out[0:2], uint16(payloadLen)) + off := 2 + for _, e := range exts { + putUint16(out[off:off+2], e.Type) + putUint16(out[off+2:off+4], uint16(len(e.Data))) + copy(out[off+4:], e.Data) + off += 4 + len(e.Data) + } + return out, nil +} + +func UnmarshalExtensions(b []byte) (exts []Extension, rest []byte, err error) { + if len(b) < 2 { + return nil, nil, ErrTruncated + } + total := int(readUint16(b[0:2])) + if len(b) < 2+total { + return nil, nil, ErrTruncated + } + payload := b[2 : 2+total] + rest = b[2+total:] + i := 0 + for i < len(payload) { + if len(payload)-i < 4 { + return nil, nil, ErrTruncated + } + typ := readUint16(payload[i : i+2]) + l := int(readUint16(payload[i+2 : i+4])) + i += 4 + if l < 0 || l > len(payload)-i { + return nil, nil, ErrInvalidLength + } + data := make([]byte, l) + copy(data, payload[i:i+l]) + i += l + exts = append(exts, Extension{Type: typ, Data: data}) + } + return exts, rest, nil +} diff --git a/pkg/atls/ea/finished.go b/pkg/atls/ea/finished.go new file mode 100644 index 00000000..ce0c2827 --- /dev/null +++ b/pkg/atls/ea/finished.go @@ -0,0 +1,27 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +type FinishedMessage struct { + VerifyData []byte +} + +func (m FinishedMessage) Marshal() ([]byte, error) { + if len(m.VerifyData) == 0 { + return nil, ErrInvalidLength + } + return MarshalHandshakeMessage(HandshakeMessage{Type: HandshakeTypeFinished, Body: append([]byte(nil), m.VerifyData...)}) +} + +func UnmarshalFinishedMessage(handshakeBytes []byte) (FinishedMessage, []byte, error) { + hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes) + if err != nil { + return FinishedMessage{}, nil, err + } + if len(rest) != 0 || hm.Type != HandshakeTypeFinished || len(hm.Body) == 0 { + return FinishedMessage{}, nil, ErrInvalidLength + } + raw, _ := MarshalHandshakeMessage(hm) + return FinishedMessage{VerifyData: append([]byte(nil), hm.Body...)}, raw, nil +} diff --git a/pkg/atls/ea/handshake.go b/pkg/atls/ea/handshake.go new file mode 100644 index 00000000..c5b09ad2 --- /dev/null +++ b/pkg/atls/ea/handshake.go @@ -0,0 +1,55 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +const handshakeHeaderLen = 4 // 1 byte type + 3 byte uint24 len + +const ( + HandshakeTypeCertificate uint8 = 11 + HandshakeTypeCertificateRequest uint8 = 13 + HandshakeTypeCertificateVerify uint8 = 15 + HandshakeTypeClientCertificateRequest uint8 = 17 + HandshakeTypeFinished uint8 = 20 +) + +type HandshakeMessage struct { + Type uint8 + Body []byte +} + +func MarshalHandshakeMessage(m HandshakeMessage) ([]byte, error) { + if len(m.Body) > 0xFFFFFF { + return nil, ErrInvalidLength + } + out := make([]byte, handshakeHeaderLen+len(m.Body)) + out[0] = m.Type + putUint24(out[1:4], uint32(len(m.Body))) + copy(out[4:], m.Body) + return out, nil +} + +func UnmarshalHandshakeMessage(b []byte) (msg HandshakeMessage, rest []byte, err error) { + if len(b) < handshakeHeaderLen { + return HandshakeMessage{}, nil, ErrTruncated + } + t := b[0] + n := int(readUint24(b[1:4])) + if n < 0 || n > 0xFFFFFF { + return HandshakeMessage{}, nil, ErrInvalidLength + } + if len(b) < handshakeHeaderLen+n { + return HandshakeMessage{}, nil, ErrTruncated + } + body := make([]byte, n) + copy(body, b[4:4+n]) + return HandshakeMessage{Type: t, Body: body}, b[4+n:], nil +} + +func putUint24(dst []byte, v uint32) { dst[0] = byte(v >> 16); dst[1] = byte(v >> 8); dst[2] = byte(v) } +func readUint24(src []byte) uint32 { + return (uint32(src[0]) << 16) | (uint32(src[1]) << 8) | uint32(src[2]) +} + +func putUint16(dst []byte, v uint16) { dst[0] = byte(v >> 8); dst[1] = byte(v) } +func readUint16(src []byte) uint16 { return (uint16(src[0]) << 8) | uint16(src[1]) } diff --git a/pkg/atls/ea/policy.go b/pkg/atls/ea/policy.go new file mode 100644 index 00000000..b94692a8 --- /dev/null +++ b/pkg/atls/ea/policy.go @@ -0,0 +1,60 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +type SpontaneousAuthenticatorPolicy struct { + AllowedSignatureSchemes []uint16 + AllowedCertificateExtensions []uint16 +} + +func RequestPermitsCertificateExtension(req *AuthenticatorRequest, typ uint16) bool { + if req == nil { + return false + } + for _, e := range req.Extensions { + if e.Type == typ { + return true + } + } + return false +} + +func PolicyPermitsCertificateExtension(policy *SpontaneousAuthenticatorPolicy, typ uint16) bool { + if policy == nil { + return false + } + for _, allowed := range policy.AllowedCertificateExtensions { + if allowed == typ { + return true + } + } + return false +} + +func policyPermitsSignatureScheme(policy *SpontaneousAuthenticatorPolicy, scheme uint16) bool { + if policy == nil || len(policy.AllowedSignatureSchemes) == 0 { + return true + } + for _, allowed := range policy.AllowedSignatureSchemes { + if allowed == scheme { + return true + } + } + return false +} + +func validateCertificateEntryExtensions(exts []Extension, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy) error { + for _, e := range exts { + if req != nil { + if !RequestPermitsCertificateExtension(req, e.Type) { + return ErrBadRequest + } + continue + } + if !PolicyPermitsCertificateExtension(policy, e.Type) { + return ErrBadRequest + } + } + return nil +} diff --git a/pkg/atls/ea/request.go b/pkg/atls/ea/request.go new file mode 100644 index 00000000..670f4193 --- /dev/null +++ b/pkg/atls/ea/request.go @@ -0,0 +1,214 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "crypto/rand" + "encoding/asn1" +) + +const ( + SignatureAlgorithmsExtensionType uint16 = 0x000d + ServerNameExtensionType uint16 = 0x0000 + CertificateAuthoritiesExtensionType uint16 = 0x002f + OIDFiltersExtensionType uint16 = 0x0030 + SignatureAlgorithmsCertExtensionType uint16 = 0x0032 +) + +type AuthenticatorRequest struct { + Type uint8 + Context []byte + Extensions []Extension +} + +type OIDFilter struct { + OID asn1.ObjectIdentifier + Values []byte +} + +func NewRandomContext(n int) ([]byte, error) { + if n <= 0 || n > 255 { + return nil, ErrInvalidLength + } + b := make([]byte, n) + _, err := rand.Read(b) + return b, err +} + +func (r AuthenticatorRequest) Marshal() ([]byte, error) { + if r.Type != HandshakeTypeCertificateRequest && r.Type != HandshakeTypeClientCertificateRequest { + return nil, ErrUnsupportedHandshakeType + } + if len(r.Context) > 255 { + return nil, ErrInvalidLength + } + extVec, err := MarshalExtensions(r.Extensions) + if err != nil { + return nil, err + } + body := make([]byte, 1+len(r.Context)+len(extVec)) + body[0] = byte(len(r.Context)) + copy(body[1:], r.Context) + copy(body[1+len(r.Context):], extVec) + return MarshalHandshakeMessage(HandshakeMessage{Type: r.Type, Body: body}) +} + +func UnmarshalAuthenticatorRequest(handshakeBytes []byte) (AuthenticatorRequest, []byte, error) { + hm, rest, err := UnmarshalHandshakeMessage(handshakeBytes) + if err != nil { + return AuthenticatorRequest{}, nil, err + } + if hm.Type != HandshakeTypeCertificateRequest && hm.Type != HandshakeTypeClientCertificateRequest { + return AuthenticatorRequest{}, nil, ErrUnsupportedHandshakeType + } + if len(hm.Body) < 1 { + return AuthenticatorRequest{}, nil, ErrTruncated + } + ctxLen := int(hm.Body[0]) + if len(hm.Body) < 1+ctxLen { + return AuthenticatorRequest{}, nil, ErrTruncated + } + ctx := append([]byte(nil), hm.Body[1:1+ctxLen]...) + exts, leftover, err := UnmarshalExtensions(hm.Body[1+ctxLen:]) + if err != nil { + return AuthenticatorRequest{}, nil, err + } + if len(leftover) != 0 { + return AuthenticatorRequest{}, nil, ErrInvalidLength + } + return AuthenticatorRequest{ + Type: hm.Type, + Context: ctx, + Extensions: exts, + }, rest, nil +} + +func (r AuthenticatorRequest) SignatureSchemes() ([]uint16, bool) { + return parseSignatureSchemesExtension(r.Extensions, SignatureAlgorithmsExtensionType) +} + +func (r AuthenticatorRequest) SignatureSchemesCert() ([]uint16, bool) { + return parseSignatureSchemesExtension(r.Extensions, SignatureAlgorithmsCertExtensionType) +} + +func (r AuthenticatorRequest) CertificateAuthorities() ([][]byte, bool) { + for _, e := range r.Extensions { + if e.Type != CertificateAuthoritiesExtensionType { + continue + } + if len(e.Data) < 2 { + return nil, false + } + total := int(readUint16(e.Data[0:2])) + if total < 3 || len(e.Data) != 2+total { + return nil, false + } + var out [][]byte + for off := 2; off < len(e.Data); { + if len(e.Data)-off < 2 { + return nil, false + } + l := int(readUint16(e.Data[off : off+2])) + off += 2 + if l == 0 || l > len(e.Data)-off { + return nil, false + } + out = append(out, append([]byte(nil), e.Data[off:off+l]...)) + off += l + } + return out, true + } + return nil, false +} + +func (r AuthenticatorRequest) OIDFilters() ([]OIDFilter, bool) { + for _, e := range r.Extensions { + if e.Type != OIDFiltersExtensionType { + continue + } + if len(e.Data) < 2 { + return nil, false + } + total := int(readUint16(e.Data[0:2])) + if len(e.Data) != 2+total { + return nil, false + } + var out []OIDFilter + for off := 2; off < len(e.Data); { + if len(e.Data)-off < 1 { + return nil, false + } + oidLen := int(e.Data[off]) + off++ + if oidLen == 0 || oidLen > len(e.Data)-off { + return nil, false + } + rawOID := append([]byte(nil), e.Data[off:off+oidLen]...) + off += oidLen + var oid asn1.ObjectIdentifier + if _, err := asn1.Unmarshal(rawOID, &oid); err != nil { + return nil, false + } + if len(e.Data)-off < 2 { + return nil, false + } + valLen := int(readUint16(e.Data[off : off+2])) + off += 2 + if valLen > len(e.Data)-off { + return nil, false + } + values := append([]byte(nil), e.Data[off:off+valLen]...) + off += valLen + out = append(out, OIDFilter{OID: oid, Values: values}) + } + return out, true + } + return nil, false +} + +func parseSignatureSchemesExtension(exts []Extension, typ uint16) ([]uint16, bool) { + for _, e := range exts { + if e.Type != typ { + continue + } + if len(e.Data) < 2 { + return nil, false + } + vecLen := int(readUint16(e.Data[0:2])) + if vecLen < 2 || vecLen%2 != 0 || len(e.Data) != 2+vecLen { + return nil, false + } + out := make([]uint16, 0, vecLen/2) + for off := 2; off < len(e.Data); off += 2 { + out = append(out, readUint16(e.Data[off:off+2])) + } + return out, true + } + return nil, false +} + +func SignatureAlgorithmsExtension(schemes []uint16) (Extension, error) { + return marshalSignatureSchemesExtension(SignatureAlgorithmsExtensionType, schemes) +} + +func SignatureAlgorithmsCertExtension(schemes []uint16) (Extension, error) { + return marshalSignatureSchemesExtension(SignatureAlgorithmsCertExtensionType, schemes) +} + +func marshalSignatureSchemesExtension(typ uint16, schemes []uint16) (Extension, error) { + if len(schemes) == 0 { + return Extension{}, ErrInvalidLength + } + if len(schemes) > 0x7fff { + return Extension{}, ErrInvalidLength + } + data := make([]byte, 2+2*len(schemes)) + putUint16(data[0:2], uint16(2*len(schemes))) + off := 2 + for _, s := range schemes { + putUint16(data[off:off+2], s) + off += 2 + } + return Extension{Type: typ, Data: data}, nil +} diff --git a/pkg/atls/ea/session.go b/pkg/atls/ea/session.go new file mode 100644 index 00000000..77747e3d --- /dev/null +++ b/pkg/atls/ea/session.go @@ -0,0 +1,59 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "crypto/tls" + "crypto/x509" + "sync" + + eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" +) + +type Session struct { + mu sync.Mutex + used map[string]struct{} +} + +func NewSession() *Session { + return &Session{used: make(map[string]struct{})} +} + +func (s *Session) MarkContextUsed(ctx []byte) error { + if s == nil { + return nil + } + key := string(ctx) + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.used[key]; ok { + return ErrContextReuse + } + s.used[key] = struct{}{} + return nil +} + +func (s *Session) CreateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) { + return createAuthenticator(s, st, role, req, nil, identity, leafEntryExtensions) +} + +func (s *Session) ValidateAuthenticator(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) { + return validateAuthenticator(s, st, role, req, nil, nil, authBytes, verifyOpts) +} + +func (s *Session) CreateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, identity tls.Certificate, leafEntryExtensions []Extension) ([]byte, error) { + return createAuthenticator(s, st, role, req, policy, identity, leafEntryExtensions) +} + +func (s *Session) ValidateAuthenticatorWithPolicy(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions) (*ValidationResult, error) { + return validateAuthenticator(s, st, role, req, policy, nil, authBytes, verifyOpts) +} + +func (s *Session) ValidateAuthenticatorWithAttestation(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) { + return validateAuthenticator(s, st, role, req, nil, &attPolicy, authBytes, verifyOpts) +} + +func (s *Session) ValidateAuthenticatorWithPolicies(st *tls.ConnectionState, role Role, req *AuthenticatorRequest, policy *SpontaneousAuthenticatorPolicy, authBytes []byte, verifyOpts *x509.VerifyOptions, attPolicy eaattestation.VerificationPolicy) (*ValidationResult, error) { + return validateAuthenticator(s, st, role, req, policy, &attPolicy, authBytes, verifyOpts) +} diff --git a/pkg/atls/ea/sigscheme.go b/pkg/atls/ea/sigscheme.go new file mode 100644 index 00000000..147fc44a --- /dev/null +++ b/pkg/atls/ea/sigscheme.go @@ -0,0 +1,82 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "crypto/tls" + "fmt" +) + +type sigAlg uint8 + +const ( + sigAlgECDSA sigAlg = iota + 1 + sigAlgRSAPSS + sigAlgEd25519 +) + +type sigSchemeInfo struct { + Scheme uint16 + Alg sigAlg + Hash crypto.Hash // 0 for Ed25519 +} + +func signatureSchemeInfo(s uint16) (sigSchemeInfo, error) { + switch s { + case uint16(tls.ECDSAWithP256AndSHA256): + return sigSchemeInfo{s, sigAlgECDSA, crypto.SHA256}, nil + case uint16(tls.PSSWithSHA256): + return sigSchemeInfo{s, sigAlgRSAPSS, crypto.SHA256}, nil + case uint16(tls.Ed25519): + return sigSchemeInfo{s, sigAlgEd25519, 0}, nil + default: + return sigSchemeInfo{}, fmt.Errorf("%w: 0x%04x", ErrUnsupportedSignatureScheme, s) + } +} + +func chooseSignatureScheme(priv any, offered []uint16) (uint16, error) { + compat := func(s uint16) bool { + info, err := signatureSchemeInfo(s) + if err != nil { + return false + } + switch info.Alg { + case sigAlgECDSA: + k, ok := priv.(*ecdsa.PrivateKey) + return ok && k.Curve.Params().Name == "P-256" + case sigAlgRSAPSS: + _, ok := priv.(*rsa.PrivateKey) + return ok + case sigAlgEd25519: + _, ok := priv.(ed25519.PrivateKey) + return ok + default: + return false + } + } + + if len(offered) > 0 { + for _, s := range offered { + if compat(s) { + return s, nil + } + } + return 0, ErrUnsupportedSignatureScheme + } + + if k, ok := priv.(*ecdsa.PrivateKey); ok && k.Curve.Params().Name == "P-256" { + return uint16(tls.ECDSAWithP256AndSHA256), nil + } + if _, ok := priv.(*rsa.PrivateKey); ok { + return uint16(tls.PSSWithSHA256), nil + } + if _, ok := priv.(ed25519.PrivateKey); ok { + return uint16(tls.Ed25519), nil + } + return 0, ErrUnsupportedSignatureScheme +} diff --git a/pkg/atls/ea/util.go b/pkg/atls/ea/util.go new file mode 100644 index 00000000..6ac10831 --- /dev/null +++ b/pkg/atls/ea/util.go @@ -0,0 +1,34 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package ea + +import ( + "crypto" + "crypto/hmac" + "crypto/subtle" +) + +func hashConcat(h crypto.Hash, chunks ...[]byte) []byte { + hs := h.New() + for _, c := range chunks { + if len(c) == 0 { + continue + } + hs.Write(c) + } + return hs.Sum(nil) +} + +func hmacSum(h crypto.Hash, key, data []byte) []byte { + m := hmac.New(h.New, key) + m.Write(data) + return m.Sum(nil) +} + +func constantTimeEqual(a, b []byte) bool { + if len(a) != len(b) { + return false + } + return subtle.ConstantTimeCompare(a, b) == 1 +} diff --git a/pkg/atls/eaattestation/binding.go b/pkg/atls/eaattestation/binding.go new file mode 100644 index 00000000..96a86e5f --- /dev/null +++ b/pkg/atls/eaattestation/binding.go @@ -0,0 +1,88 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package attestation + +import ( + "crypto" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" +) + +const ( + ExporterLabelAttestation = "Attestation" + ExporterLabelAttestationBinding = "Attestation Binding" +) + +const ExportedAttestationValueLen = 32 + +var errNotTLS13 = errors.New("attestation: not TLS 1.3") + +func ExportAttestationValue(st *tls.ConnectionState, label string, contextValue []byte) ([]byte, crypto.Hash, error) { + if st.Version != tls.VersionTLS13 { + return nil, 0, errNotTLS13 + } + h, err := authenticatorHashTLS13(st.CipherSuite) + if err != nil { + return nil, 0, err + } + out, err := st.ExportKeyingMaterial(label, contextValue, ExportedAttestationValueLen) + if err != nil { + return nil, 0, err + } + return out, h, nil +} + +func PublicKeyBytes(leaf *x509.Certificate) ([]byte, error) { + if leaf == nil { + return nil, fmt.Errorf("nil leaf cert") + } + if len(leaf.RawSubjectPublicKeyInfo) > 0 { + return leaf.RawSubjectPublicKeyInfo, nil + } + b, err := x509.MarshalPKIXPublicKey(leaf.PublicKey) + if err != nil { + return nil, err + } + return b, nil +} + +func AIKPublicKeyHash(h crypto.Hash, pubKey []byte) []byte { + hs := h.New() + hs.Write(pubKey) + return hs.Sum(nil) +} + +func BindingValue(h crypto.Hash, pubKey, exportedValue []byte) []byte { + hs := h.New() + hs.Write(pubKey) + hs.Write(exportedValue) + return hs.Sum(nil) +} + +func ComputeBinding(st *tls.ConnectionState, label string, certificateRequestContext []byte, leaf *x509.Certificate) (exportedValue, aikPubHash, binding []byte, err error) { + exportedValue, h, err := ExportAttestationValue(st, label, certificateRequestContext) + if err != nil { + return nil, nil, nil, err + } + pub, err := PublicKeyBytes(leaf) + if err != nil { + return nil, nil, nil, err + } + aikPubHash = AIKPublicKeyHash(h, pub) + binding = BindingValue(h, pub, exportedValue) + return exportedValue, aikPubHash, binding, nil +} + +func authenticatorHashTLS13(cipherSuite uint16) (crypto.Hash, error) { + switch cipherSuite { + case tls.TLS_AES_128_GCM_SHA256, tls.TLS_CHACHA20_POLY1305_SHA256: + return crypto.SHA256, nil + case tls.TLS_AES_256_GCM_SHA384: + return crypto.SHA384, nil + default: + return 0, fmt.Errorf("attestation: unknown cipher suite: %s (0x%04x)", tls.CipherSuiteName(cipherSuite), cipherSuite) + } +} diff --git a/pkg/atls/eaattestation/binding_test.go b/pkg/atls/eaattestation/binding_test.go new file mode 100644 index 00000000..dcec1aea --- /dev/null +++ b/pkg/atls/eaattestation/binding_test.go @@ -0,0 +1,199 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package attestation + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "testing" + "time" +) + +type stubEvidenceVerifier struct { + called bool + err error +} + +func (s *stubEvidenceVerifier) VerifyEvidence(evidence []byte) error { + s.called = true + return s.err +} + +type stubResultsVerifier struct { + called bool + err error +} + +func (s *stubResultsVerifier) VerifyAttestationResults(results []byte) error { + s.called = true + return s.err +} + +func makeCert(t *testing.T) (tls.Certificate, *x509.Certificate) { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "binding"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + leaf, err := x509.ParseCertificate(der) + if err != nil { + t.Fatal(err) + } + return tls.Certificate{Certificate: [][]byte{der}, PrivateKey: priv}, leaf +} + +func tls13Client(t *testing.T, cert tls.Certificate) (*tls.Conn, *tls.Conn) { + t.Helper() + srvConf := &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13} + cliConf := &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, MaxVersion: tls.VersionTLS13} + a, b := net.Pipe() + srv := tls.Server(a, srvConf) + cli := tls.Client(b, cliConf) + errCh := make(chan error, 2) + go func() { errCh <- srv.Handshake() }() + go func() { errCh <- cli.Handshake() }() + for i := 0; i < 2; i++ { + if err := <-errCh; err != nil { + t.Fatalf("handshake: %v", err) + } + } + return srv, cli +} + +func TestComputeBindingDeterministic(t *testing.T) { + cert, leaf := makeCert(t) + srv, cli := tls13Client(t, cert) + defer srv.Close() + defer cli.Close() + + st := cli.ConnectionState() + ctx := []byte{1, 2, 3, 4} + + ev1, aik1, b1, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + ev2, aik2, b2, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + + if len(ev1) != ExportedAttestationValueLen { + t.Fatalf("unexpected exported len: %d", len(ev1)) + } + if !bytes.Equal(ev1, ev2) || !bytes.Equal(aik1, aik2) || !bytes.Equal(b1, b2) { + t.Fatalf("expected deterministic outputs for same conn+context") + } + if bytes.Equal(aik1, b1) { + t.Fatalf("unexpected aik == binding") + } +} + +func TestPayloadRoundTrip(t *testing.T) { + payload := Payload{ + Version: 1, + MediaType: "application/eat+cwt", + Evidence: []byte("evidence"), + Binder: AttestationBinder{ + ExporterLabel: ExporterLabelAttestation, + AIKPubHash: []byte("aik"), + Binding: []byte("binding"), + }, + } + + raw, err := MarshalPayload(payload) + if err != nil { + t.Fatal(err) + } + parsed, err := ParsePayload(raw) + if err != nil { + t.Fatal(err) + } + if parsed.Version != 1 || parsed.MediaType != "application/eat+cwt" || string(parsed.Evidence) != "evidence" { + t.Fatalf("unexpected parsed payload: %#v", parsed) + } +} + +func TestVerifyPayloadSuccess(t *testing.T) { + cert, leaf := makeCert(t) + srv, cli := tls13Client(t, cert) + defer srv.Close() + defer cli.Close() + + st := cli.ConnectionState() + ctx := []byte{1, 2, 3, 4} + _, aik, binding, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + + ev := &stubEvidenceVerifier{} + rv := &stubResultsVerifier{} + payload := &Payload{ + Version: 1, + Evidence: []byte("evidence"), + AttestationResults: []byte("results"), + Binder: AttestationBinder{ + ExporterLabel: ExporterLabelAttestation, + AIKPubHash: aik, + Binding: binding, + }, + } + + verified, err := VerifyPayload(&st, ExporterLabelAttestation, ctx, leaf, payload, VerificationPolicy{ + EvidenceVerifier: ev, + ResultsVerifier: rv, + }) + if err != nil { + t.Fatal(err) + } + if !verified.BindingVerified || !verified.EvidenceVerified || !verified.ResultsVerified { + t.Fatalf("unexpected verification result: %#v", verified) + } + if !ev.called || !rv.called { + t.Fatalf("expected both verifiers to be called") + } +} + +func TestVerifyBinderRejectsMismatch(t *testing.T) { + cert, leaf := makeCert(t) + srv, cli := tls13Client(t, cert) + defer srv.Close() + defer cli.Close() + + st := cli.ConnectionState() + ctx := []byte{1, 2, 3, 4} + _, aik, binding, err := ComputeBinding(&st, ExporterLabelAttestation, ctx, leaf) + if err != nil { + t.Fatal(err) + } + binding[0] ^= 0xff + + err = VerifyBinder(&st, ExporterLabelAttestation, ctx, leaf, AttestationBinder{ + AIKPubHash: aik, + Binding: binding, + }) + if err != ErrBindingMismatch { + t.Fatalf("got %v, want %v", err, ErrBindingMismatch) + } +} diff --git a/pkg/atls/eaattestation/types.go b/pkg/atls/eaattestation/types.go new file mode 100644 index 00000000..c1502a25 --- /dev/null +++ b/pkg/atls/eaattestation/types.go @@ -0,0 +1,87 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package attestation + +import ( + "encoding/json" + "errors" +) + +var ( + ErrMalformedPayload = errors.New("attestation: malformed payload") + ErrMissingStatement = errors.New("attestation: missing evidence or attestation results") + ErrMissingBinder = errors.New("attestation: missing attestation binder") + ErrAIKPubHashMismatch = errors.New("attestation: AIK public key hash mismatch") + ErrBindingMismatch = errors.New("attestation: attestation binding mismatch") + ErrEvidenceVerificationMissing = errors.New("attestation: evidence verifier not configured") + ErrResultsVerificationMissing = errors.New("attestation: attestation results verifier not configured") +) + +// Payload models the attestation document carried inside the EA certificate-entry extension. +// It intentionally separates the attestation statement from the TLS binding material. +type Payload struct { + Version int `json:"version"` + MediaType string `json:"media_type,omitempty"` + Evidence []byte `json:"evidence,omitempty"` + AttestationResults []byte `json:"attestation_results,omitempty"` + Binder AttestationBinder `json:"binder"` +} + +type AttestationBinder struct { + ExporterLabel string `json:"exporter_label,omitempty"` + AIKPubHash []byte `json:"aik_pub_hash,omitempty"` + Binding []byte `json:"binding,omitempty"` +} + +type VerifiedPayload struct { + Payload *Payload + EvidenceVerified bool + ResultsVerified bool + BindingVerified bool + UsedExporterLabel string +} + +func (p *Payload) Validate() error { + if p == nil { + return ErrMalformedPayload + } + if len(p.Evidence) == 0 && len(p.AttestationResults) == 0 { + return ErrMissingStatement + } + if len(p.Binder.AIKPubHash) == 0 || len(p.Binder.Binding) == 0 { + return ErrMissingBinder + } + return nil +} + +func (p *Payload) NormalizedExporterLabel(defaultLabel string) string { + if p == nil || p.Binder.ExporterLabel == "" { + return defaultLabel + } + return p.Binder.ExporterLabel +} + +func MarshalPayload(p Payload) ([]byte, error) { + if err := p.Validate(); err != nil { + return nil, err + } + if p.Version == 0 { + p.Version = 1 + } + return json.Marshal(p) +} + +func ParsePayload(raw []byte) (*Payload, error) { + if len(raw) == 0 { + return nil, ErrMalformedPayload + } + var p Payload + if err := json.Unmarshal(raw, &p); err != nil { + return nil, ErrMalformedPayload + } + if err := p.Validate(); err != nil { + return nil, err + } + return &p, nil +} diff --git a/pkg/atls/eaattestation/verify.go b/pkg/atls/eaattestation/verify.go new file mode 100644 index 00000000..cd27ffbc --- /dev/null +++ b/pkg/atls/eaattestation/verify.go @@ -0,0 +1,75 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package attestation + +import ( + "crypto/subtle" + "crypto/tls" + "crypto/x509" +) + +type EvidenceVerifier interface { + VerifyEvidence(evidence []byte) error +} + +type ResultsVerifier interface { + VerifyAttestationResults(results []byte) error +} + +type VerificationPolicy struct { + EvidenceVerifier EvidenceVerifier + ResultsVerifier ResultsVerifier +} + +func VerifyPayload(st *tls.ConnectionState, defaultLabel string, certificateRequestContext []byte, leaf *x509.Certificate, payload *Payload, policy VerificationPolicy) (*VerifiedPayload, error) { + if err := payload.Validate(); err != nil { + return nil, err + } + + verified := &VerifiedPayload{ + Payload: payload, + UsedExporterLabel: payload.NormalizedExporterLabel(defaultLabel), + } + + if len(payload.Evidence) > 0 && policy.EvidenceVerifier != nil { + if err := policy.EvidenceVerifier.VerifyEvidence(payload.Evidence); err != nil { + return nil, err + } + verified.EvidenceVerified = true + } else if len(payload.Evidence) > 0 { + return nil, ErrEvidenceVerificationMissing + } + if len(payload.AttestationResults) > 0 && policy.ResultsVerifier != nil { + if err := policy.ResultsVerifier.VerifyAttestationResults(payload.AttestationResults); err != nil { + return nil, err + } + verified.ResultsVerified = true + } else if len(payload.AttestationResults) > 0 { + return nil, ErrResultsVerificationMissing + } + if err := VerifyBinder(st, verified.UsedExporterLabel, certificateRequestContext, leaf, payload.Binder); err != nil { + return nil, err + } + verified.BindingVerified = true + return verified, nil +} + +func VerifyBinder(st *tls.ConnectionState, label string, certificateRequestContext []byte, leaf *x509.Certificate, binder AttestationBinder) error { + exportedValue, aikPubHash, binding, err := ComputeBinding(st, label, certificateRequestContext, leaf) + if err != nil { + return err + } + _ = exportedValue + if !equalBytes(aikPubHash, binder.AIKPubHash) { + return ErrAIKPubHashMismatch + } + if !equalBytes(binding, binder.Binding) { + return ErrBindingMismatch + } + return nil +} + +func equalBytes(a, b []byte) bool { + return subtle.ConstantTimeCompare(a, b) == 1 +} diff --git a/pkg/atls/evidence_verifier.go b/pkg/atls/evidence_verifier.go new file mode 100644 index 00000000..1f8fb66b --- /dev/null +++ b/pkg/atls/evidence_verifier.go @@ -0,0 +1,92 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "fmt" + "os" + + eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" + cocosattestation "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/azure" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" + "github.com/ultravioletrs/cocos/pkg/attestation/tdx" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" + "github.com/veraison/corim/corim" +) + +type policyEvidenceVerifier struct { + policyPath string +} + +func NewEvidenceVerifier(policyPath string) eaattestation.EvidenceVerifier { + return &policyEvidenceVerifier{policyPath: policyPath} +} + +func (v *policyEvidenceVerifier) VerifyEvidence(evidence []byte) error { + if v.policyPath == "" { + return fmt.Errorf("atls: attestation policy path is not set") + } + claims, err := eat.DecodeCBOR(evidence, nil) + if err != nil { + return fmt.Errorf("atls: failed to decode EAT evidence: %w", err) + } + manifest, err := loadCoRIM(v.policyPath) + if err != nil { + return err + } + verifier, err := platformVerifier(platformTypeFromClaims(claims.PlatformType)) + if err != nil { + return err + } + return verifier.VerifyWithCoRIM(claims.RawReport, manifest) +} + +func loadCoRIM(path string) (*corim.UnsignedCorim, error) { + corimBytes, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("atls: failed to read CoRIM file: %w", err) + } + + var sc corim.SignedCorim + if err := sc.FromCOSE(corimBytes); err == nil { + return &sc.UnsignedCorim, nil + } + + var uc corim.UnsignedCorim + if err := uc.FromCBOR(corimBytes); err != nil { + return nil, fmt.Errorf("atls: failed to parse CoRIM: %w", err) + } + return &uc, nil +} + +func platformTypeFromClaims(name string) cocosattestation.PlatformType { + switch name { + case "SNP": + return cocosattestation.SNP + case "TDX": + return cocosattestation.TDX + case "vTPM": + return cocosattestation.VTPM + case "SNP-vTPM": + return cocosattestation.SNPvTPM + case "Azure": + return cocosattestation.Azure + default: + return cocosattestation.NoCC + } +} + +func platformVerifier(platformType cocosattestation.PlatformType) (cocosattestation.Verifier, error) { + switch platformType { + case cocosattestation.SNP, cocosattestation.SNPvTPM, cocosattestation.VTPM: + return vtpm.NewVerifier(nil), nil + case cocosattestation.Azure: + return azure.NewVerifier(nil), nil + case cocosattestation.TDX: + return tdx.NewVerifier(), nil + default: + return nil, fmt.Errorf("atls: unsupported platform type: %d", platformType) + } +} diff --git a/pkg/atls/internal_transport/conn.go b/pkg/atls/internal_transport/conn.go new file mode 100644 index 00000000..00202056 --- /dev/null +++ b/pkg/atls/internal_transport/conn.go @@ -0,0 +1,280 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package internaltransport + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "time" + + "github.com/ultravioletrs/cocos/pkg/atls/ea" + eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" +) + +type Conn struct { + *tls.Conn + Request *ea.AuthenticatorRequest + ValidationResult *ea.ValidationResult +} + +type ClientConfig struct { + TLSConfig *tls.Config + Session *ea.Session + VerifyOptions *x509.VerifyOptions + AttestationPolicy eaattestation.VerificationPolicy + Request *ea.AuthenticatorRequest + RequestBuilder func() (*ea.AuthenticatorRequest, error) +} + +type ServerConfig struct { + TLSConfig *tls.Config + Session *ea.Session + Identity tls.Certificate + BuildLeafExtensions func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) ([]ea.Extension, error) +} + +func Dial(network, address string, cfg *ClientConfig) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, address, cfg) +} + +func DialContext(ctx context.Context, network, address string, cfg *ClientConfig) (*Conn, error) { + return DialContextWithDialer(ctx, new(net.Dialer), network, address, cfg) +} + +func DialWithDialer(d *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) { + if cfg == nil || cfg.TLSConfig == nil { + return nil, fmt.Errorf("atls: missing client TLS config") + } + + rawConn, err := d.Dial(network, address) + if err != nil { + return nil, err + } + + tlsConn := tls.Client(rawConn, cfg.TLSConfig.Clone()) + conn, err := Client(tlsConn, cfg) + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return conn, nil +} + +func DialContextWithDialer(ctx context.Context, d *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) { + if cfg == nil || cfg.TLSConfig == nil { + return nil, fmt.Errorf("atls: missing client TLS config") + } + if ctx == nil { + return nil, fmt.Errorf("atls: missing client context") + } + + rawConn, err := d.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + + tlsConn := tls.Client(rawConn, cfg.TLSConfig.Clone()) + conn, err := ClientContext(ctx, tlsConn, cfg) + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return conn, nil +} + +func Client(tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) { + return ClientContext(context.Background(), tlsConn, cfg) +} + +func ClientContext(ctx context.Context, tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) { + if cfg == nil { + return nil, fmt.Errorf("atls: missing client config") + } + if ctx == nil { + return nil, fmt.Errorf("atls: missing client context") + } + var res *Conn + err := withConnContext(ctx, tlsConn, func() error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + + req, err := buildRequest(cfg) + if err != nil { + return err + } + reqBytes, err := req.Marshal() + if err != nil { + return err + } + + if err := writeFrame(tlsConn, frameTypeRequest, reqBytes); err != nil { + return err + } + frameType, authBytes, err := readFrame(tlsConn) + if err != nil { + return err + } + + if frameType != frameTypeAuthenticator { + return fmt.Errorf("atls: unexpected frame type %d", frameType) + } + + st := tlsConn.ConnectionState() + var validation *ea.ValidationResult + if cfg.Session != nil { + validation, err = cfg.Session.ValidateAuthenticatorWithAttestation(&st, ea.RoleServer, req, authBytes, cfg.VerifyOptions, cfg.AttestationPolicy) + } else { + validation, err = ea.ValidateAuthenticatorWithAttestation(&st, ea.RoleServer, req, authBytes, cfg.VerifyOptions, cfg.AttestationPolicy) + } + if err != nil { + return err + } + + res = &Conn{Conn: tlsConn, Request: req, ValidationResult: validation} + return nil + }) + if err != nil { + return nil, err + } + + return res, nil +} + +func Server(tlsConn *tls.Conn, cfg *ServerConfig) (*Conn, error) { + if cfg == nil { + return nil, fmt.Errorf("atls: missing server config") + } + + if err := tlsConn.Handshake(); err != nil { + return nil, err + } + + frameType, reqBytes, err := readFrame(tlsConn) + if err != nil { + return nil, err + } + if frameType != frameTypeRequest { + return nil, fmt.Errorf("atls: unexpected frame type %d", frameType) + } + req, rest, err := ea.UnmarshalAuthenticatorRequest(reqBytes) + if err != nil { + return nil, err + } + if len(rest) != 0 { + return nil, fmt.Errorf("atls: trailing request bytes") + } + + st := tlsConn.ConnectionState() + identity, err := resolveIdentity(cfg) + if err != nil { + return nil, err + } + exts, err := buildServerExtensions(cfg, &st, &req, identity) + if err != nil { + return nil, err + } + + var authBytes []byte + if cfg.Session != nil { + authBytes, err = cfg.Session.CreateAuthenticator(&st, ea.RoleServer, &req, identity, exts) + } else { + authBytes, err = ea.CreateAuthenticator(&st, ea.RoleServer, &req, identity, exts) + } + if err != nil { + return nil, err + } + + if err := writeFrame(tlsConn, frameTypeAuthenticator, authBytes); err != nil { + return nil, err + } + + return &Conn{Conn: tlsConn, Request: &req}, nil +} + +func buildRequest(cfg *ClientConfig) (*ea.AuthenticatorRequest, error) { + if cfg.RequestBuilder != nil { + return cfg.RequestBuilder() + } + if cfg.Request != nil { + return cfg.Request, nil + } + ctx, err := ea.NewRandomContext(32) + if err != nil { + return nil, err + } + sigExt, err := ea.SignatureAlgorithmsExtension([]uint16{uint16(tls.ECDSAWithP256AndSHA256)}) + if err != nil { + return nil, err + } + return &ea.AuthenticatorRequest{ + Type: ea.HandshakeTypeClientCertificateRequest, + Context: ctx, + Extensions: []ea.Extension{ + sigExt, + ea.CMWAttestationOfferExtension(), + }, + }, nil +} + +func resolveIdentity(cfg *ServerConfig) (tls.Certificate, error) { + if len(cfg.Identity.Certificate) > 0 && cfg.Identity.PrivateKey != nil { + return cfg.Identity, nil + } + if cfg.TLSConfig != nil && len(cfg.TLSConfig.Certificates) > 0 { + return cfg.TLSConfig.Certificates[0], nil + } + return tls.Certificate{}, fmt.Errorf("atls: missing server identity") +} + +func buildServerExtensions(cfg *ServerConfig, st *tls.ConnectionState, req *ea.AuthenticatorRequest, identity tls.Certificate) ([]ea.Extension, error) { + if cfg.BuildLeafExtensions == nil { + return nil, nil + } + if len(identity.Certificate) == 0 { + return nil, fmt.Errorf("atls: missing server leaf certificate") + } + leaf, err := x509.ParseCertificate(identity.Certificate[0]) + if err != nil { + return nil, err + } + return cfg.BuildLeafExtensions(st, req, leaf) +} + +func withConnContext(ctx context.Context, conn interface{ SetDeadline(time.Time) error }, fn func() error) error { + if ctx == nil { + return fn() + } + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return err + } + defer func() { + _ = conn.SetDeadline(time.Time{}) + }() + } else { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = conn.SetDeadline(time.Now()) + case <-done: + } + }() + defer func() { + close(done) + _ = conn.SetDeadline(time.Time{}) + }() + } + + err := fn() + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + return err +} diff --git a/pkg/atls/internal_transport/conn_test.go b/pkg/atls/internal_transport/conn_test.go new file mode 100644 index 00000000..86435026 --- /dev/null +++ b/pkg/atls/internal_transport/conn_test.go @@ -0,0 +1,100 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package internaltransport + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net" + "testing" + "time" +) + +func selfSignedCert(t *testing.T) tls.Certificate { + t.Helper() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "internal-transport"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost"}, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + + return tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: priv, + } +} + +func TestServerAllowsIdentityWithoutTLSConfig(t *testing.T) { + cert := selfSignedCert(t) + a, b := net.Pipe() + + serverTLS := tls.Server(a, &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + }) + clientTLS := tls.Client(b, &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + }) + + type result struct { + conn *Conn + err error + } + serverCh := make(chan result, 1) + clientCh := make(chan result, 1) + + go func() { + conn, err := Server(serverTLS, &ServerConfig{ + Identity: cert, + }) + serverCh <- result{conn: conn, err: err} + }() + + go func() { + conn, err := Client(clientTLS, &ClientConfig{ + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + }, + }) + clientCh <- result{conn: conn, err: err} + }() + + srvRes := <-serverCh + cliRes := <-clientCh + + if srvRes.err != nil { + t.Fatalf("server failed: %v", srvRes.err) + } + if cliRes.err != nil { + t.Fatalf("client failed: %v", cliRes.err) + } + + defer srvRes.conn.Close() + defer cliRes.conn.Close() +} diff --git a/pkg/atls/internal_transport/listener.go b/pkg/atls/internal_transport/listener.go new file mode 100644 index 00000000..ae94de14 --- /dev/null +++ b/pkg/atls/internal_transport/listener.go @@ -0,0 +1,49 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package internaltransport + +import ( + "crypto/tls" + "fmt" + "net" +) + +type Listener struct { + raw net.Listener + cfg *ServerConfig +} + +func Listen(network, address string, cfg *ServerConfig) (*Listener, error) { + if cfg == nil || cfg.TLSConfig == nil { + return nil, fmt.Errorf("atls: missing server TLS config") + } + raw, err := net.Listen(network, address) + if err != nil { + return nil, err + } + return &Listener{raw: raw, cfg: cfg}, nil +} + +func (l *Listener) Accept() (net.Conn, error) { + rawConn, err := l.raw.Accept() + if err != nil { + return nil, err + } + tlsConn := tls.Server(rawConn, l.cfg.TLSConfig.Clone()) + conn, err := Server(tlsConn, l.cfg) + if err != nil { + _ = tlsConn.Close() + return nil, err + } + + return conn, nil +} + +func (l *Listener) Close() error { + return l.raw.Close() +} + +func (l *Listener) Addr() net.Addr { + return l.raw.Addr() +} diff --git a/pkg/atls/internal_transport/protocol.go b/pkg/atls/internal_transport/protocol.go new file mode 100644 index 00000000..cdfd0b1d --- /dev/null +++ b/pkg/atls/internal_transport/protocol.go @@ -0,0 +1,62 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package internaltransport + +import ( + "encoding/binary" + "fmt" + "io" +) + +const ( + frameTypeRequest uint8 = iota + 1 + frameTypeAuthenticator + maxFramePayloadLen = 16 << 20 // 16 MiB +) + +func writeFrame(w io.Writer, typ uint8, payload []byte) error { + if len(payload) > maxFramePayloadLen { + return fmt.Errorf("atls: frame payload too large: %d", len(payload)) + } + header := make([]byte, 5) + header[0] = typ + binary.BigEndian.PutUint32(header[1:5], uint32(len(payload))) + if err := writeAll(w, header); err != nil { + return err + } + if len(payload) == 0 { + return nil + } + return writeAll(w, payload) +} + +func writeAll(w io.Writer, b []byte) error { + for len(b) > 0 { + n, err := w.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil +} + +func readFrame(r io.Reader) (uint8, []byte, error) { + header := make([]byte, 5) + if _, err := io.ReadFull(r, header); err != nil { + return 0, nil, err + } + if header[0] == 0 { + return 0, nil, fmt.Errorf("atls: invalid frame type") + } + n := binary.BigEndian.Uint32(header[1:5]) + if n > maxFramePayloadLen { + return 0, nil, fmt.Errorf("atls: frame payload too large: %d", n) + } + payload := make([]byte, n) + if _, err := io.ReadFull(r, payload); err != nil { + return 0, nil, err + } + return header[0], payload, nil +} diff --git a/pkg/atls/internal_transport/protocol_test.go b/pkg/atls/internal_transport/protocol_test.go new file mode 100644 index 00000000..bc0b696e --- /dev/null +++ b/pkg/atls/internal_transport/protocol_test.go @@ -0,0 +1,40 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package internaltransport + +import ( + "bytes" + "encoding/binary" + "strings" + "testing" +) + +func TestReadFrameRejectsOversizedPayload(t *testing.T) { + var buf bytes.Buffer + header := make([]byte, 5) + header[0] = frameTypeRequest + binary.BigEndian.PutUint32(header[1:5], maxFramePayloadLen+1) + if _, err := buf.Write(header); err != nil { + t.Fatal(err) + } + + _, _, err := readFrame(&buf) + if err == nil { + t.Fatal("expected oversized frame error") + } + if !strings.Contains(err.Error(), "frame payload too large") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestWriteFrameRejectsOversizedPayload(t *testing.T) { + payload := make([]byte, maxFramePayloadLen+1) + err := writeFrame(bytes.NewBuffer(nil), frameTypeAuthenticator, payload) + if err == nil { + t.Fatal("expected oversized frame error") + } + if !strings.Contains(err.Error(), "frame payload too large") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/pkg/atls/mocks/certificateprovider.go b/pkg/atls/mocks/certificateprovider.go index 77524bc5..1e1b63f6 100644 --- a/pkg/atls/mocks/certificateprovider.go +++ b/pkg/atls/mocks/certificateprovider.go @@ -1,103 +1,53 @@ +// Code generated manually for tests. // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - package mocks import ( "crypto/tls" + "crypto/x509" - mock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/pkg/atls/ea" ) -// NewCertificateProvider creates a new instance of CertificateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewCertificateProvider(t interface { - mock.TestingT - Cleanup(func()) -}) *CertificateProvider { - mock := &CertificateProvider{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// CertificateProvider is an autogenerated mock type for the CertificateProvider type type CertificateProvider struct { mock.Mock } -type CertificateProvider_Expecter struct { - mock *mock.Mock -} +func (_m *CertificateProvider) BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) { + ret := _m.Called(st, req, leaf) -func (_m *CertificateProvider) EXPECT() *CertificateProvider_Expecter { - return &CertificateProvider_Expecter{mock: &_m.Mock} -} - -// GetCertificate provides a mock function for the type CertificateProvider -func (_mock *CertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - ret := _mock.Called(clientHello) - - if len(ret) == 0 { - panic("no return value specified for GetCertificate") - } - - var r0 *tls.Certificate + var r0 []ea.Extension var r1 error - if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) (*tls.Certificate, error)); ok { - return returnFunc(clientHello) + + if rf, ok := ret.Get(0).(func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) ([]ea.Extension, error)); ok { + return rf(st, req, leaf) } - if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) *tls.Certificate); ok { - r0 = returnFunc(clientHello) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*tls.Certificate) - } + if ret.Get(0) != nil { + r0 = ret.Get(0).([]ea.Extension) } - if returnFunc, ok := ret.Get(1).(func(*tls.ClientHelloInfo) error); ok { - r1 = returnFunc(clientHello) + + if rf, ok := ret.Get(1).(func(*tls.ConnectionState, *ea.AuthenticatorRequest, *x509.Certificate) error); ok { + r1 = rf(st, req, leaf) } else { r1 = ret.Error(1) } + return r0, r1 } -// CertificateProvider_GetCertificate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCertificate' -type CertificateProvider_GetCertificate_Call struct { - *mock.Call -} +func NewCertificateProvider(t interface { + mock.TestingT + Cleanup(func()) +}) *CertificateProvider { + mockProvider := &CertificateProvider{} + mockProvider.Mock.Test(t) -// GetCertificate is a helper method to define mock.On call -// - clientHello *tls.ClientHelloInfo -func (_e *CertificateProvider_Expecter) GetCertificate(clientHello interface{}) *CertificateProvider_GetCertificate_Call { - return &CertificateProvider_GetCertificate_Call{Call: _e.mock.On("GetCertificate", clientHello)} -} - -func (_c *CertificateProvider_GetCertificate_Call) Run(run func(clientHello *tls.ClientHelloInfo)) *CertificateProvider_GetCertificate_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 *tls.ClientHelloInfo - if args[0] != nil { - arg0 = args[0].(*tls.ClientHelloInfo) - } - run( - arg0, - ) + t.Cleanup(func() { + mockProvider.AssertExpectations(t) }) - return _c -} -func (_c *CertificateProvider_GetCertificate_Call) Return(certificate *tls.Certificate, err error) *CertificateProvider_GetCertificate_Call { - _c.Call.Return(certificate, err) - return _c -} - -func (_c *CertificateProvider_GetCertificate_Call) RunAndReturn(run func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)) *CertificateProvider_GetCertificate_Call { - _c.Call.Return(run) - return _c + return mockProvider } diff --git a/pkg/atls/provider.go b/pkg/atls/provider.go new file mode 100644 index 00000000..4235d673 --- /dev/null +++ b/pkg/atls/provider.go @@ -0,0 +1,84 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "context" + "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "fmt" + + "github.com/absmach/certs/sdk" + "github.com/ultravioletrs/cocos/pkg/atls/ea" + eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" + cocosattestation "github.com/ultravioletrs/cocos/pkg/attestation" + attestationclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" +) + +// CertificateProvider is kept for compatibility with existing cocos call sites. +// In the EA-based implementation it provides the leaf certificate-entry extensions +// carried in the exported authenticator instead of generating TLS certificates. +type CertificateProvider interface { + BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) +} + +type provider struct { + attClient attestationclient.Client + platformType cocosattestation.PlatformType +} + +func NewProvider(attClient attestationclient.Client, platformType cocosattestation.PlatformType, _ string, _ string, _ sdk.SDK) (CertificateProvider, error) { + if attClient == nil { + return nil, fmt.Errorf("atls: missing attestation client") + } + if platformType == cocosattestation.NoCC { + return nil, fmt.Errorf("atls: confidential computing platform not available") + } + return &provider{ + attClient: attClient, + platformType: platformType, + }, nil +} + +func (p *provider) BuildLeafExtensions(st *tls.ConnectionState, req *ea.AuthenticatorRequest, leaf *x509.Certificate) ([]ea.Extension, error) { + if st == nil || req == nil || leaf == nil { + return nil, fmt.Errorf("atls: missing state, request, or leaf certificate") + } + exportedValue, aikPubHash, binding, err := eaattestation.ComputeBinding(st, eaattestation.ExporterLabelAttestation, req.Context, leaf) + if err != nil { + return nil, err + } + + reportData := sha512.Sum512(binding) + nonceBytes := sha256.Sum256(exportedValue) + var nonce [32]byte + copy(nonce[:], nonceBytes[:]) + + evidence, err := p.attClient.GetAttestation(context.Background(), reportData, nonce, p.platformType) + if err != nil { + return nil, fmt.Errorf("atls: failed to fetch attestation evidence: %w", err) + } + + payloadBytes, err := eaattestation.MarshalPayload(eaattestation.Payload{ + Version: 1, + MediaType: "application/eat+cwt", + Evidence: evidence, + Binder: eaattestation.AttestationBinder{ + ExporterLabel: eaattestation.ExporterLabelAttestation, + AIKPubHash: aikPubHash, + Binding: binding, + }, + }) + if err != nil { + return nil, err + } + + ext, err := ea.CMWAttestationDataExtension(payloadBytes) + if err != nil { + return nil, err + } + return []ea.Extension{ext}, nil +} diff --git a/pkg/atls/server_tls.go b/pkg/atls/server_tls.go new file mode 100644 index 00000000..4383e439 --- /dev/null +++ b/pkg/atls/server_tls.go @@ -0,0 +1,93 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "time" +) + +// BuildServerTLSConfig prepares the base TLS configuration used by the EA/aTLS +// transport. If no certificate/key pair is configured, it falls back to an +// ephemeral self-signed identity bound by the exported authenticator. +func BuildServerTLSConfig(certFile, keyFile, serverCAFile, clientCAFile string) (*tls.Config, tls.Certificate, bool, error) { + if certFile != "" || keyFile != "" { + tlsSetup, err := setupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile) + if err != nil { + return nil, tls.Certificate{}, false, err + } + tlsSetup.config.MinVersion = tls.VersionTLS13 + return tlsSetup.config, tlsSetup.config.Certificates[0], tlsSetup.mtls, nil + } + + identity, err := generateEphemeralIdentity() + if err != nil { + return nil, tls.Certificate{}, false, fmt.Errorf("failed to generate ephemeral TLS identity: %w", err) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + ClientAuth: tls.NoClientCert, + Certificates: []tls.Certificate{identity}, + } + + mtls, err := configureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile) + if err != nil { + return nil, tls.Certificate{}, false, err + } + if mtls { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + return tlsConfig, identity, mtls, nil +} + +func generateEphemeralIdentity() (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + serialLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialLimit) + if err != nil { + return tls.Certificate{}, err + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "cocos-atls-ephemeral", + Organization: []string{"Ultraviolet"}, + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + leaf, err := x509.ParseCertificate(der) + if err != nil { + return tls.Certificate{}, err + } + + return tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: priv, + Leaf: leaf, + }, nil +} diff --git a/pkg/atls/tls_helpers.go b/pkg/atls/tls_helpers.go new file mode 100644 index 00000000..f65a20e9 --- /dev/null +++ b/pkg/atls/tls_helpers.go @@ -0,0 +1,104 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "strings" +) + +type tlsSetupResult struct { + config *tls.Config + mtls bool +} + +func readFileOrData(input string) ([]byte, error) { + if len(input) < 1000 && !strings.Contains(input, "\n") { + data, err := os.ReadFile(input) + if err == nil { + return data, nil + } + return nil, err + } + return []byte(input), nil +} + +func loadX509KeyPair(certFile, keyFile string) (tls.Certificate, error) { + cert, err := readFileOrData(certFile) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to read cert: %w", err) + } + + key, err := readFileOrData(keyFile) + if err != nil { + return tls.Certificate{}, fmt.Errorf("failed to read key: %w", err) + } + + return tls.X509KeyPair(cert, key) +} + +func loadCertFile(certFile string) ([]byte, error) { + if certFile == "" { + return []byte{}, nil + } + return readFileOrData(certFile) +} + +func configureCertificateAuthorities(tlsConfig *tls.Config, serverCAFile, clientCAFile string) (bool, error) { + rootCA, err := loadCertFile(serverCAFile) + if err != nil { + return false, fmt.Errorf("failed to load server ca file: %w", err) + } + if len(rootCA) > 0 { + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) { + return false, fmt.Errorf("failed to append server ca to tls.Config") + } + } + + clientCA, err := loadCertFile(clientCAFile) + if err != nil { + return false, fmt.Errorf("failed to load client ca file: %w", err) + } + if len(clientCA) == 0 { + return false, nil + } + + if tlsConfig.ClientCAs == nil { + tlsConfig.ClientCAs = x509.NewCertPool() + } + if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) { + return false, fmt.Errorf("failed to append client ca to tls.Config") + } + + return true, nil +} + +func setupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*tlsSetupResult, error) { + certificate, err := loadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("failed to load auth certificates: %w", err) + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + ClientAuth: tls.NoClientCert, + Certificates: []tls.Certificate{certificate}, + } + + mtls, err := configureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile) + if err != nil { + return nil, err + } + if mtls { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + + return &tlsSetupResult{config: tlsConfig, mtls: mtls}, nil +} diff --git a/pkg/atls/transport.go b/pkg/atls/transport.go new file mode 100644 index 00000000..bdaac1a3 --- /dev/null +++ b/pkg/atls/transport.go @@ -0,0 +1,83 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + + "github.com/ultravioletrs/cocos/pkg/atls/ea" + eaattestation "github.com/ultravioletrs/cocos/pkg/atls/eaattestation" + internaltransport "github.com/ultravioletrs/cocos/pkg/atls/internal_transport" +) + +type Conn = internaltransport.Conn + +type Listener = internaltransport.Listener + +type ClientConfig = internaltransport.ClientConfig + +type ServerConfig = internaltransport.ServerConfig + +type AuthenticatorRequest = ea.AuthenticatorRequest + +func Dial(network, address string, cfg *ClientConfig) (*Conn, error) { + return internaltransport.Dial(network, address, cfg) +} + +func DialContext(ctx context.Context, network, address string, cfg *ClientConfig) (*Conn, error) { + return internaltransport.DialContext(ctx, network, address, cfg) +} + +func DialWithDialer(dialer *net.Dialer, network, address string, cfg *ClientConfig) (*Conn, error) { + return internaltransport.DialWithDialer(dialer, network, address, cfg) +} + +func Client(tlsConn *tls.Conn, cfg *ClientConfig) (*Conn, error) { + return internaltransport.Client(tlsConn, cfg) +} + +func Server(tlsConn *tls.Conn, cfg *ServerConfig) (*Conn, error) { + return internaltransport.Server(tlsConn, cfg) +} + +func Listen(network, address string, cfg *ServerConfig) (*Listener, error) { + return internaltransport.Listen(network, address, cfg) +} + +func NewRequest(context []byte) (*ea.AuthenticatorRequest, error) { + sigExt, err := ea.SignatureAlgorithmsExtension([]uint16{uint16(tls.ECDSAWithP256AndSHA256)}) + if err != nil { + return nil, err + } + return &ea.AuthenticatorRequest{ + Type: ea.HandshakeTypeClientCertificateRequest, + Context: context, + Extensions: []ea.Extension{ + sigExt, + ea.CMWAttestationOfferExtension(), + }, + }, nil +} + +func NewRandomRequest(contextLen int) (*ea.AuthenticatorRequest, error) { + context, err := ea.NewRandomContext(contextLen) + if err != nil { + return nil, err + } + return NewRequest(context) +} + +func VerifyOptionsFromTLSConfig(cfg *tls.Config) *x509.VerifyOptions { + if cfg == nil || cfg.InsecureSkipVerify || cfg.RootCAs == nil { + return nil + } + return &x509.VerifyOptions{Roots: cfg.RootCAs} +} + +func VerificationPolicyFromEvidenceVerifier(v eaattestation.EvidenceVerifier) eaattestation.VerificationPolicy { + return eaattestation.VerificationPolicy{EvidenceVerifier: v} +} diff --git a/pkg/atls/transport_test.go b/pkg/atls/transport_test.go new file mode 100644 index 00000000..ac3784ba --- /dev/null +++ b/pkg/atls/transport_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "crypto/tls" + "crypto/x509" + "testing" +) + +func TestVerifyOptionsFromTLSConfig(t *testing.T) { + t.Run("nil config", func(t *testing.T) { + if got := VerifyOptionsFromTLSConfig(nil); got != nil { + t.Fatalf("expected nil verify options, got %#v", got) + } + }) + + t.Run("skip verify disables ea chain validation", func(t *testing.T) { + got := VerifyOptionsFromTLSConfig(&tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + }) + if got != nil { + t.Fatalf("expected nil verify options for insecure skip verify, got %#v", got) + } + }) + + t.Run("missing roots disables ea chain validation", func(t *testing.T) { + got := VerifyOptionsFromTLSConfig(&tls.Config{ + MinVersion: tls.VersionTLS13, + }) + if got != nil { + t.Fatalf("expected nil verify options when roots are not configured, got %#v", got) + } + }) + + t.Run("configured roots are propagated", func(t *testing.T) { + roots := x509.NewCertPool() + got := VerifyOptionsFromTLSConfig(&tls.Config{ + RootCAs: roots, + MinVersion: tls.VersionTLS13, + }) + if got == nil { + t.Fatal("expected verify options, got nil") + } + if got.Roots != roots { + t.Fatal("expected verify options to reuse configured root CAs") + } + }) +} + +func TestNewRandomRequest(t *testing.T) { + req1, err := NewRandomRequest(32) + if err != nil { + t.Fatalf("first request failed: %v", err) + } + req2, err := NewRandomRequest(32) + if err != nil { + t.Fatalf("second request failed: %v", err) + } + + if len(req1.Context) != 32 { + t.Fatalf("expected first request context length 32, got %d", len(req1.Context)) + } + if len(req2.Context) != 32 { + t.Fatalf("expected second request context length 32, got %d", len(req2.Context)) + } + if len(req1.Extensions) == 0 { + t.Fatal("expected first request to carry extensions") + } + if len(req2.Extensions) == 0 { + t.Fatal("expected second request to carry extensions") + } + if string(req1.Context) == string(req2.Context) { + t.Fatal("expected random request contexts to differ") + } +} diff --git a/pkg/clients/clients.go b/pkg/clients/clients.go index 9a77e5da..d0f66f6c 100644 --- a/pkg/clients/clients.go +++ b/pkg/clients/clients.go @@ -3,11 +3,18 @@ package clients -import "time" +import ( + "encoding/hex" + "errors" + "fmt" + "time" +) var ( _ ClientConfiguration = (*AttestedClientConfig)(nil) _ ClientConfiguration = (*StandardClientConfig)(nil) + + ErrInvalidAttestationRequestContext = errors.New("invalid attestation request context") ) type ClientConfiguration interface { @@ -29,6 +36,13 @@ type AttestedClientConfig struct { AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""` AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"` + // AttestationRequestContextHex, when set, is decoded from hex and used as + // the exported authenticator certificate_request_context. This lets the + // caller provide the background-check freshness value directly. + AttestationRequestContextHex string `env:"ATTESTATION_REQUEST_CONTEXT" envDefault:""` + // AttestationRequestContext allows callers inside the same process to pass + // raw request-context bytes directly instead of using the hex string form. + AttestationRequestContext []byte `env:"-"` } func (c AttestedClientConfig) Config() StandardClientConfig { @@ -38,3 +52,20 @@ func (c AttestedClientConfig) Config() StandardClientConfig { func (c StandardClientConfig) Config() StandardClientConfig { return c } + +func (c AttestedClientConfig) RequestContext() ([]byte, error) { + if len(c.AttestationRequestContext) > 0 { + return append([]byte(nil), c.AttestationRequestContext...), nil + } + if c.AttestationRequestContextHex == "" { + return nil, nil + } + requestContext, err := hex.DecodeString(c.AttestationRequestContextHex) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrInvalidAttestationRequestContext, err) + } + if len(requestContext) == 0 { + return nil, fmt.Errorf("%w: decoded value is empty", ErrInvalidAttestationRequestContext) + } + return requestContext, nil +} diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go index ea524994..96d7fd99 100644 --- a/pkg/clients/grpc/connect_test.go +++ b/pkg/clients/grpc/connect_test.go @@ -103,6 +103,22 @@ func TestNewClient(t *testing.T) { wantErr: false, err: nil, }, + { + name: "Success agent client with aTLS and custom request context", + agentCfg: clients.AttestedClientConfig{ + StandardClientConfig: clients.StandardClientConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + AttestedTLS: true, + AttestationPolicy: policyFile.Name(), + AttestationRequestContextHex: "01020304", + }, + wantErr: false, + err: nil, + }, { name: "Failed agent client with aTLS", agentCfg: clients.AttestedClientConfig{ @@ -118,6 +134,22 @@ func TestNewClient(t *testing.T) { wantErr: true, err: fmt.Errorf("failed to stat attestation policy file"), }, + { + name: "Failed agent client with invalid attestation request context", + agentCfg: clients.AttestedClientConfig{ + StandardClientConfig: clients.StandardClientConfig{ + URL: "localhost:7001", + ServerCAFile: caCertFile, + ClientCert: clientCertFile, + ClientKey: clientKeyFile, + }, + AttestedTLS: true, + AttestationPolicy: policyFile.Name(), + AttestationRequestContextHex: "xyz", + }, + wantErr: true, + err: clients.ErrInvalidAttestationRequestContext, + }, { name: "Fail with invalid ServerCAFile", cfg: clients.StandardClientConfig{ diff --git a/pkg/clients/grpc/grpc.go b/pkg/clients/grpc/grpc.go index b68e7414..59ef819e 100644 --- a/pkg/clients/grpc/grpc.go +++ b/pkg/clients/grpc/grpc.go @@ -4,7 +4,13 @@ package grpc import ( + "context" + stdtls "crypto/tls" + "net" + "strings" + "github.com/absmach/supermq/pkg/errors" + "github.com/ultravioletrs/cocos/pkg/atls" "github.com/ultravioletrs/cocos/pkg/clients" "github.com/ultravioletrs/cocos/pkg/tls" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -77,7 +83,43 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, e return nil, security, err } - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(result.Config))) + tlsConfig := result.Config.Clone() + tlsConfig.MinVersion = stdtls.VersionTLS13 + tlsConfig.NextProtos = []string{"h2"} + + atlsConfig := &atls.ClientConfig{ + TLSConfig: tlsConfig, + VerifyOptions: atls.VerifyOptionsFromTLSConfig(tlsConfig), + AttestationPolicy: atls.VerificationPolicyFromEvidenceVerifier(atls.NewEvidenceVerifier(agcfg.AttestationPolicy)), + } + requestContext, err := agcfg.RequestContext() + if err != nil { + return nil, security, err + } + if len(requestContext) > 0 { + req, err := atls.NewRequest(requestContext) + if err != nil { + return nil, security, err + } + atlsConfig.Request = req + } else { + atlsConfig.RequestBuilder = func() (*atls.AuthenticatorRequest, error) { + return atls.NewRandomRequest(32) + } + } + + opts = append(opts, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + network, target := dialTarget(addr) + conn, err := atls.DialContext(ctx, network, target, atlsConfig) + if err != nil { + return nil, err + } + + return conn, nil + }), + ) security = result.Security } else { conf := cfg.Config() @@ -96,6 +138,16 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, e return conn, security, nil } +func dialTarget(addr string) (string, string) { + if strings.HasPrefix(addr, "unix://") { + return "unix", strings.TrimPrefix(addr, "unix://") + } + if strings.HasPrefix(addr, "/") { + return "unix", addr + } + return "tcp", addr +} + func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) { result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey) if err != nil { diff --git a/pkg/clients/http/client.go b/pkg/clients/http/client.go index 97fd3933..1089c259 100644 --- a/pkg/clients/http/client.go +++ b/pkg/clients/http/client.go @@ -4,9 +4,14 @@ package http import ( + "context" + stdtls "crypto/tls" + "net" "net/http" + "strings" "time" + "github.com/ultravioletrs/cocos/pkg/atls" "github.com/ultravioletrs/cocos/pkg/clients" "github.com/ultravioletrs/cocos/pkg/tls" ) @@ -70,7 +75,33 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Secu return nil, security, err } - transport.TLSClientConfig = result.Config + tlsConfig := result.Config.Clone() + tlsConfig.MinVersion = stdtls.VersionTLS13 + atlsConfig := &atls.ClientConfig{ + TLSConfig: tlsConfig, + VerifyOptions: atls.VerifyOptionsFromTLSConfig(tlsConfig), + AttestationPolicy: atls.VerificationPolicyFromEvidenceVerifier(atls.NewEvidenceVerifier(agcfg.AttestationPolicy)), + } + requestContext, err := agcfg.RequestContext() + if err != nil { + return nil, security, err + } + if len(requestContext) > 0 { + req, err := atls.NewRequest(requestContext) + if err != nil { + return nil, security, err + } + atlsConfig.Request = req + } else { + atlsConfig.RequestBuilder = func() (*atls.AuthenticatorRequest, error) { + return atls.NewRandomRequest(32) + } + } + + transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + dialNetwork, target := httpDialTarget(network, addr) + return atls.DialContext(ctx, dialNetwork, target, atlsConfig) + } security = result.Security } else { conf := cfg.Config() @@ -89,3 +120,13 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Secu return transport, security, nil } + +func httpDialTarget(network, addr string) (string, string) { + if strings.HasPrefix(addr, "unix://") { + return "unix", strings.TrimPrefix(addr, "unix://") + } + if strings.HasPrefix(addr, "/") { + return "unix", addr + } + return network, addr +} diff --git a/pkg/clients/http/client_test.go b/pkg/clients/http/client_test.go index 3686b8c9..7bb27618 100644 --- a/pkg/clients/http/client_test.go +++ b/pkg/clients/http/client_test.go @@ -5,6 +5,7 @@ package http import ( "net/http" + "os" "testing" "time" @@ -210,6 +211,62 @@ func TestCreateTransport_ATLSError(t *testing.T) { assert.Contains(t, err.Error(), "failed to stat attestation policy") } +func TestCreateTransport_ATLSCustomRequestContext(t *testing.T) { + policyFile, err := os.CreateTemp("", "attestation_policy.json") + assert.NoError(t, err) + _, err = policyFile.WriteString("{}") + assert.NoError(t, err) + assert.NoError(t, policyFile.Close()) + t.Cleanup(func() { + _ = os.Remove(policyFile.Name()) + }) + + config := &clients.AttestedClientConfig{ + StandardClientConfig: clients.StandardClientConfig{ + URL: "https://agent.example.com", + Timeout: 60 * time.Second, + }, + AttestationPolicy: policyFile.Name(), + AttestedTLS: true, + AttestationRequestContextHex: "01020304", + } + + transport, security, err := createTransport(config) + + assert.NoError(t, err) + assert.NotNil(t, transport) + assert.Equal(t, tls.WithATLS, security) + assert.NotNil(t, transport.DialTLSContext) +} + +func TestCreateTransport_ATLSInvalidRequestContext(t *testing.T) { + policyFile, err := os.CreateTemp("", "attestation_policy.json") + assert.NoError(t, err) + _, err = policyFile.WriteString("{}") + assert.NoError(t, err) + assert.NoError(t, policyFile.Close()) + t.Cleanup(func() { + _ = os.Remove(policyFile.Name()) + }) + + config := &clients.AttestedClientConfig{ + StandardClientConfig: clients.StandardClientConfig{ + URL: "https://agent.example.com", + Timeout: 60 * time.Second, + }, + AttestationPolicy: policyFile.Name(), + AttestedTLS: true, + AttestationRequestContextHex: "xyz", + } + + transport, security, err := createTransport(config) + + assert.Error(t, err) + assert.Nil(t, transport) + assert.Equal(t, tls.WithoutTLS, security) + assert.Contains(t, err.Error(), "invalid attestation request context") +} + func TestCreateTransport_BasicTLSError(t *testing.T) { config := clients.StandardClientConfig{ URL: "https://example.com", diff --git a/pkg/ingress/proxy.go b/pkg/ingress/proxy.go index 678fa8bb..648197ad 100644 --- a/pkg/ingress/proxy.go +++ b/pkg/ingress/proxy.go @@ -11,6 +11,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "os" "sync" "github.com/ultravioletrs/cocos/pkg/atls" @@ -18,6 +19,8 @@ import ( "golang.org/x/net/http2/h2c" ) +const unix = "unix" + // ProxyConfig contains configuration for starting a proxy instance. type ProxyConfig struct { Port string @@ -81,12 +84,12 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { var rp *httputil.ReverseProxy // Check if backend is Unix socket or TCP - if p.backendURL.Scheme == "unix" { + if p.backendURL.Scheme == unix { // For Unix socket backend, we need to manually configure the reverse proxy // because NewSingleHostReverseProxy doesn't support unix:// scheme targetURL := &url.URL{ Scheme: "http", - Host: "unix", + Host: unix, } rp = httputil.NewSingleHostReverseProxy(targetURL) @@ -96,7 +99,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { originalDirector(req) // Set the URL to point to the backend service req.URL.Scheme = "http" - req.URL.Host = "unix" + req.URL.Host = unix } // Configure Transport for Unix socket with HTTP/2 @@ -105,7 +108,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { var d net.Dialer // Use Unix socket path from URL - return d.DialContext(ctx, "unix", p.backendURL.Path) + return d.DialContext(ctx, unix, p.backendURL.Path) }, } } else { @@ -130,6 +133,8 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { // Configure TLS var tlsConfig *tls.Config + var listener net.Listener + var err error contextDesc := fmt.Sprintf("context %s", ctx.ID) if ctx.Name != "" { contextDesc = fmt.Sprintf("%s (%s)", ctx.Name, ctx.ID) @@ -139,23 +144,31 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { if p.certProvider == nil { return fmt.Errorf("attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running") } - tlsConfig = &tls.Config{ - GetCertificate: p.certProvider.GetCertificate, - ClientAuth: tls.NoClientCert, - NextProtos: []string{"h2", "http/1.1"}, + tlsConfig, identity, mtls, err := atls.BuildServerTLSConfig(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile) + if err != nil { + return fmt.Errorf("failed to setup attested TLS: %w", err) + } + tlsConfig.NextProtos = []string{"h2", "http/1.1"} + + listener, err = p.attestedListener(addr, tlsConfig, identity) + if err != nil { + return fmt.Errorf("failed to listen: %w", err) } - mtls, err := ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile) - if err != nil { - return fmt.Errorf("failed to configure certificate authorities: %w", err) - } + p.started = true + go func() { + serveErr := p.httpServer.Serve(listener) + if serveErr != nil && serveErr != http.ErrServerClosed { + p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", serveErr)) + } + }() if mtls { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested mTLS for %s", addr, contextDesc)) } else { p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested TLS for %s", addr, contextDesc)) } + return nil } else if cfg.CertFile != "" && cfg.KeyFile != "" { // Regular TLS tlsSetup, err := SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile) @@ -174,31 +187,48 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s without TLS for %s", addr, contextDesc)) } + if tlsConfig != nil { + tcpListener, listenErr := net.Listen("tcp", addr) + if listenErr != nil { + return fmt.Errorf("failed to listen: %w", listenErr) + } + listener = tls.NewListener(tcpListener, tlsConfig) + } else { + listener, err = net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen: %w", err) + } + } + p.started = true // Start server in goroutine go func() { - var err error - if tlsConfig != nil { - ln, listenErr := net.Listen("tcp", addr) - if listenErr != nil { - p.logger.Error(fmt.Sprintf("failed to listen: %s", listenErr)) - return - } - tlsLn := tls.NewListener(ln, tlsConfig) - err = p.httpServer.Serve(tlsLn) - } else { - err = p.httpServer.ListenAndServe() - } - - if err != nil && err != http.ErrServerClosed { - p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", err)) + serveErr := p.httpServer.Serve(listener) + if serveErr != nil && serveErr != http.ErrServerClosed { + p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", serveErr)) } }() return nil } +func (p *proxyServer) attestedListener(addr string, tlsConfig *tls.Config, identity tls.Certificate) (net.Listener, error) { + network := "tcp" + address := addr + if len(addr) > 0 && addr[0] == '/' { + network = unix + address = addr + _ = os.Remove(address) + } + + return atls.Listen(network, address, &atls.ServerConfig{ + TLSConfig: tlsConfig, + Identity: identity, + BuildLeafExtensions: p.certProvider.BuildLeafExtensions, + }) +} + // Stop stops the proxy server. func (p *proxyServer) Stop() error { p.mu.Lock() diff --git a/pkg/ingress/proxy_test.go b/pkg/ingress/proxy_test.go index 4b4f84f7..227d8ba7 100644 --- a/pkg/ingress/proxy_test.go +++ b/pkg/ingress/proxy_test.go @@ -18,6 +18,7 @@ import ( "net/url" "os" "path/filepath" + "strings" "testing" "time" @@ -113,6 +114,9 @@ func TestProxyStartWithoutPort(t *testing.T) { ctx := ProxyContext{ID: "test-2"} err := ps.Start(cfg, ctx) + if err != nil && strings.Contains(err.Error(), "address already in use") { + t.Skip("default ingress port 7002 is already in use") + } require.NoError(t, err) defer func() { _ = ps.Stop() }() time.Sleep(100 * time.Millisecond) @@ -141,6 +145,33 @@ func TestProxyStartAlreadyStarted(t *testing.T) { assert.Equal(t, "proxy server already started", err.Error()) } +func TestProxyStartReturnsListenerError(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil).(*proxyServer) + + occupied, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer occupied.Close() + + port := occupied.Addr().(*net.TCPAddr).Port + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "test-bind-failure"} + + err = ps.Start(cfg, ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to listen") + assert.False(t, ps.started) + + retry, retryErr := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, retryErr) + retryPort := retry.Addr().(*net.TCPAddr).Port + retry.Close() + + err = ps.Start(ProxyConfig{Port: fmt.Sprintf("%d", retryPort)}, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() +} + // TestProxyStartAfterStopped tests error when starting after stop. func TestProxyStartAfterStopped(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) @@ -473,5 +504,5 @@ func TestProxyAttestedTLSInvalidCA(t *testing.T) { err := ps.Start(cfg, ctx) assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to configure certificate authorities") + assert.Contains(t, err.Error(), "failed to setup attested TLS") } diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index 51bd4d8d..f2076a3e 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -4,15 +4,12 @@ package tls import ( - "crypto/rand" "crypto/tls" "crypto/x509" - "encoding/hex" "encoding/pem" "os" "github.com/absmach/supermq/pkg/errors" - "github.com/ultravioletrs/cocos/pkg/atls" "github.com/ultravioletrs/cocos/pkg/attestation" ) @@ -125,21 +122,12 @@ func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey strin security = WithMATLS } - nonce := make([]byte, 64) - if _, err := rand.Read(nonce); err != nil { - return nil, errors.Wrap(errors.New("failed to generate nonce"), err) - } - - encoded := hex.EncodeToString(nonce) - sni := encoded + ".nonce" - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - RootCAs: rootCAs, - ServerName: sni, - VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - return atls.NewCertificateVerifier(rootCAs).VerifyPeerCertificate(rawCerts, verifiedChains, nonce) - }, + MinVersion: tls.VersionTLS13, + RootCAs: rootCAs, + } + if rootCAs == nil { + tlsConfig.InsecureSkipVerify = true } if clientCert != "" || clientKey != "" { diff --git a/pkg/tls/tls_test.go b/pkg/tls/tls_test.go index d37d3e98..3cbb4856 100644 --- a/pkg/tls/tls_test.go +++ b/pkg/tls/tls_test.go @@ -6,6 +6,7 @@ package tls import ( "crypto/rand" "crypto/rsa" + stdtls "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -296,10 +297,11 @@ func TestLoadATLSConfig(t *testing.T) { assert.NotNil(t, result.Config) // Verify TLS config properties - assert.True(t, result.Config.InsecureSkipVerify) - assert.NotNil(t, result.Config.VerifyPeerCertificate) - assert.NotEmpty(t, result.Config.ServerName) - assert.Contains(t, result.Config.ServerName, ".nonce") + assert.Equal(t, tt.serverCAFile == "", result.Config.InsecureSkipVerify) + assert.Equal(t, uint16(stdtls.VersionTLS13), result.Config.MinVersion) + if tt.serverCAFile != "" { + assert.NotNil(t, result.Config.RootCAs) + } }) } }