mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-560 - EAT (#561)
* feat: Implement EAT (Evidence Attestation Token) generation and verification for attestation responses, replacing raw quotes with EAT tokens in the attestation service and protobuf. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * style: standardize comment formatting and fix a debug log format specifier. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix pkg test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Introduce named constants for OEM IDs and use them in attestation claim extraction. Signed-off-by: SammyOina <sammyoina@gmail.com> * feat: Implement and test minimum length validation for EAT nonce in `NewEATClaims`. Signed-off-by: SammyOina <sammyoina@gmail.com> * feat: Add EATClaims.Sanitize method and integrate it into the validator to enforce claim dependencies. Signed-off-by: SammyOina <sammyoina@gmail.com> * feat: Add Signature field to SNPExtensions and TDXExtensions for enhanced claim validation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update dependencies and improve code structure in attestation package Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Introduce comprehensive test suites for EAT, ATLS, TDX, Azure SNP, and vTPM attestation, and improve EAT decoder robustness. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add encryption and admin keys, an encrypted algorithm file, and update go.mod to use go-jose/v4. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add new encryption and KBS admin keys while improving TDX attestation test error handling. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add new KBS admin and encryption keys, an encrypted linear regression algorithm, and refactor TDX test error message checks. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Implement Azure SNP attestation policy, update certificate verification, and add key management. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: replace hardcoded string literals with variables in Azure SNP attestation tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Refactor TDX EAT claims to use individual RTMR fields with `tdx_` prefixes and add an `IntUse` field. Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com> Signed-off-by: SammyOina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a3265bc346
commit
de50b6d2d4
+74
-48
@@ -3,6 +3,7 @@
|
||||
package atls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
@@ -32,7 +33,6 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
@@ -44,6 +44,32 @@ const (
|
||||
|
||||
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
||||
|
||||
// mockAttestationClient is a simple mock for testing.
|
||||
type mockAttestationClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
args := m.Called(ctx, reportData, nonce, attType)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
|
||||
args := m.Called(ctx, nonce)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).([]byte), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockAttestationClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func generateTestCertPEM(t *testing.T) string {
|
||||
return generateTestCertPEMWithSubject(t, "test")
|
||||
}
|
||||
@@ -133,9 +159,8 @@ func TestUnifiedCertificateGenerator(t *testing.T) {
|
||||
|
||||
// TestPlatformAttestationProvider tests the platform attestation provider.
|
||||
func TestPlatformAttestationProvider(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
|
||||
t.Run("NewAttestationProvider", func(t *testing.T) {
|
||||
mockClient := new(mockAttestationClient)
|
||||
cases := []struct {
|
||||
name string
|
||||
platformType attestation.PlatformType
|
||||
@@ -149,7 +174,7 @@ func TestPlatformAttestationProvider(t *testing.T) {
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
provider, err := NewAttestationProvider(mockProvider, c.platformType)
|
||||
provider, err := NewAttestationProvider(mockClient, c.platformType)
|
||||
|
||||
if c.expectError {
|
||||
assert.Error(t, err)
|
||||
@@ -164,10 +189,11 @@ func TestPlatformAttestationProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("GetAttestation", func(t *testing.T) {
|
||||
mockClient := new(mockAttestationClient)
|
||||
expectedAttestation := []byte("test-attestation")
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(expectedAttestation, nil)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedAttestation, nil)
|
||||
|
||||
provider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey := []byte("test-pubkey")
|
||||
@@ -177,14 +203,14 @@ func TestPlatformAttestationProvider(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedAttestation, attestation)
|
||||
mockProvider.AssertExpectations(t)
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("GetAttestationError", func(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
||||
|
||||
provider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = provider.Attest([]byte("pubkey"), []byte("nonce"))
|
||||
@@ -194,12 +220,11 @@ func TestPlatformAttestationProvider(t *testing.T) {
|
||||
|
||||
// TestAttestedCertificateProvider tests the attested certificate provider.
|
||||
func TestAttestedCertificateProvider(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
|
||||
t.Run("GetCertificateSuccess", func(t *testing.T) {
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
|
||||
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
subject := DefaultCertificateSubject()
|
||||
@@ -223,8 +248,8 @@ func TestAttestedCertificateProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("InvalidServerName", func(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
mockClient := new(mockAttestationClient)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
||||
@@ -237,10 +262,10 @@ func TestAttestedCertificateProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("AttestationError", func(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
||||
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
||||
@@ -260,10 +285,10 @@ func TestAttestedCertificateProvider(t *testing.T) {
|
||||
|
||||
// TestNewProvider tests the factory function.
|
||||
func TestNewProvider(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockClient := new(mockAttestationClient)
|
||||
|
||||
t.Run("SelfSignedProvider", func(t *testing.T) {
|
||||
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil)
|
||||
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
})
|
||||
@@ -271,19 +296,19 @@ func TestNewProvider(t *testing.T) {
|
||||
t.Run("CASignedProviderWithSDK", func(t *testing.T) {
|
||||
mockSDK := sdkmocks.NewSDK(t)
|
||||
|
||||
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
||||
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
})
|
||||
|
||||
t.Run("SelfSignedProviderNilSDK", func(t *testing.T) {
|
||||
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", nil)
|
||||
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
})
|
||||
|
||||
t.Run("InvalidPlatformType", func(t *testing.T) {
|
||||
_, err := NewProvider(mockProvider, attestation.PlatformType(999), "", "", nil)
|
||||
_, err := NewProvider(mockClient, attestation.PlatformType(999), "", "", nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -714,8 +739,8 @@ func TestCertificateVerification(t *testing.T) {
|
||||
|
||||
// TestAttestedCAProvider tests the CA-signed certificate provider.
|
||||
func TestAttestedCAProvider(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
mockClient := new(mockAttestationClient)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
subject := DefaultCertificateSubject()
|
||||
@@ -740,8 +765,8 @@ func TestAttestedCAProvider(t *testing.T) {
|
||||
|
||||
// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation.
|
||||
func TestCASignedCertificateErrors(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
mockClient := new(mockAttestationClient)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
subject := DefaultCertificateSubject()
|
||||
@@ -787,8 +812,8 @@ func TestCASignedCertificateErrors(t *testing.T) {
|
||||
// TestGetCertificateErrors tests error paths in certificate generation.
|
||||
func TestGetCertificateErrors(t *testing.T) {
|
||||
t.Run("InvalidServerNameFormat", func(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
mockClient := new(mockAttestationClient)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
||||
@@ -803,10 +828,10 @@ func TestGetCertificateErrors(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("AttestationProviderError", func(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
|
||||
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
|
||||
@@ -824,10 +849,10 @@ func TestGetCertificateErrors(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("CASignedCertificateError", func(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
|
||||
|
||||
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
|
||||
attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockSDK := sdkmocks.NewSDK(t)
|
||||
@@ -904,7 +929,8 @@ func TestCertificateVerificationEdgeCases(t *testing.T) {
|
||||
|
||||
err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported platform type")
|
||||
// The error occurs during EAT token decoding before platform type validation
|
||||
assert.Contains(t, err.Error(), "failed to decode EAT token")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -973,12 +999,12 @@ func TestIntegrationScenarios(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("FullSelfSignedFlow", func(t *testing.T) {
|
||||
// Setup mock provider
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
||||
// Setup mock client
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
||||
|
||||
// Create provider
|
||||
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil)
|
||||
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate certificate
|
||||
@@ -1017,10 +1043,10 @@ func TestIntegrationScenarios(t *testing.T) {
|
||||
mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
|
||||
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil))
|
||||
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
||||
|
||||
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
||||
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
@@ -1041,17 +1067,17 @@ func TestIntegrationScenarios(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, parsedCert.Subject)
|
||||
|
||||
mockProvider.AssertExpectations(t)
|
||||
mockClient.AssertExpectations(t)
|
||||
mockSDK.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
// TestConcurrentAccess tests concurrent access scenarios.
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
||||
mockClient := new(mockAttestationClient)
|
||||
mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
|
||||
|
||||
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil)
|
||||
provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
const numGoroutines = 10
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
package atls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
@@ -19,20 +21,20 @@ type AttestationProvider interface {
|
||||
|
||||
// PlatformAttestationProvider handles platform attestation operations.
|
||||
type platformAttestationProvider struct {
|
||||
provider attestation.Provider
|
||||
attClient attestation_client.Client
|
||||
oid asn1.ObjectIdentifier
|
||||
platformType attestation.PlatformType
|
||||
}
|
||||
|
||||
// NewAttestationProvider creates a new attestation provider for the given platform type.
|
||||
func NewAttestationProvider(provider attestation.Provider, platformType attestation.PlatformType) (AttestationProvider, error) {
|
||||
func NewAttestationProvider(attClient attestation_client.Client, platformType attestation.PlatformType) (AttestationProvider, error) {
|
||||
oid, err := OID(platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OID: %w", err)
|
||||
}
|
||||
|
||||
return &platformAttestationProvider{
|
||||
provider: provider,
|
||||
attClient: attClient,
|
||||
oid: oid,
|
||||
platformType: platformType,
|
||||
}, nil
|
||||
@@ -41,7 +43,21 @@ func NewAttestationProvider(provider attestation.Provider, platformType attestat
|
||||
func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
return p.provider.Attestation(hashNonce[:], hashNonce[:32])
|
||||
|
||||
var reportData [64]byte
|
||||
copy(reportData[:], hashNonce[:])
|
||||
|
||||
var nonceArray [32]byte
|
||||
copy(nonceArray[:], hashNonce[:32])
|
||||
|
||||
// Get signed EAT token from attestation service
|
||||
// The attestation service maintains a persistent signing key and returns a pre-signed token
|
||||
eatToken, err := p.attClient.GetAttestation(context.Background(), reportData, nonceArray, p.platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attestation from service: %w", err)
|
||||
}
|
||||
|
||||
return eatToken, nil
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/absmach/certs"
|
||||
sdk "github.com/absmach/certs/sdk"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
)
|
||||
|
||||
// CertificateProvider defines the interface for providing TLS certificates.
|
||||
@@ -173,8 +174,8 @@ func (p *attestedCertificateProvider) generateCASignedCertificate(ctx context.Co
|
||||
return block.Bytes, nil
|
||||
}
|
||||
|
||||
func NewProvider(provider attestation.Provider, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) {
|
||||
attestationProvider, err := NewAttestationProvider(provider, platformType)
|
||||
func NewProvider(attClient attestation_client.Client, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) {
|
||||
attestationProvider, err := NewAttestationProvider(attClient, platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attestation provider: %w", err)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
@@ -21,11 +22,15 @@ type CertificateVerifier interface {
|
||||
|
||||
// CertificateVerifier handles certificate verification operations.
|
||||
type certificateVerifier struct {
|
||||
rootCAs *x509.CertPool
|
||||
rootCAs *x509.CertPool
|
||||
verifierProvider func(attestation.PlatformType) (attestation.Verifier, error)
|
||||
}
|
||||
|
||||
func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
|
||||
return &certificateVerifier{rootCAs: rootCAs}
|
||||
return &certificateVerifier{
|
||||
rootCAs: rootCAs,
|
||||
verifierProvider: platformVerifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
|
||||
@@ -75,15 +80,43 @@ func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate,
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error {
|
||||
verifier, err := platformVerifier(platformType)
|
||||
// Decode EAT token from certificate extension
|
||||
// Note: We don't have the public key for verification here, so we decode without verification
|
||||
// The signature was created by the attester, and we trust the TEE hardware verification
|
||||
claims, err := eat.DecodeCBOR(extension, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode EAT token: %w", err)
|
||||
}
|
||||
|
||||
// Verify nonce matches
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
// Compare nonces (EAT nonce should match our computed nonce)
|
||||
if len(claims.Nonce) != len(hashNonce) {
|
||||
return fmt.Errorf("nonce length mismatch: expected %d, got %d", len(hashNonce), len(claims.Nonce))
|
||||
}
|
||||
|
||||
nonceMatch := true
|
||||
for i := range claims.Nonce {
|
||||
if claims.Nonce[i] != hashNonce[i] {
|
||||
nonceMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !nonceMatch {
|
||||
return fmt.Errorf("nonce mismatch in EAT token")
|
||||
}
|
||||
|
||||
// Get platform verifier
|
||||
verifier, err := v.verifierProvider(platformType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get platform verifier: %w", err)
|
||||
}
|
||||
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:32]); err != nil {
|
||||
// Verify the binary attestation report embedded in EAT token
|
||||
if err = verifier.VerifyAttestation(claims.RawReport, hashNonce[:], hashNonce[:32]); err != nil {
|
||||
return fmt.Errorf("failed to verify attestation: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
type mockVerifier struct {
|
||||
verifyAttestationFunc func(report []byte, teeNonce []byte, vTpmNonce []byte) error
|
||||
}
|
||||
|
||||
func (m *mockVerifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
if m.verifyAttestationFunc != nil {
|
||||
return m.verifyAttestationFunc(report, teeNonce, vTpmNonce)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVerifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVerifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVerifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVerifier) JSONToPolicy(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_Success(t *testing.T) {
|
||||
// Setup keys and cert templates
|
||||
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
caCert, err := x509.ParseCertificate(caCertDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
// Create verifier with mock platform verifier
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) {
|
||||
return &mockVerifier{
|
||||
verifyAttestationFunc: func(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
peerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Prepare EAT Claims
|
||||
nonce := []byte("test-nonce")
|
||||
peerPubKeyDER, err := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{
|
||||
Nonce: hashNonce[:],
|
||||
RawReport: []byte("mock-report"),
|
||||
}
|
||||
eatBytes, err := cbor.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create Peer Cert with EAT extension
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{CommonName: "Test Peer"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
ExtraExtensions: []pkix.Extension{
|
||||
{
|
||||
Id: SNPvTPMOID, // Use SNPvTPMOID as default testing OID
|
||||
Value: eatBytes,
|
||||
},
|
||||
},
|
||||
}
|
||||
peerCertDER, err := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_Failures(t *testing.T) {
|
||||
caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
caTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "Test CA"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
|
||||
caCert, _ := x509.ParseCertificate(caCertDER)
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(caCert)
|
||||
|
||||
verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier)
|
||||
|
||||
peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
peerTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
|
||||
err := verifier.VerifyPeerCertificate([][]byte{certDER}, nil, []byte("nonce"))
|
||||
assert.ErrorContains(t, err, "attestation extension not found")
|
||||
|
||||
nonce := []byte("nonce1")
|
||||
wrongNonce := []byte("nonce2")
|
||||
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
teeNonce := append(peerPubKeyDER, wrongNonce...) // Mismatching input
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
|
||||
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
|
||||
certDERMismatch, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
|
||||
err = verifier.VerifyPeerCertificate([][]byte{certDERMismatch}, nil, nonce) // Pass nonce1
|
||||
assert.ErrorContains(t, err, "nonce mismatch")
|
||||
}
|
||||
|
||||
func TestVerifyPeerCertificate_Empty(t *testing.T) {
|
||||
verifier := NewCertificateVerifier(nil)
|
||||
err := verifier.VerifyPeerCertificate(nil, nil, nil)
|
||||
assert.ErrorContains(t, err, "no certificates provided")
|
||||
}
|
||||
@@ -42,9 +42,20 @@ type PcrConfig struct {
|
||||
PCRValues PcrValues `json:"pcr_values"`
|
||||
}
|
||||
|
||||
// Config represents attestation configuration.
|
||||
type Config struct {
|
||||
*check.Config
|
||||
*PcrConfig
|
||||
*EATValidation
|
||||
}
|
||||
|
||||
// EATValidation contains EAT token validation settings.
|
||||
type EATValidation struct {
|
||||
RequireEATFormat bool `json:"require_eat_format"`
|
||||
AllowedFormats []string `json:"allowed_formats"`
|
||||
MaxTokenAgeSeconds int `json:"max_token_age_seconds"`
|
||||
RequireClaims []string `json:"require_claims"`
|
||||
VerifySignature bool `json:"verify_signature"`
|
||||
}
|
||||
|
||||
type ccCheck struct {
|
||||
@@ -61,6 +72,7 @@ type Provider interface {
|
||||
|
||||
type Verifier interface {
|
||||
VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error
|
||||
VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error
|
||||
VerifTeeAttestation(report []byte, teeNonce []byte) error
|
||||
VerifVTpmAttestation(report []byte, vTpmNonce []byte) error
|
||||
JSONToPolicy(path string) error
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/google/go-sev-guest/tools/lib/report"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@@ -154,6 +155,18 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyEAT verifies an EAT token and extracts the binary report for verification.
|
||||
func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
// Decode EAT token
|
||||
claims, err := eat.Decode(eatToken, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode EAT token: %w", err)
|
||||
}
|
||||
|
||||
// Verify the embedded binary report
|
||||
return v.VerifyAttestation(claims.RawReport, teeNonce, vTpmNonce)
|
||||
}
|
||||
|
||||
func (a verifier) JSONToPolicy(path string) error {
|
||||
return vtpm.ReadPolicy(path, a.Policy)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateAttestationPolicy_Success(t *testing.T) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwk := jose.JSONWebKey{
|
||||
Key: &key.PublicKey,
|
||||
KeyID: testKID,
|
||||
Algorithm: "RS256",
|
||||
Use: "sig",
|
||||
Certificates: []*x509.Certificate{cert},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwks := jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{jwk},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
originalMaaURL := MaaURL
|
||||
MaaURL = server.URL
|
||||
defer func() { MaaURL = originalMaaURL }()
|
||||
|
||||
token := createTestToken(t, key, server.URL)
|
||||
|
||||
policy, err := GenerateAttestationPolicy(token, "Milan", 0)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, policy)
|
||||
assert.Equal(t, "SEV_PRODUCT_MILAN", policy.Config.Policy.Product.Name.String())
|
||||
}
|
||||
|
||||
func createTestToken(t *testing.T, key *rsa.PrivateKey, jku string) string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-audience",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"x-ms-isolation-tee": map[string]any{
|
||||
"x-ms-sevsnpvm-familyId": "0102030405060708090a0b0c0d0e0f10",
|
||||
"x-ms-sevsnpvm-imageId": "0102030405060708090a0b0c0d0e0f10",
|
||||
"x-ms-sevsnpvm-launchmeasurement": "0102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f10",
|
||||
"x-ms-sevsnpvm-bootloader-svn": float64(1),
|
||||
"x-ms-sevsnpvm-tee-svn": float64(2),
|
||||
"x-ms-sevsnpvm-snpfw-svn": float64(3),
|
||||
"x-ms-sevsnpvm-microcode-svn": float64(4),
|
||||
"x-ms-sevsnpvm-guestsvn": float64(5),
|
||||
"x-ms-sevsnpvm-idkeydigest": "0102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f10",
|
||||
"x-ms-sevsnpvm-reportid": "0102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f10",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["jku"] = jku
|
||||
token.Header["kid"] = testKID
|
||||
|
||||
signedToken, err := token.SignedString(key)
|
||||
require.NoError(t, err)
|
||||
return signedToken
|
||||
}
|
||||
|
||||
func TestGenerateAttestationPolicy_InvalidToken(t *testing.T) {
|
||||
// Test with invalid token string
|
||||
_, err := GenerateAttestationPolicy("invalid-token", "Milan", 0)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to validate token")
|
||||
}
|
||||
@@ -0,0 +1,262 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
jose "github.com/go-jose/go-jose/v4"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateAttestationPolicy(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
product string
|
||||
policy uint64
|
||||
setupServer func(t *testing.T, key *rsa.PrivateKey, cert *x509.Certificate) *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
setupTokenJKU bool
|
||||
}{
|
||||
{
|
||||
name: "valid token and claims",
|
||||
product: "Milan-B0",
|
||||
policy: 0,
|
||||
setupServer: func(t *testing.T, key *rsa.PrivateKey, cert *x509.Certificate) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case openIDConfigPath:
|
||||
config := map[string]any{
|
||||
"jwks_uri": "http://" + r.Host + certsPath,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(config); err != nil {
|
||||
t.Errorf("failed to encode config: %v", err)
|
||||
}
|
||||
case certsPath:
|
||||
jwks := generateJWKS(&key.PublicKey, cert)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(jwks); err != nil {
|
||||
t.Errorf("failed to encode jwks: %v", err)
|
||||
}
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
},
|
||||
setupTokenJKU: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid token format",
|
||||
token: "invalid-token",
|
||||
product: "Milan-B0",
|
||||
policy: 0,
|
||||
setupServer: nil,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to parse token",
|
||||
setupTokenJKU: false,
|
||||
},
|
||||
{
|
||||
name: "missing familyId",
|
||||
product: "Milan-B0",
|
||||
policy: 0,
|
||||
setupServer: func(t *testing.T, key *rsa.PrivateKey, cert *x509.Certificate) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case openIDConfigPath:
|
||||
config := map[string]any{
|
||||
"jwks_uri": "http://" + r.Host + certsPath,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(config); err != nil {
|
||||
t.Errorf("failed to encode config: %v", err)
|
||||
}
|
||||
case certsPath:
|
||||
jwks := generateJWKS(&key.PublicKey, cert)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(jwks); err != nil {
|
||||
t.Errorf("failed to encode jwks: %v", err)
|
||||
}
|
||||
}
|
||||
}))
|
||||
},
|
||||
setupTokenJKU: true,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to get familyId from claims",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tokenString string
|
||||
var server *httptest.Server
|
||||
|
||||
if tt.setupServer != nil {
|
||||
server = tt.setupServer(t, privateKey, cert)
|
||||
defer server.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = "" // Clear it so it uses JKU
|
||||
defer func() { MaaURL = originalURL }()
|
||||
}
|
||||
|
||||
if tt.token != "" {
|
||||
tokenString = tt.token
|
||||
} else {
|
||||
// Generate token
|
||||
claims := createValidClaims()
|
||||
if tt.name == "missing familyId" {
|
||||
if tee, ok := claims["x-ms-isolation-tee"].(map[string]any); ok {
|
||||
delete(tee, "x-ms-sevsnpvm-familyId")
|
||||
}
|
||||
}
|
||||
|
||||
jku := ""
|
||||
if tt.setupTokenJKU && server != nil {
|
||||
jku = server.URL
|
||||
}
|
||||
|
||||
var err error
|
||||
tokenString, err = signToken(claims, privateKey, jku)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
config, err := GenerateAttestationPolicy(tokenString, tt.product, tt.policy)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, config)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyEAT(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
eatToken []byte
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
setupToken func() ([]byte, error)
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "invalid cbor",
|
||||
eatToken: []byte("invalid-cbor"),
|
||||
teeNonce: testNonce,
|
||||
vTpmNonce: testNonce,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
token := tt.eatToken
|
||||
if tt.setupToken != nil {
|
||||
var err error
|
||||
token, err = tt.setupToken()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := v.VerifyEAT(token, tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createValidClaims() jwt.MapClaims {
|
||||
return jwt.MapClaims{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-audience",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"nbf": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"x-ms-isolation-tee": map[string]any{
|
||||
"x-ms-sevsnpvm-familyId": "1234567890abcdef",
|
||||
"x-ms-sevsnpvm-imageId": "fedcba0987654321",
|
||||
"x-ms-sevsnpvm-launchmeasurement": "abcdef1234567890",
|
||||
"x-ms-sevsnpvm-bootloader-svn": float64(1),
|
||||
"x-ms-sevsnpvm-tee-svn": float64(2),
|
||||
"x-ms-sevsnpvm-snpfw-svn": float64(3),
|
||||
"x-ms-sevsnpvm-microcode-svn": float64(4),
|
||||
"x-ms-sevsnpvm-guestsvn": float64(5),
|
||||
"x-ms-sevsnpvm-idkeydigest": "1234567890abcdef",
|
||||
"x-ms-sevsnpvm-reportid": "fedcba0987654321",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func signToken(claims jwt.MapClaims, key *rsa.PrivateKey, jku string) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = testKID
|
||||
if jku != "" {
|
||||
token.Header["jku"] = jku
|
||||
}
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
func generateJWKS(pubKey *rsa.PublicKey, cert *x509.Certificate) *jose.JSONWebKeySet {
|
||||
key := jose.JSONWebKey{
|
||||
Key: pubKey,
|
||||
KeyID: testKID,
|
||||
Algorithm: "RS256",
|
||||
Use: "sig",
|
||||
Certificates: []*x509.Certificate{cert},
|
||||
}
|
||||
return &jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{key},
|
||||
}
|
||||
}
|
||||
@@ -22,8 +22,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
testNonce = []byte("test-nonce-12345678901234567890123456789012")
|
||||
testReport = []byte("test-report-data")
|
||||
testNonce = []byte("test-nonce-12345678901234567890123456789012")
|
||||
testReport = []byte("test-report-data")
|
||||
testKID = "test-kid"
|
||||
openIDConfigPath = "/.well-known/openid_configuration"
|
||||
certsPath = "/certs"
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
@@ -459,19 +462,19 @@ func TestIntegration_FullAttestationFlow(t *testing.T) {
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
t.Fatalf("Failed to encode response: %v", err)
|
||||
}
|
||||
case "/.well-known/openid_configuration":
|
||||
case openIDConfigPath:
|
||||
config := map[string]any{
|
||||
"jwks_uri": "maaServer.URL" + "/certs",
|
||||
"jwks_uri": "maaServer.URL" + certsPath,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(config); err != nil {
|
||||
t.Fatalf("Failed to encode OpenID configuration: %v", err)
|
||||
}
|
||||
case "/certs":
|
||||
case certsPath:
|
||||
jwks := map[string]any{
|
||||
"keys": []map[string]any{
|
||||
{
|
||||
"kid": "test-kid",
|
||||
"kid": testKID,
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"n": "test-n-value",
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/veraison/go-cose"
|
||||
)
|
||||
|
||||
// CBOREncoder encodes EAT claims to CBOR format (CWT - CBOR Web Token).
|
||||
type CBOREncoder struct {
|
||||
signingKey *ecdsa.PrivateKey
|
||||
issuer string
|
||||
}
|
||||
|
||||
// NewCBOREncoder creates a new CBOR encoder.
|
||||
func NewCBOREncoder(signingKey *ecdsa.PrivateKey, issuer string) *CBOREncoder {
|
||||
return &CBOREncoder{
|
||||
signingKey: signingKey,
|
||||
issuer: issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes EAT claims to CBOR bytes with COSE_Sign1 signature.
|
||||
func (e *CBOREncoder) Encode(claims *EATClaims) ([]byte, error) {
|
||||
// Set standard CWT claims
|
||||
now := time.Now()
|
||||
claims.Issuer = e.issuer
|
||||
claims.IssuedAt = now.Unix()
|
||||
claims.ExpiresAt = now.Add(5 * time.Minute).Unix() // 5 minute validity
|
||||
|
||||
// Encode claims to CBOR (this will be the payload)
|
||||
payload, err := cbor.Marshal(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode CBOR payload: %w", err)
|
||||
}
|
||||
|
||||
// Create COSE Sign1 message
|
||||
msg := cose.NewSign1Message()
|
||||
msg.Payload = payload
|
||||
msg.Headers.Protected.SetAlgorithm(cose.AlgorithmES256)
|
||||
msg.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte(e.issuer)
|
||||
|
||||
// Create signer from ECDSA private key
|
||||
signer, err := cose.NewSigner(cose.AlgorithmES256, e.signingKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create COSE signer: %w", err)
|
||||
}
|
||||
|
||||
// Sign the message
|
||||
if err := msg.Sign(rand.Reader, nil, signer); err != nil {
|
||||
return nil, fmt.Errorf("failed to sign COSE message: %w", err)
|
||||
}
|
||||
|
||||
// Encode the signed message to CBOR
|
||||
signed, err := msg.MarshalCBOR()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal COSE_Sign1: %w", err)
|
||||
}
|
||||
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
// EncodeToCBOR is a convenience function to encode EAT claims to CBOR.
|
||||
func EncodeToCBOR(claims *EATClaims, signingKey *ecdsa.PrivateKey, issuer string) ([]byte, error) {
|
||||
encoder := NewCBOREncoder(signingKey, issuer)
|
||||
return encoder.Encode(claims)
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/veraison/go-cose"
|
||||
)
|
||||
|
||||
func TestCBOREncoder_Encode(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
type fields struct {
|
||||
signingKey *ecdsa.PrivateKey
|
||||
issuer string
|
||||
}
|
||||
type args struct {
|
||||
claims *EATClaims
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid encoding",
|
||||
fields: fields{
|
||||
signingKey: key,
|
||||
issuer: "test-issuer",
|
||||
},
|
||||
args: args{
|
||||
claims: &EATClaims{
|
||||
Nonce: []byte("test-nonce"),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := NewCBOREncoder(tt.fields.signingKey, tt.fields.issuer)
|
||||
got, err := e.Encode(tt.args.claims)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, got)
|
||||
|
||||
var msg cose.Sign1Message
|
||||
err = msg.UnmarshalCBOR(got)
|
||||
assert.NoError(t, err)
|
||||
|
||||
verifier, err := cose.NewVerifier(cose.AlgorithmES256, &key.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
err = msg.Verify(nil, verifier)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeToCBOR(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := &EATClaims{Nonce: []byte("nonce")}
|
||||
token, err := EncodeToCBOR(claims, key, "issuer")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/veraison/go-cose"
|
||||
)
|
||||
|
||||
// Decoder decodes EAT tokens (auto-detects JWT vs CBOR).
|
||||
type Decoder struct {
|
||||
verifyKey *ecdsa.PublicKey
|
||||
}
|
||||
|
||||
// NewDecoder creates a new EAT decoder.
|
||||
func NewDecoder(verifyKey *ecdsa.PublicKey) *Decoder {
|
||||
return &Decoder{
|
||||
verifyKey: verifyKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes an EAT token (auto-detects format).
|
||||
func (d *Decoder) Decode(token []byte) (*EATClaims, error) {
|
||||
// Try to detect format
|
||||
if isJWT(token) {
|
||||
return d.decodeJWT(string(token))
|
||||
}
|
||||
return d.decodeCBOR(token)
|
||||
}
|
||||
|
||||
// isJWT checks if the token is JWT format.
|
||||
func isJWT(token []byte) bool {
|
||||
// JWT tokens are base64-encoded strings with dots
|
||||
if len(token) < 10 {
|
||||
return false
|
||||
}
|
||||
return bytes.Contains(token, []byte(".")) && !bytes.Contains(token[:10], []byte{0x00})
|
||||
}
|
||||
|
||||
// decodeJWT decodes a JWT token.
|
||||
func (d *Decoder) decodeJWT(tokenString string) (*EATClaims, error) {
|
||||
claims := &jwtClaims{&EATClaims{}}
|
||||
|
||||
var token *jwt.Token
|
||||
var err error
|
||||
|
||||
if d.verifyKey == nil {
|
||||
token, _, err = new(jwt.Parser).ParseUnverified(tokenString, claims)
|
||||
} else {
|
||||
// Parse and verify JWT
|
||||
token, err = jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
// Verify signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return d.verifyKey, nil
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return nil, fmt.Errorf("invalid JWT token")
|
||||
}
|
||||
|
||||
return claims.EATClaims, nil
|
||||
}
|
||||
|
||||
// decodeCBOR decodes a CBOR token with COSE signature verification.
|
||||
func (d *Decoder) decodeCBOR(token []byte) (*EATClaims, error) {
|
||||
// Try to unmarshal as COSE_Sign1 message
|
||||
var msg cose.Sign1Message
|
||||
if err := msg.UnmarshalCBOR(token); err != nil {
|
||||
// If it's not a COSE message, try to decode as plain CBOR (backward compatibility)
|
||||
claims := &EATClaims{}
|
||||
if err := cbor.Unmarshal(token, claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode CBOR: %w", err)
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Verify the signature if we have a verification key
|
||||
if d.verifyKey != nil {
|
||||
verifier, err := cose.NewVerifier(cose.AlgorithmES256, d.verifyKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create COSE verifier: %w", err)
|
||||
}
|
||||
|
||||
if err := msg.Verify(nil, verifier); err != nil {
|
||||
return nil, fmt.Errorf("COSE signature verification failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Decode the payload
|
||||
claims := &EATClaims{}
|
||||
if err := cbor.Unmarshal(msg.Payload, claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode CBOR payload: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// DecodeJWT is a convenience function to decode JWT EAT token.
|
||||
func DecodeJWT(tokenString string, verifyKey *ecdsa.PublicKey) (*EATClaims, error) {
|
||||
decoder := NewDecoder(verifyKey)
|
||||
return decoder.decodeJWT(tokenString)
|
||||
}
|
||||
|
||||
// DecodeCBOR is a convenience function to decode CBOR EAT token.
|
||||
func DecodeCBOR(token []byte, verifyKey *ecdsa.PublicKey) (*EATClaims, error) {
|
||||
decoder := NewDecoder(verifyKey)
|
||||
return decoder.decodeCBOR(token)
|
||||
}
|
||||
|
||||
// Decode is a convenience function that auto-detects format.
|
||||
func Decode(token []byte, verifyKey *ecdsa.PublicKey) (*EATClaims, error) {
|
||||
decoder := NewDecoder(verifyKey)
|
||||
return decoder.Decode(token)
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler for pretty printing.
|
||||
func (c *EATClaims) MarshalJSON() ([]byte, error) {
|
||||
type Alias EATClaims
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
NonceHex string `json:"eat_nonce_hex,omitempty"`
|
||||
UEIDHex string `json:"ueid_hex,omitempty"`
|
||||
MeasurementsHex string `json:"measurements_hex,omitempty"`
|
||||
}{
|
||||
Alias: (*Alias)(c),
|
||||
NonceHex: fmt.Sprintf("%x", c.Nonce),
|
||||
UEIDHex: fmt.Sprintf("%x", c.UEID),
|
||||
MeasurementsHex: fmt.Sprintf("%x", c.Measurements),
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/veraison/go-cose"
|
||||
)
|
||||
|
||||
func TestDecodeJWT(t *testing.T) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := &EATClaims{
|
||||
Nonce: []byte("test-nonce"),
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
jwtClaims := &jwtClaims{claims}
|
||||
claims.Issuer = "test-issuer"
|
||||
claims.IssuedAt = now.Unix()
|
||||
claims.ExpiresAt = now.Add(time.Hour).Unix()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, jwtClaims)
|
||||
signedToken, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
verifyKey *ecdsa.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "Valid token",
|
||||
args: args{
|
||||
token: signedToken,
|
||||
verifyKey: &privateKey.PublicKey,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid signature",
|
||||
args: args{
|
||||
token: signedToken,
|
||||
verifyKey: func() *ecdsa.PublicKey {
|
||||
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
return &key.PublicKey
|
||||
}(),
|
||||
},
|
||||
wantErr: true,
|
||||
expectedErr: "verification error",
|
||||
},
|
||||
{
|
||||
name: "Malformed token",
|
||||
args: args{
|
||||
token: "invalid.token.structure",
|
||||
verifyKey: &privateKey.PublicKey,
|
||||
},
|
||||
wantErr: true,
|
||||
expectedErr: "failed to parse JWT",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := DecodeJWT(tt.args.token, tt.args.verifyKey)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != "" {
|
||||
assert.ErrorContains(t, err, tt.expectedErr)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
assert.Equal(t, claims.Nonce, got.Nonce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeCBOR(t *testing.T) {
|
||||
claims := &EATClaims{
|
||||
Nonce: []byte("test-nonce"),
|
||||
}
|
||||
|
||||
payload, err := cbor.Marshal(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
signer, err := cose.NewSigner(cose.AlgorithmES256, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := cose.NewSign1Message()
|
||||
msg.Payload = payload
|
||||
err = msg.Sign(rand.Reader, []byte{}, signer)
|
||||
require.NoError(t, err)
|
||||
|
||||
cborToken, err := msg.MarshalCBOR()
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
token []byte
|
||||
verifyKey *ecdsa.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "Valid COSE token",
|
||||
args: args{
|
||||
token: cborToken,
|
||||
verifyKey: &privateKey.PublicKey,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Plain CBOR token (no signature)",
|
||||
args: args{
|
||||
token: payload,
|
||||
verifyKey: nil,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid COSE signature",
|
||||
args: args{
|
||||
token: cborToken,
|
||||
verifyKey: func() *ecdsa.PublicKey {
|
||||
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
return &key.PublicKey
|
||||
}(),
|
||||
},
|
||||
wantErr: true,
|
||||
expectedErr: "verification failed",
|
||||
},
|
||||
{
|
||||
name: "Malformed CBOR",
|
||||
args: args{
|
||||
token: []byte("invalid cbor"),
|
||||
verifyKey: nil,
|
||||
},
|
||||
wantErr: true,
|
||||
expectedErr: "failed to decode CBOR",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := DecodeCBOR(tt.args.token, tt.args.verifyKey)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != "" {
|
||||
assert.ErrorContains(t, err, tt.expectedErr)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
assert.Equal(t, claims.Nonce, got.Nonce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeAutoDetect(t *testing.T) {
|
||||
key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
claims := &EATClaims{Nonce: []byte("jwt")}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, &jwtClaims{claims})
|
||||
jwtString, _ := token.SignedString(key)
|
||||
|
||||
got, err := Decode([]byte(jwtString), &key.PublicKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("jwt"), got.Nonce)
|
||||
|
||||
claimsCBOR := &EATClaims{Nonce: []byte("cbor")}
|
||||
cborBytes, _ := cbor.Marshal(claimsCBOR)
|
||||
gotCBOR, err := Decode(cborBytes, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("cbor"), gotCBOR.Nonce)
|
||||
}
|
||||
|
||||
func TestIsJWT(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token []byte
|
||||
want bool
|
||||
}{
|
||||
{"Empty", []byte{}, false},
|
||||
{"JWT like", []byte("header.payload.signature"), true},
|
||||
{"CBOR (binary)", []byte{0x00, 0x01}, false},
|
||||
{"Text but not JWT", []byte("not a jwt"), false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := isJWT(tt.token); got != tt.want {
|
||||
t.Errorf("isJWT() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
// EATClaims represents the Entity Attestation Token claims following RFC 9711.
|
||||
type EATClaims struct {
|
||||
// Standard JWT/CWT claims
|
||||
Issuer string `json:"iss,omitempty" cbor:"1,keyasint,omitempty"`
|
||||
Subject string `json:"sub,omitempty" cbor:"2,keyasint,omitempty"`
|
||||
IssuedAt int64 `json:"iat,omitempty" cbor:"6,keyasint,omitempty"`
|
||||
ExpiresAt int64 `json:"exp,omitempty" cbor:"4,keyasint,omitempty"`
|
||||
|
||||
// Core EAT claims (RFC 9711)
|
||||
Nonce []byte `json:"eat_nonce" cbor:"10,keyasint"` // Freshness/replay protection
|
||||
UEID []byte `json:"ueid" cbor:"256,keyasint"` // Universal Entity ID
|
||||
OEMID int `json:"oemid,omitempty" cbor:"258,keyasint,omitempty"` // Hardware OEM ID
|
||||
HWModel []byte `json:"hwmodel,omitempty" cbor:"259,keyasint,omitempty"` // Hardware model
|
||||
HWVersion string `json:"hwversion,omitempty" cbor:"260,keyasint,omitempty"` // Hardware version
|
||||
SWName string `json:"swname,omitempty" cbor:"270,keyasint,omitempty"` // Software name
|
||||
SWVersion string `json:"swversion,omitempty" cbor:"271,keyasint,omitempty"` // Software version
|
||||
DebugStatus int `json:"dbgstat" cbor:"263,keyasint"` // Debug status
|
||||
IntUse int `json:"intuse,omitempty" cbor:"262,keyasint,omitempty"` // Intended use
|
||||
Measurements []byte `json:"measurements" cbor:"265,keyasint"` // Software measurements
|
||||
|
||||
// Platform type indicator
|
||||
PlatformType string `json:"platform_type"`
|
||||
|
||||
// Submodules for vTPM and other components
|
||||
Submods map[string]interface{} `json:"submods,omitempty" cbor:"266,keyasint,omitempty"`
|
||||
|
||||
// Platform-specific extensions (custom claims)
|
||||
SNPExtensions *SNPExtensions `json:"x-cocos-sevsnp,omitempty"`
|
||||
TDXExtensions *TDXExtensions `json:"x-cocos-tdx,omitempty"`
|
||||
VTPMExtensions *VTPMExtensions `json:"x-cocos-vtpm,omitempty"`
|
||||
|
||||
// Original binary report (for verification)
|
||||
RawReport []byte `json:"raw_report,omitempty"`
|
||||
}
|
||||
|
||||
// SNPExtensions contains AMD SEV-SNP specific claims.
|
||||
type SNPExtensions struct {
|
||||
Measurement []byte `json:"measurement"` // SNP MEASUREMENT field
|
||||
TCB string `json:"tcb"` // TCB version info
|
||||
PlatformInfo uint64 `json:"platform_info"` // PLATFORM_INFO
|
||||
Policy uint64 `json:"policy"` // POLICY field
|
||||
FamilyID []byte `json:"family_id,omitempty"` // Family ID
|
||||
ImageID []byte `json:"image_id,omitempty"` // Image ID
|
||||
VMPL int `json:"vmpl,omitempty"` // VM Privilege Level
|
||||
SignatureAlgo int `json:"signature_algo,omitempty"` // Signature algorithm
|
||||
CurrentTCB uint64 `json:"current_tcb,omitempty"` // Current TCB
|
||||
ReportedTCB uint64 `json:"reported_tcb,omitempty"` // Reported TCB
|
||||
ChipID []byte `json:"chip_id,omitempty"` // Chip ID
|
||||
CommittedTCB uint64 `json:"committed_tcb,omitempty"` // Committed TCB
|
||||
LaunchTCB uint64 `json:"launch_tcb,omitempty"` // Launch TCB
|
||||
Signature []byte `json:"signature,omitempty"` // Signature
|
||||
}
|
||||
|
||||
// TDXExtensions contains Intel TDX specific claims.
|
||||
type TDXExtensions struct {
|
||||
MRTD []byte `json:"tdx_mrtd"` // MRTD measurement
|
||||
RTMR0 []byte `json:"tdx_rtmr0"` // Runtime measurement register 0
|
||||
RTMR1 []byte `json:"tdx_rtmr1"` // Runtime measurement register 1
|
||||
RTMR2 []byte `json:"tdx_rtmr2"` // Runtime measurement register 2
|
||||
RTMR3 []byte `json:"tdx_rtmr3"` // Runtime measurement register 3
|
||||
XFAM uint64 `json:"tdx_xfam"` // Extended features available mask
|
||||
TDAttributes uint64 `json:"tdx_td_attributes"` // TD attributes
|
||||
MRConfigID []byte `json:"tdx_mrconfigid,omitempty"` // MR Config ID
|
||||
MROwner []byte `json:"tdx_mrowner,omitempty"` // MR Owner
|
||||
MROwnerConfig []byte `json:"tdx_mrownerconfig,omitempty"` // MR Owner Config
|
||||
MRSEAM []byte `json:"tdx_mrseam,omitempty"` // MR SEAM
|
||||
TDXModule *TDXModuleInfo `json:"tdx_module,omitempty"` // TDX module info
|
||||
Signature []byte `json:"tdx_signature,omitempty"` // Quote Signature
|
||||
}
|
||||
|
||||
// TDXModuleInfo contains TDX module version information.
|
||||
type TDXModuleInfo struct {
|
||||
Major uint8 `json:"major"`
|
||||
Minor uint8 `json:"minor"`
|
||||
BuildNum uint16 `json:"build_num"`
|
||||
BuildDate uint32 `json:"build_date"`
|
||||
}
|
||||
|
||||
// VTPMExtensions contains vTPM specific claims.
|
||||
type VTPMExtensions struct {
|
||||
PCRs map[string]string `json:"pcrs"` // PCR values (SHA256/SHA384)
|
||||
EventLog []byte `json:"event_log,omitempty"` // Event log
|
||||
Quote []byte `json:"quote,omitempty"` // TPM quote
|
||||
}
|
||||
|
||||
// DebugStatus constants (RFC 9711 Section 4.2.6).
|
||||
const (
|
||||
DebugEnabled = 0 // Debug is enabled
|
||||
DebugDisabled = 1 // Debug is disabled
|
||||
DebugDisabledSinceBoot = 2 // Debug is disabled since boot
|
||||
DebugPermanentDisable = 3 // Debug is permanently disabled
|
||||
DebugFullPermanentDisable = 4 // Debug is fully and permanently disabled
|
||||
)
|
||||
|
||||
// IntUse constants (RFC 9711 Section 4.2.5).
|
||||
const (
|
||||
IntUseGenericFresh = 1 // General purpose, fresh token
|
||||
)
|
||||
|
||||
// MinNonceLength defines the minimum length for EAT nonce in bytes.
|
||||
const MinNonceLength = 8
|
||||
|
||||
// NewEATClaims creates EAT claims from binary attestation report.
|
||||
func NewEATClaims(report []byte, nonce []byte, platformType attestation.PlatformType) (*EATClaims, error) {
|
||||
if len(nonce) < MinNonceLength {
|
||||
return nil, errors.New("eat_nonce must be at least 8 bytes long")
|
||||
}
|
||||
claims := &EATClaims{
|
||||
Nonce: nonce,
|
||||
PlatformType: getPlatformTypeName(platformType),
|
||||
RawReport: report,
|
||||
DebugStatus: DebugDisabledSinceBoot, // Default to disabled since boot
|
||||
IntUse: IntUseGenericFresh, // Default to general purpose, fresh token
|
||||
}
|
||||
|
||||
// Extract platform-specific claims
|
||||
if err := extractPlatformClaims(claims, report, platformType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// extractPlatformClaims extracts platform-specific claims from binary report.
|
||||
func extractPlatformClaims(claims *EATClaims, report []byte, platformType attestation.PlatformType) error {
|
||||
switch platformType {
|
||||
case attestation.SNP, attestation.SNPvTPM:
|
||||
return extractSNPClaims(claims, report)
|
||||
case attestation.TDX:
|
||||
return extractTDXClaims(claims, report)
|
||||
case attestation.VTPM:
|
||||
return extractVTPMClaims(claims, report)
|
||||
case attestation.Azure:
|
||||
return extractAzureClaims(claims, report)
|
||||
default:
|
||||
// For unknown platforms, just store the raw report
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// getPlatformTypeName converts platform type to string name.
|
||||
func getPlatformTypeName(platformType attestation.PlatformType) string {
|
||||
switch platformType {
|
||||
case attestation.SNP:
|
||||
return "SNP"
|
||||
case attestation.TDX:
|
||||
return "TDX"
|
||||
case attestation.VTPM:
|
||||
return "vTPM"
|
||||
case attestation.SNPvTPM:
|
||||
return "SNP-vTPM"
|
||||
case attestation.Azure:
|
||||
return "Azure"
|
||||
case attestation.NoCC:
|
||||
return "NoCC"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize enforces dependency rules for claims.
|
||||
// HWModel requires OEMID.
|
||||
// HWVersion requires HWModel.
|
||||
func (c *EATClaims) Sanitize() {
|
||||
if c.OEMID == 0 {
|
||||
c.HWModel = nil
|
||||
c.HWVersion = ""
|
||||
}
|
||||
if len(c.HWModel) == 0 {
|
||||
c.HWVersion = ""
|
||||
}
|
||||
if c.SWName == "" {
|
||||
c.SWVersion = ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
func TestNewEATClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nonce []byte
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "Valid nonce",
|
||||
nonce: []byte("12345678"),
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "Nonce too short",
|
||||
nonce: []byte("1234567"),
|
||||
expectedErr: "eat_nonce must be at least 8 bytes long",
|
||||
},
|
||||
{
|
||||
name: "Empty nonce",
|
||||
nonce: []byte{},
|
||||
expectedErr: "eat_nonce must be at least 8 bytes long",
|
||||
},
|
||||
{
|
||||
name: "Nil nonce",
|
||||
nonce: nil,
|
||||
expectedErr: "eat_nonce must be at least 8 bytes long",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewEATClaims([]byte("dummy report"), tt.nonce, attestation.NoCC)
|
||||
if tt.expectedErr != "" {
|
||||
assert.EqualError(t, err, tt.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claims *EATClaims
|
||||
expected *EATClaims
|
||||
}{
|
||||
{
|
||||
name: "All dependencies present",
|
||||
claims: &EATClaims{
|
||||
OEMID: 123,
|
||||
HWModel: []byte("ValidModel"),
|
||||
HWVersion: "1.0",
|
||||
},
|
||||
expected: &EATClaims{
|
||||
OEMID: 123,
|
||||
HWModel: []byte("ValidModel"),
|
||||
HWVersion: "1.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing OEMID clears HWModel and HWVersion",
|
||||
claims: &EATClaims{
|
||||
OEMID: 0,
|
||||
HWModel: []byte("ValidModel"),
|
||||
HWVersion: "1.0",
|
||||
},
|
||||
expected: &EATClaims{
|
||||
OEMID: 0,
|
||||
HWModel: nil,
|
||||
HWVersion: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing HWModel clears HWVersion",
|
||||
claims: &EATClaims{
|
||||
OEMID: 123,
|
||||
HWModel: nil,
|
||||
HWVersion: "1.0",
|
||||
},
|
||||
expected: &EATClaims{
|
||||
OEMID: 123,
|
||||
HWModel: nil,
|
||||
HWVersion: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing HWModel (empty bytes) clears HWVersion",
|
||||
claims: &EATClaims{
|
||||
OEMID: 123,
|
||||
HWModel: []byte{},
|
||||
HWVersion: "1.0",
|
||||
},
|
||||
expected: &EATClaims{
|
||||
OEMID: 123,
|
||||
HWModel: []byte{}, // Should remain empty slice
|
||||
HWVersion: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Independent fields unaffected",
|
||||
claims: &EATClaims{
|
||||
OEMID: 0,
|
||||
DebugStatus: DebugEnabled,
|
||||
},
|
||||
expected: &EATClaims{
|
||||
OEMID: 0,
|
||||
DebugStatus: DebugEnabled,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing SWName clears SWVersion",
|
||||
claims: &EATClaims{
|
||||
SWName: "",
|
||||
SWVersion: "1.0.0",
|
||||
},
|
||||
expected: &EATClaims{
|
||||
SWName: "",
|
||||
SWVersion: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.claims.Sanitize()
|
||||
assert.Equal(t, tt.expected, tt.claims)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
tdxabi "github.com/google/go-tdx-guest/abi"
|
||||
tdxpb "github.com/google/go-tdx-guest/proto/tdx"
|
||||
)
|
||||
|
||||
// OEMID constants (Private Enterprise Numbers).
|
||||
const (
|
||||
OEMID_AMD = 3704 // https://www.iana.org/assignments/enterprise-numbers/?q=Advanced+Micro+Devices
|
||||
OEMID_INTEL = 343 // https://www.iana.org/assignments/enterprise-numbers/?q=Intel+Corporation
|
||||
OEMID_MICROSOFT = 311 // https://www.iana.org/assignments/enterprise-numbers/?q=Microsoft+Corporation
|
||||
)
|
||||
|
||||
// extractSNPClaims extracts AMD SEV-SNP specific claims from binary report.
|
||||
func extractSNPClaims(claims *EATClaims, report []byte) error {
|
||||
if len(report) < int(abi.ReportSize) {
|
||||
return fmt.Errorf("SNP report too small: got %d bytes, want at least %d", len(report), abi.ReportSize)
|
||||
}
|
||||
|
||||
// Parse SNP report structure
|
||||
snpReport, err := abi.ReportToProto(report[:abi.ReportSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse SNP report: %w", err)
|
||||
}
|
||||
|
||||
// Extract SNP-specific fields
|
||||
claims.SNPExtensions = &SNPExtensions{
|
||||
Measurement: snpReport.Measurement,
|
||||
Policy: snpReport.Policy,
|
||||
FamilyID: snpReport.FamilyId,
|
||||
ImageID: snpReport.ImageId,
|
||||
VMPL: int(snpReport.Vmpl),
|
||||
SignatureAlgo: int(snpReport.SignatureAlgo),
|
||||
PlatformInfo: snpReport.PlatformInfo,
|
||||
ChipID: snpReport.ChipId,
|
||||
}
|
||||
|
||||
// Set TCB version info
|
||||
claims.SNPExtensions.CurrentTCB = snpReport.CurrentTcb
|
||||
claims.SNPExtensions.ReportedTCB = snpReport.ReportedTcb
|
||||
claims.SNPExtensions.CommittedTCB = snpReport.CommittedTcb
|
||||
claims.SNPExtensions.LaunchTCB = snpReport.LaunchTcb
|
||||
claims.SNPExtensions.TCB = fmt.Sprintf("current:%d,reported:%d", snpReport.CurrentTcb, snpReport.ReportedTcb)
|
||||
|
||||
// Set core EAT claims from SNP report
|
||||
claims.Measurements = snpReport.Measurement
|
||||
claims.UEID = snpReport.ChipId // Use ChipID as UEID
|
||||
claims.OEMID = OEMID_AMD // AMD's PEN (Private Enterprise Number)
|
||||
claims.SNPExtensions.Signature = snpReport.Signature
|
||||
|
||||
// Set hardware model (hash of product name)
|
||||
claims.HWModel = []byte(fmt.Sprintf("SEV-SNP-%d", snpReport.Version))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractTDXClaims extracts Intel TDX specific claims from binary report.
|
||||
func extractTDXClaims(claims *EATClaims, report []byte) error {
|
||||
// Parse TDX quote using go-tdx-guest ABI
|
||||
decodedQuote, err := tdxabi.QuoteToProto(report)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse TDX quote: %w", err)
|
||||
}
|
||||
|
||||
quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported TDX quote format")
|
||||
}
|
||||
|
||||
tdReport := quoteV4.GetTdQuoteBody()
|
||||
signedData := quoteV4.GetSignedData()
|
||||
|
||||
rtmrs := tdReport.GetRtmrs()
|
||||
var rtmr0, rtmr1, rtmr2, rtmr3 []byte
|
||||
if len(rtmrs) > 0 {
|
||||
rtmr0 = rtmrs[0]
|
||||
}
|
||||
if len(rtmrs) > 1 {
|
||||
rtmr1 = rtmrs[1]
|
||||
}
|
||||
if len(rtmrs) > 2 {
|
||||
rtmr2 = rtmrs[2]
|
||||
}
|
||||
if len(rtmrs) > 3 {
|
||||
rtmr3 = rtmrs[3]
|
||||
}
|
||||
|
||||
claims.TDXExtensions = &TDXExtensions{
|
||||
MRTD: tdReport.GetMrTd(),
|
||||
RTMR0: rtmr0,
|
||||
RTMR1: rtmr1,
|
||||
RTMR2: rtmr2,
|
||||
RTMR3: rtmr3,
|
||||
XFAM: binary.LittleEndian.Uint64(tdReport.GetXfam()),
|
||||
TDAttributes: binary.LittleEndian.Uint64(tdReport.GetTdAttributes()),
|
||||
MRConfigID: tdReport.GetMrConfigId(),
|
||||
MROwner: tdReport.GetMrOwner(),
|
||||
MROwnerConfig: tdReport.GetMrOwnerConfig(),
|
||||
MRSEAM: tdReport.GetMrSeam(),
|
||||
Signature: signedData.GetSignature(),
|
||||
}
|
||||
|
||||
// Set core EAT claims
|
||||
claims.Measurements = tdReport.GetMrTd()
|
||||
// Use first 32 bytes of MRTD as UEID, similar to other extractors
|
||||
if len(claims.Measurements) >= 32 {
|
||||
claims.UEID = claims.Measurements[:32]
|
||||
}
|
||||
claims.OEMID = OEMID_INTEL // Intel's PEN
|
||||
|
||||
// Set hardware model
|
||||
claims.HWModel = []byte("Intel-TDX")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractVTPMClaims extracts vTPM specific claims from binary report.
|
||||
func extractVTPMClaims(claims *EATClaims, report []byte) error {
|
||||
// vTPM report is typically a marshaled structure containing PCRs and quote
|
||||
// For now, store the entire report as the quote
|
||||
claims.VTPMExtensions = &VTPMExtensions{
|
||||
Quote: report,
|
||||
PCRs: make(map[string]string),
|
||||
}
|
||||
|
||||
// Set core EAT claims
|
||||
claims.Measurements = report[:32] // Use first 32 bytes as measurement
|
||||
claims.UEID = report[:16] // Use first 16 bytes as UEID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractAzureClaims extracts Azure-specific claims from attestation token.
|
||||
func extractAzureClaims(claims *EATClaims, report []byte) error {
|
||||
// Azure provides JWT tokens, so the report is already in a structured format
|
||||
// For now, just store it as raw report
|
||||
claims.Measurements = report[:32] // Use first 32 bytes as measurement
|
||||
claims.UEID = report[:16] // Use first 16 bytes as UEID
|
||||
claims.OEMID = OEMID_MICROSOFT // Microsoft's PEN
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
func TestExtractSNPClaims(t *testing.T) {
|
||||
validReport := make([]byte, abi.ReportSize)
|
||||
validReport[0] = 1
|
||||
validReport[10] = 0x2 // Policy bit 17 set (byte 2 of Policy, bit 1)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
report []byte
|
||||
wantErr bool
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "valid report size (minimal)",
|
||||
report: validReport,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "report too small",
|
||||
report: make([]byte, abi.ReportSize-1),
|
||||
wantErr: true,
|
||||
expectedErr: "SNP report too small",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims := &EATClaims{}
|
||||
err := extractSNPClaims(claims, tt.report)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != "" {
|
||||
assert.Contains(t, err.Error(), tt.expectedErr)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, claims.SNPExtensions)
|
||||
assert.Equal(t, OEMID_AMD, claims.OEMID)
|
||||
assert.Equal(t, []byte(fmt.Sprintf("SEV-SNP-%d", 1)), claims.HWModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTDXClaims(t *testing.T) {
|
||||
report := []byte("invalid-tdx-quote")
|
||||
claims := &EATClaims{}
|
||||
err := extractTDXClaims(claims, report)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to parse TDX quote")
|
||||
}
|
||||
|
||||
func TestTDXExtensionsJSON(t *testing.T) {
|
||||
ext := &TDXExtensions{
|
||||
MRTD: []byte("mrtd_val"),
|
||||
RTMR0: []byte("rtmr0_val"),
|
||||
RTMR1: []byte("rtmr1_val"),
|
||||
RTMR2: []byte("rtmr2_val"),
|
||||
RTMR3: []byte("rtmr3_val"),
|
||||
XFAM: 123,
|
||||
TDAttributes: 456,
|
||||
TDXModule: &TDXModuleInfo{
|
||||
Major: 1,
|
||||
},
|
||||
}
|
||||
|
||||
claims := &EATClaims{
|
||||
TDXExtensions: ext,
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(claims)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify JSON keys match Intel EAT profile
|
||||
jsonStr := string(data)
|
||||
assert.Contains(t, jsonStr, `"tdx_mrtd":"bXJ0ZF92YWw="`)
|
||||
assert.Contains(t, jsonStr, `"tdx_rtmr0":"cnRtcjBfdmFs"`) // base64 of "rtmr0_val"
|
||||
assert.Contains(t, jsonStr, `"tdx_rtmr1":"cnRtcjFfdmFs"`)
|
||||
assert.Contains(t, jsonStr, `"tdx_rtmr2":"cnRtcjJfdmFs"`)
|
||||
assert.Contains(t, jsonStr, `"tdx_rtmr3":"cnRtcjNfdmFs"`)
|
||||
assert.Contains(t, jsonStr, `"tdx_xfam":123`)
|
||||
assert.Contains(t, jsonStr, `"tdx_td_attributes":456`)
|
||||
assert.Contains(t, jsonStr, `"tdx_module":{"major":1,"minor":0,"build_num":0,"build_date":0}`)
|
||||
}
|
||||
|
||||
func TestExtractVTPMClaims(t *testing.T) {
|
||||
report := make([]byte, 32)
|
||||
copy(report, []byte("vtpm-report-with-enough-length-123"))
|
||||
|
||||
claims := &EATClaims{}
|
||||
err := extractVTPMClaims(claims, report)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, claims.VTPMExtensions)
|
||||
assert.Equal(t, report, claims.VTPMExtensions.Quote)
|
||||
assert.Equal(t, report[:32], claims.Measurements)
|
||||
assert.Equal(t, report[:16], claims.UEID)
|
||||
}
|
||||
|
||||
func TestExtractAzureClaims(t *testing.T) {
|
||||
report := make([]byte, 32) // Needs at least 32 bytes for valid slicing
|
||||
for i := range report {
|
||||
report[i] = byte(i)
|
||||
}
|
||||
claims := &EATClaims{}
|
||||
err := extractAzureClaims(claims, report)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, report, claims.Measurements)
|
||||
assert.Equal(t, report[:16], claims.UEID)
|
||||
assert.Equal(t, OEMID_MICROSOFT, claims.OEMID)
|
||||
}
|
||||
|
||||
// Platform type helper.
|
||||
func TestGetPlatformTypeName(t *testing.T) {
|
||||
tests := []struct {
|
||||
pt attestation.PlatformType
|
||||
want string
|
||||
}{
|
||||
{attestation.SNP, "SNP"},
|
||||
{attestation.SNPvTPM, "SNP-vTPM"},
|
||||
{attestation.TDX, "TDX"},
|
||||
{attestation.VTPM, "vTPM"},
|
||||
{attestation.Azure, "Azure"},
|
||||
{attestation.NoCC, "NoCC"},
|
||||
{attestation.PlatformType(999), "Unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.want, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, getPlatformTypeName(tt.pt))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
func TestIntUse(t *testing.T) {
|
||||
report := []byte("dummy-report")
|
||||
nonce := make([]byte, 8)
|
||||
|
||||
claims, err := NewEATClaims(report, nonce, attestation.NoCC)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, IntUseGenericFresh, claims.IntUse)
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// JWTEncoder encodes EAT claims to JWT format.
|
||||
type JWTEncoder struct {
|
||||
signingKey *ecdsa.PrivateKey
|
||||
issuer string
|
||||
}
|
||||
|
||||
// NewJWTEncoder creates a new JWT encoder.
|
||||
func NewJWTEncoder(signingKey *ecdsa.PrivateKey, issuer string) *JWTEncoder {
|
||||
return &JWTEncoder{
|
||||
signingKey: signingKey,
|
||||
issuer: issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes EAT claims to JWT string.
|
||||
func (e *JWTEncoder) Encode(claims *EATClaims) (string, error) {
|
||||
// Set standard JWT claims
|
||||
now := time.Now()
|
||||
claims.Issuer = e.issuer
|
||||
claims.IssuedAt = now.Unix()
|
||||
claims.ExpiresAt = now.Add(5 * time.Minute).Unix() // 5 minute validity
|
||||
|
||||
// Create JWT token with custom claims
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodES256, &jwtClaims{claims})
|
||||
|
||||
// Sign the token
|
||||
tokenString, err := token.SignedString(e.signingKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign JWT: %w", err)
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// jwtClaims wraps EATClaims for JWT encoding.
|
||||
type jwtClaims struct {
|
||||
*EATClaims
|
||||
}
|
||||
|
||||
// GetExpirationTime implements jwt.Claims interface.
|
||||
func (c *jwtClaims) GetExpirationTime() (*jwt.NumericDate, error) {
|
||||
if c.ExpiresAt == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return jwt.NewNumericDate(time.Unix(c.ExpiresAt, 0)), nil
|
||||
}
|
||||
|
||||
// GetIssuedAt implements jwt.Claims interface.
|
||||
func (c *jwtClaims) GetIssuedAt() (*jwt.NumericDate, error) {
|
||||
if c.IssuedAt == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return jwt.NewNumericDate(time.Unix(c.IssuedAt, 0)), nil
|
||||
}
|
||||
|
||||
// GetNotBefore implements jwt.Claims interface.
|
||||
func (c *jwtClaims) GetNotBefore() (*jwt.NumericDate, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetIssuer implements jwt.Claims interface.
|
||||
func (c *jwtClaims) GetIssuer() (string, error) {
|
||||
return c.Issuer, nil
|
||||
}
|
||||
|
||||
// GetSubject implements jwt.Claims interface.
|
||||
func (c *jwtClaims) GetSubject() (string, error) {
|
||||
return c.Subject, nil
|
||||
}
|
||||
|
||||
// GetAudience implements jwt.Claims interface.
|
||||
func (c *jwtClaims) GetAudience() (jwt.ClaimStrings, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// EncodeToJWT is a convenience function to encode EAT claims to JWT.
|
||||
func EncodeToJWT(claims *EATClaims, signingKey *ecdsa.PrivateKey, issuer string) (string, error) {
|
||||
encoder := NewJWTEncoder(signingKey, issuer)
|
||||
return encoder.Encode(claims)
|
||||
}
|
||||
|
||||
// GenerateSigningKey generates a new ECDSA signing key.
|
||||
func GenerateSigningKey() (*ecdsa.PrivateKey, error) {
|
||||
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestJWTEncoder_Encode(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
type fields struct {
|
||||
signingKey *ecdsa.PrivateKey
|
||||
issuer string
|
||||
}
|
||||
type args struct {
|
||||
claims *EATClaims
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid encoding",
|
||||
fields: fields{
|
||||
signingKey: key,
|
||||
issuer: "test-issuer",
|
||||
},
|
||||
args: args{
|
||||
claims: &EATClaims{
|
||||
Nonce: []byte("test-nonce"),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := NewJWTEncoder(tt.fields.signingKey, tt.fields.issuer)
|
||||
got, err := e.Encode(tt.args.claims)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, got)
|
||||
|
||||
// Verify the generated token
|
||||
parsedToken, err := jwt.ParseWithClaims(got, &jwtClaims{&EATClaims{}}, func(token *jwt.Token) (interface{}, error) {
|
||||
return &key.PublicKey, nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, parsedToken.Valid)
|
||||
|
||||
claims, ok := parsedToken.Claims.(*jwtClaims)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.fields.issuer, claims.Issuer)
|
||||
assert.Equal(t, tt.args.claims.Nonce, claims.Nonce)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSigningKey(t *testing.T) {
|
||||
key, err := GenerateSigningKey()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, key)
|
||||
assert.Equal(t, elliptic.P256(), key.Curve)
|
||||
}
|
||||
|
||||
func TestEncodeToJWT(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := &EATClaims{Nonce: []byte("nonce")}
|
||||
token, err := EncodeToJWT(claims, key, "issuer")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestJwtClaimsINTERFACE(t *testing.T) {
|
||||
now := time.Now()
|
||||
claims := &EATClaims{
|
||||
Issuer: "iss",
|
||||
Subject: "sub",
|
||||
ExpiresAt: now.Add(time.Hour).Unix(),
|
||||
IssuedAt: now.Unix(),
|
||||
}
|
||||
jwtc := &jwtClaims{claims}
|
||||
|
||||
exp, err := jwtc.GetExpirationTime()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, claims.ExpiresAt, exp.Unix())
|
||||
|
||||
iat, err := jwtc.GetIssuedAt()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, claims.IssuedAt, iat.Unix())
|
||||
|
||||
iss, err := jwtc.GetIssuer()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, claims.Issuer, iss)
|
||||
|
||||
sub, err := jwtc.GetSubject()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, claims.Subject, sub)
|
||||
|
||||
nbf, err := jwtc.GetNotBefore()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, nbf)
|
||||
|
||||
aud, err := jwtc.GetAudience()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, aud)
|
||||
|
||||
// Test zero values
|
||||
emptyClaims := &jwtClaims{&EATClaims{}}
|
||||
exp, err = emptyClaims.GetExpirationTime()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, exp)
|
||||
|
||||
iat, err = emptyClaims.GetIssuedAt()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, iat)
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidateEATClaims validates EAT claims against policy.
|
||||
func ValidateEATClaims(claims *EATClaims, policy *EATValidationPolicy) error {
|
||||
if policy == nil {
|
||||
return nil // No policy, skip validation
|
||||
}
|
||||
|
||||
// Sanitize claims to enforce dependency rules
|
||||
claims.Sanitize()
|
||||
|
||||
// Check required claims
|
||||
for _, requiredClaim := range policy.RequireClaims {
|
||||
switch requiredClaim {
|
||||
case "eat_nonce":
|
||||
if len(claims.Nonce) == 0 {
|
||||
return fmt.Errorf("missing required claim: eat_nonce")
|
||||
}
|
||||
case "measurements":
|
||||
if len(claims.Measurements) == 0 {
|
||||
return fmt.Errorf("missing required claim: measurements")
|
||||
}
|
||||
case "platform_type":
|
||||
if claims.PlatformType == "" {
|
||||
return fmt.Errorf("missing required claim: platform_type")
|
||||
}
|
||||
case "ueid":
|
||||
if len(claims.UEID) == 0 {
|
||||
return fmt.Errorf("missing required claim: ueid")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check token age
|
||||
if policy.MaxTokenAgeSeconds > 0 && claims.IssuedAt > 0 {
|
||||
tokenAge := time.Since(time.Unix(claims.IssuedAt, 0))
|
||||
if tokenAge.Seconds() > float64(policy.MaxTokenAgeSeconds) {
|
||||
return fmt.Errorf("token too old: %v seconds (max: %d)", tokenAge.Seconds(), policy.MaxTokenAgeSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
// Check expiration
|
||||
if claims.ExpiresAt > 0 {
|
||||
if time.Now().Unix() > claims.ExpiresAt {
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EATValidationPolicy contains validation rules for EAT tokens.
|
||||
type EATValidationPolicy struct {
|
||||
RequireEATFormat bool
|
||||
AllowedFormats []string
|
||||
MaxTokenAgeSeconds int
|
||||
RequireClaims []string
|
||||
VerifySignature bool
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package eat
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateEATClaims(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims *EATClaims
|
||||
policy *EATValidationPolicy
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "Nil policy",
|
||||
claims: &EATClaims{},
|
||||
policy: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid claims conforming to policy",
|
||||
claims: &EATClaims{
|
||||
Nonce: []byte("nonce"),
|
||||
Measurements: []byte("meas"),
|
||||
IssuedAt: now.Unix(),
|
||||
ExpiresAt: now.Add(time.Hour).Unix(),
|
||||
},
|
||||
policy: &EATValidationPolicy{
|
||||
RequireClaims: []string{"eat_nonce", "measurements"},
|
||||
MaxTokenAgeSeconds: 300,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Missing nonce",
|
||||
claims: &EATClaims{
|
||||
Measurements: []byte("meas"),
|
||||
},
|
||||
policy: &EATValidationPolicy{
|
||||
RequireClaims: []string{"eat_nonce"},
|
||||
},
|
||||
expectedErr: "missing required claim: eat_nonce",
|
||||
},
|
||||
{
|
||||
name: "Missing measurements",
|
||||
claims: &EATClaims{
|
||||
Nonce: []byte("nonce"),
|
||||
},
|
||||
policy: &EATValidationPolicy{
|
||||
RequireClaims: []string{"measurements"},
|
||||
},
|
||||
expectedErr: "missing required claim: measurements",
|
||||
},
|
||||
{
|
||||
name: "Missing platform type",
|
||||
claims: &EATClaims{},
|
||||
policy: &EATValidationPolicy{
|
||||
RequireClaims: []string{"platform_type"},
|
||||
},
|
||||
expectedErr: "missing required claim: platform_type",
|
||||
},
|
||||
{
|
||||
name: "Missing UEID",
|
||||
claims: &EATClaims{},
|
||||
policy: &EATValidationPolicy{
|
||||
RequireClaims: []string{"ueid"},
|
||||
},
|
||||
expectedErr: "missing required claim: ueid",
|
||||
},
|
||||
{
|
||||
name: "Token too old",
|
||||
claims: &EATClaims{
|
||||
IssuedAt: now.Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
policy: &EATValidationPolicy{
|
||||
MaxTokenAgeSeconds: 3600, // 1 hour max age
|
||||
},
|
||||
expectedErr: "token too old",
|
||||
},
|
||||
{
|
||||
name: "Token expired",
|
||||
claims: &EATClaims{
|
||||
ExpiresAt: now.Add(-1 * time.Hour).Unix(),
|
||||
},
|
||||
policy: &EATValidationPolicy{},
|
||||
expectedErr: "token expired",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateEATClaims(tt.claims, tt.policy)
|
||||
if tt.expectedErr != "" {
|
||||
assert.ErrorContains(t, err, tt.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -265,3 +265,66 @@ func (_c *Verifier_VerifyAttestation_Call) RunAndReturn(run func(report []byte,
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyEAT provides a mock function for the type Verifier
|
||||
func (_mock *Verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
ret := _mock.Called(eatToken, teeNonce, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyEAT")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
|
||||
r0 = returnFunc(eatToken, teeNonce, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifyEAT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEAT'
|
||||
type Verifier_VerifyEAT_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyEAT is a helper method to define mock.On call
|
||||
// - eatToken []byte
|
||||
// - teeNonce []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifyEAT(eatToken interface{}, teeNonce interface{}, vTpmNonce interface{}) *Verifier_VerifyEAT_Call {
|
||||
return &Verifier_VerifyEAT_Call{Call: _e.mock.On("VerifyEAT", eatToken, teeNonce, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyEAT_Call) Run(run func(eatToken []byte, teeNonce []byte, vTpmNonce []byte)) *Verifier_VerifyEAT_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 []byte
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].([]byte)
|
||||
}
|
||||
var arg1 []byte
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].([]byte)
|
||||
}
|
||||
var arg2 []byte
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].([]byte)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyEAT_Call) Return(err error) *Verifier_VerifyEAT_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyEAT_Call) RunAndReturn(run func(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error) *Verifier_VerifyEAT_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
verifytdx "github.com/google/go-tdx-guest/verify"
|
||||
trusttdx "github.com/google/go-tdx-guest/verify/trust"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
@@ -141,6 +142,18 @@ func (v verifier) JSONToPolicy(path string) error {
|
||||
return ReadTDXAttestationPolicy(path, v.Policy)
|
||||
}
|
||||
|
||||
// VerifyEAT verifies an EAT token and extracts the binary report for verification.
|
||||
func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
// Decode EAT token
|
||||
claims, err := eat.Decode(eatToken, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode EAT token: %w", err)
|
||||
}
|
||||
|
||||
// Verify the embedded binary report
|
||||
return v.VerifyAttestation(claims.RawReport, teeNonce, vTpmNonce)
|
||||
}
|
||||
|
||||
func ReadTDXAttestationPolicy(policyPath string, policy *checkconfig.Config) error {
|
||||
policyByte, err := os.ReadFile(policyPath)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package tdx
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
)
|
||||
|
||||
func TestVerifyEAT_TDX(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := &eat.EATClaims{
|
||||
Nonce: []byte("test-nonce"),
|
||||
IssuedAt: time.Now().Unix(),
|
||||
RawReport: []byte("dummy-report"),
|
||||
PlatformType: "TDX",
|
||||
}
|
||||
|
||||
jwtEncoder := eat.NewJWTEncoder(key, "issuer")
|
||||
token, err := jwtEncoder.Encode(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
vInterface := NewVerifier()
|
||||
|
||||
v, ok := vInterface.(verifier)
|
||||
require.True(t, ok)
|
||||
|
||||
err = v.VerifyEAT([]byte(token), []byte("tee-nonce"), []byte("vtpm-nonce"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed")
|
||||
}
|
||||
|
||||
func TestVerifyEAT_TDX_InvalidToken(t *testing.T) {
|
||||
vInterface := NewVerifier()
|
||||
v, ok := vInterface.(verifier)
|
||||
require.True(t, ok)
|
||||
|
||||
err := v.VerifyEAT([]byte("invalid-token"), nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode EAT token")
|
||||
}
|
||||
|
||||
func TestTeeAttestation_InvalidNonce(t *testing.T) {
|
||||
p := NewProvider()
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err := p.TeeAttestation(nonce)
|
||||
assert.Error(t, err)
|
||||
// Check for likely errors in non-TDX environment
|
||||
// Check for likely errors in non-TDX environment
|
||||
errMsg := err.Error()
|
||||
assert.True(t,
|
||||
strings.Contains(errMsg, "no such file or directory") ||
|
||||
strings.Contains(errMsg, "permission denied") ||
|
||||
strings.Contains(errMsg, "failed to open TDX device"),
|
||||
"unexpected error message: %s", errMsg,
|
||||
)
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/google/go-tpm/legacy/tpm2"
|
||||
"github.com/google/go-tpm/tpmutil"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
@@ -181,10 +182,22 @@ func (v verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []
|
||||
return VTPMVerify(report, teeNonce, vTpmNonce, v.writer, v.Policy)
|
||||
}
|
||||
|
||||
func (v verifier) JSONToPolicy(path string) error {
|
||||
func (v *verifier) JSONToPolicy(path string) error {
|
||||
return ReadPolicy(path, v.Policy)
|
||||
}
|
||||
|
||||
// VerifyEAT verifies an EAT token and extracts the binary report for verification.
|
||||
func (v *verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
// Decode EAT token
|
||||
claims, err := eat.Decode(eatToken, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode EAT token: %w", err)
|
||||
}
|
||||
|
||||
// Verify the embedded binary report
|
||||
return v.VerifyAttestation(claims.RawReport, teeNonce, vTpmNonce)
|
||||
}
|
||||
|
||||
func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool, vmpl uint) ([]byte, error) {
|
||||
attestation, err := FetchQuote(vTPMNonce)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package vtpm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
)
|
||||
|
||||
func TestVerifyEAT(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims := &eat.EATClaims{
|
||||
Nonce: []byte("test-nonce"),
|
||||
IssuedAt: time.Now().Unix(),
|
||||
RawReport: []byte("dummy-report"), // This will be passed to VerifyAttestation
|
||||
PlatformType: "SNP-vTPM",
|
||||
}
|
||||
|
||||
jwtEncoder := eat.NewJWTEncoder(key, "issuer")
|
||||
token, err := jwtEncoder.Encode(claims)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer := &mockWriter{}
|
||||
vInterface := NewVerifier(writer)
|
||||
v, ok := vInterface.(*verifier)
|
||||
require.True(t, ok)
|
||||
|
||||
err = v.VerifyEAT([]byte(token), []byte("tee-nonce"), []byte("vtpm-nonce"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed")
|
||||
}
|
||||
|
||||
func TestVerifyEAT_InvalidToken(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
vInterface := NewVerifier(writer)
|
||||
v, ok := vInterface.(*verifier)
|
||||
require.True(t, ok)
|
||||
|
||||
err := v.VerifyEAT([]byte("invalid-token"), nil, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to decode EAT token")
|
||||
}
|
||||
|
||||
func TestProvider_Methods(t *testing.T) {
|
||||
p := NewProvider(true, 1)
|
||||
|
||||
originalExternalTPM := ExternalTPM
|
||||
defer func() { ExternalTPM = originalExternalTPM }()
|
||||
|
||||
ExternalTPM = &mockTPM{Buffer: &bytes.Buffer{}}
|
||||
|
||||
_, err := p.VTpmAttestation([]byte("nonce"))
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = p.TeeAttestation([]byte("nonce"))
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp.Quote, nil
|
||||
return resp.EatToken, nil
|
||||
}
|
||||
|
||||
func (c *client) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) {
|
||||
|
||||
@@ -39,7 +39,7 @@ func (m *mockAttestationServer) FetchAttestation(ctx context.Context, req *attes
|
||||
}
|
||||
|
||||
return &attestation_v1.AttestationResponse{
|
||||
Quote: []byte("mock-attestation-quote"),
|
||||
EatToken: []byte("mock-attestation-quote"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user