From de50b6d2d45cbec1e422245c848c1ab465cec3f3 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 11 Feb 2026 18:16:35 +0300 Subject: [PATCH] COCOS-560 - EAT (#561) * feat: Implement EAT (Evidence Attestation Token) generation and verification for attestation responses, replacing raw quotes with EAT tokens in the attestation service and protobuf. Signed-off-by: Sammy Oina * style: standardize comment formatting and fix a debug log format specifier. Signed-off-by: Sammy Oina * fix pkg test Signed-off-by: Sammy Oina * feat: Introduce named constants for OEM IDs and use them in attestation claim extraction. Signed-off-by: SammyOina * feat: Implement and test minimum length validation for EAT nonce in `NewEATClaims`. Signed-off-by: SammyOina * feat: Add EATClaims.Sanitize method and integrate it into the validator to enforce claim dependencies. Signed-off-by: SammyOina * feat: Add Signature field to SNPExtensions and TDXExtensions for enhanced claim validation Signed-off-by: Sammy Oina * feat: Update dependencies and improve code structure in attestation package Signed-off-by: Sammy Oina * feat: Introduce comprehensive test suites for EAT, ATLS, TDX, Azure SNP, and vTPM attestation, and improve EAT decoder robustness. Signed-off-by: Sammy Oina * feat: Add encryption and admin keys, an encrypted algorithm file, and update go.mod to use go-jose/v4. Signed-off-by: Sammy Oina * feat: add new encryption and KBS admin keys while improving TDX attestation test error handling. Signed-off-by: Sammy Oina * feat: Add new KBS admin and encryption keys, an encrypted linear regression algorithm, and refactor TDX test error message checks. Signed-off-by: Sammy Oina * feat: Implement Azure SNP attestation policy, update certificate verification, and add key management. Signed-off-by: Sammy Oina * refactor: replace hardcoded string literals with variables in Azure SNP attestation tests. Signed-off-by: Sammy Oina * feat: Refactor TDX EAT claims to use individual RTMR fields with `tdx_` prefixes and add an `IntUse` field. Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina Signed-off-by: SammyOina --- cmd/agent/main.go | 3 +- cmd/attestation-service/main.go | 91 +++++- cmd/ingress-proxy/main.go | 11 +- go.mod | 7 +- go.sum | 6 + .../proto/attestation/v1/attestation.pb.go | 12 +- .../proto/attestation/v1/attestation.proto | 2 +- manager/attestation_policy_embed.go | 2 +- pkg/atls/atls_test.go | 122 ++++---- pkg/atls/attestation_provider.go | 24 +- pkg/atls/certificate_provider.go | 5 +- pkg/atls/certificate_verifier.go | 47 +++- pkg/atls/certificate_verifier_test.go | 172 ++++++++++++ pkg/attestation/attestation.go | 12 + pkg/attestation/azure/snp.go | 13 + pkg/attestation/azure/snp_coverage_test.go | 102 +++++++ pkg/attestation/azure/snp_policy_test.go | 262 ++++++++++++++++++ pkg/attestation/azure/snp_test.go | 15 +- pkg/attestation/eat/cbor_encoder.go | 74 +++++ pkg/attestation/eat/cbor_encoder_test.go | 79 ++++++ pkg/attestation/eat/decoder.go | 143 ++++++++++ pkg/attestation/eat/decoder_test.go | 218 +++++++++++++++ pkg/attestation/eat/eat.go | 186 +++++++++++++ pkg/attestation/eat/eat_test.go | 141 ++++++++++ pkg/attestation/eat/extractor.go | 150 ++++++++++ pkg/attestation/eat/extractor_test.go | 147 ++++++++++ pkg/attestation/eat/intuse_test.go | 21 ++ pkg/attestation/eat/jwt_encoder.go | 100 +++++++ pkg/attestation/eat/jwt_encoder_test.go | 135 +++++++++ pkg/attestation/eat/validator.go | 67 +++++ pkg/attestation/eat/validator_test.go | 106 +++++++ pkg/attestation/mocks/verifier.go | 63 +++++ pkg/attestation/tdx/tdx.go | 13 + pkg/attestation/tdx/tdx_coverage_test.go | 69 +++++ pkg/attestation/vtpm/vtpm.go | 15 +- pkg/attestation/vtpm/vtpm_coverage_test.go | 68 +++++ pkg/clients/grpc/attestation/client.go | 2 +- pkg/clients/grpc/attestation/client_test.go | 2 +- .../sev-snp/attestation_policy.json | 7 + .../sev-snp/attestation_policy_tdx.json | 50 ++-- 40 files changed, 2655 insertions(+), 109 deletions(-) create mode 100644 pkg/atls/certificate_verifier_test.go create mode 100644 pkg/attestation/azure/snp_coverage_test.go create mode 100644 pkg/attestation/azure/snp_policy_test.go create mode 100644 pkg/attestation/eat/cbor_encoder.go create mode 100644 pkg/attestation/eat/cbor_encoder_test.go create mode 100644 pkg/attestation/eat/decoder.go create mode 100644 pkg/attestation/eat/decoder_test.go create mode 100644 pkg/attestation/eat/eat.go create mode 100644 pkg/attestation/eat/eat_test.go create mode 100644 pkg/attestation/eat/extractor.go create mode 100644 pkg/attestation/eat/extractor_test.go create mode 100644 pkg/attestation/eat/intuse_test.go create mode 100644 pkg/attestation/eat/jwt_encoder.go create mode 100644 pkg/attestation/eat/jwt_encoder_test.go create mode 100644 pkg/attestation/eat/validator.go create mode 100644 pkg/attestation/eat/validator_test.go create mode 100644 pkg/attestation/tdx/tdx_coverage_test.go create mode 100644 pkg/attestation/vtpm/vtpm_coverage_test.go diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 2ac994b9..2b400a5b 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -138,7 +138,6 @@ func main() { } }) - var provider attestation.Provider ccPlatform := attestation.CCPlatform() azureConfig := azure.NewEnvConfigFromAgent( @@ -217,7 +216,7 @@ func main() { CertsURL: cfg.CAUrl, }) } - certProvider, err = atls.NewProvider(provider, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK) + certProvider, err = atls.NewProvider(attClient, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK) if err != nil { logger.Error(fmt.Sprintf("failed to create certificate provider: %s", err)) exitCode = 1 diff --git a/cmd/attestation-service/main.go b/cmd/attestation-service/main.go index ebe49aa1..9e2488b8 100644 --- a/cmd/attestation-service/main.go +++ b/cmd/attestation-service/main.go @@ -4,6 +4,7 @@ package main import ( "context" + "crypto/ecdsa" "fmt" "log/slog" "net" @@ -16,6 +17,7 @@ import ( attestationpb "github.com/ultravioletrs/cocos/internal/proto/attestation/v1" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/tdx" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" @@ -35,6 +37,8 @@ type config struct { AgentOSBuild string `env:"AGENT_OS_BUILD" envDefault:"UVC"` AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"` AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"` + EATFormat string `env:"ATTESTATION_EAT_FORMAT" envDefault:"CBOR"` // JWT or CBOR + EATIssuer string `env:"ATTESTATION_EAT_ISSUER" envDefault:"cocos-attestation-service"` } func main() { @@ -121,10 +125,21 @@ func main() { return } + // Generate EAT signing key + signingKey, err := eat.GenerateSigningKey() + if err != nil { + logger.Error(fmt.Sprintf("failed to generate EAT signing key: %s", err)) + exitCode = 1 + return + } + grpcServer := grpc.NewServer() svc := &service{ - provider: provider, - logger: logger, + provider: provider, + logger: logger, + signingKey: signingKey, + eatFormat: cfg.EATFormat, + eatIssuer: cfg.EATIssuer, } attestationpb.RegisterAttestationServiceServer(grpcServer, svc) @@ -156,29 +171,37 @@ func main() { type service struct { attestationpb.UnimplementedAttestationServiceServer - provider attestation.Provider - logger *slog.Logger + provider attestation.Provider + logger *slog.Logger + signingKey *ecdsa.PrivateKey + eatFormat string + eatIssuer string } func (s *service) FetchAttestation(ctx context.Context, req *attestationpb.AttestationRequest) (*attestationpb.AttestationResponse, error) { - var quote []byte + var binaryReport []byte var err error + var platformType attestation.PlatformType + // Get binary attestation report based on platform type switch req.PlatformType { case attestationpb.PlatformType_PLATFORM_TYPE_SNP, attestationpb.PlatformType_PLATFORM_TYPE_TDX: var reportData [64]byte copy(reportData[:], req.ReportData) - quote, err = s.provider.TeeAttestation(reportData[:]) + binaryReport, err = s.provider.TeeAttestation(reportData[:]) + platformType = convertPlatformType(req.PlatformType) case attestationpb.PlatformType_PLATFORM_TYPE_VTPM: var nonce [32]byte copy(nonce[:], req.Nonce) - quote, err = s.provider.VTpmAttestation(nonce[:]) + binaryReport, err = s.provider.VTpmAttestation(nonce[:]) + platformType = attestation.VTPM case attestationpb.PlatformType_PLATFORM_TYPE_SNP_VTPM: var reportData [64]byte copy(reportData[:], req.ReportData) var nonce [32]byte copy(nonce[:], req.Nonce) - quote, err = s.provider.Attestation(reportData[:], nonce[:]) + binaryReport, err = s.provider.Attestation(reportData[:], nonce[:]) + platformType = attestation.SNPvTPM default: return nil, fmt.Errorf("unsupported platform type") } @@ -187,7 +210,57 @@ func (s *service) FetchAttestation(ctx context.Context, req *attestationpb.Attes return nil, err } - return &attestationpb.AttestationResponse{Quote: quote}, nil + // Create EAT claims from binary report + nonce := req.ReportData + if len(req.Nonce) > 0 { + nonce = req.Nonce + } + + claims, err := eat.NewEATClaims(binaryReport, nonce, platformType) + if err != nil { + s.logger.Error(fmt.Sprintf("failed to create EAT claims: %s", err)) + return nil, fmt.Errorf("failed to create EAT claims: %w", err) + } + + // Encode to EAT token based on configured format + var eatToken []byte + switch s.eatFormat { + case "JWT": + tokenString, err := eat.EncodeToJWT(claims, s.signingKey, s.eatIssuer) + if err != nil { + return nil, fmt.Errorf("failed to encode JWT: %w", err) + } + eatToken = []byte(tokenString) + case "CBOR": + eatToken, err = eat.EncodeToCBOR(claims, s.signingKey, s.eatIssuer) + if err != nil { + return nil, fmt.Errorf("failed to encode CBOR: %w", err) + } + default: + return nil, fmt.Errorf("unsupported EAT format: %s", s.eatFormat) + } + + s.logger.Debug(fmt.Sprintf("generated EAT token (%s format) for platform %v", s.eatFormat, platformType)) + + return &attestationpb.AttestationResponse{EatToken: eatToken}, nil +} + +// convertPlatformType converts protobuf platform type to internal platform type. +func convertPlatformType(pt attestationpb.PlatformType) attestation.PlatformType { + switch pt { + case attestationpb.PlatformType_PLATFORM_TYPE_SNP: + return attestation.SNP + case attestationpb.PlatformType_PLATFORM_TYPE_TDX: + return attestation.TDX + case attestationpb.PlatformType_PLATFORM_TYPE_VTPM: + return attestation.VTPM + case attestationpb.PlatformType_PLATFORM_TYPE_SNP_VTPM: + return attestation.SNPvTPM + case attestationpb.PlatformType_PLATFORM_TYPE_AZURE: + return attestation.Azure + default: + return attestation.NoCC + } } func (s *service) GetAzureToken(ctx context.Context, req *attestationpb.AzureTokenRequest) (*attestationpb.AzureTokenResponse, error) { diff --git a/cmd/ingress-proxy/main.go b/cmd/ingress-proxy/main.go index 5ce3e417..fcb2a435 100644 --- a/cmd/ingress-proxy/main.go +++ b/cmd/ingress-proxy/main.go @@ -18,6 +18,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/atls" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" + attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" "github.com/ultravioletrs/cocos/pkg/ingress" "golang.org/x/sync/errgroup" ) @@ -76,7 +77,6 @@ func run(cfg config) error { } // Initialize Certificate Provider - var provider attestation.Provider ccPlatform := attestation.CCPlatform() azureConfig := azure.NewEnvConfigFromAgent( @@ -90,13 +90,20 @@ func run(cfg config) error { var certProvider atls.CertificateProvider if ccPlatform != attestation.NoCC { + // Create attestation client + attClient, err := attestation_client.NewClient("/run/cocos/attestation.sock") + if err != nil { + return fmt.Errorf("failed to create attestation client: %w", err) + } + defer attClient.Close() + var certsSDK sdk.SDK if cfg.CAUrl != "" { certsSDK = sdk.NewSDK(sdk.Config{ CertsURL: cfg.CAUrl, }) } - certProvider, err = atls.NewProvider(provider, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK) + certProvider, err = atls.NewProvider(attClient, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK) if err != nil { return fmt.Errorf("failed to create certificate provider: %w", err) } diff --git a/go.mod b/go.mod index 2cb46194..fd6cae8e 100644 --- a/go.mod +++ b/go.mod @@ -25,9 +25,12 @@ require ( cloud.google.com/go/storage v1.57.2 github.com/absmach/supermq v0.18.4 github.com/caarlos0/env/v10 v10.0.0 + github.com/fxamacker/cbor/v2 v2.9.0 github.com/go-chi/chi/v5 v5.2.3 + github.com/go-jose/go-jose/v4 v4.1.3 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/gce-tcb-verifier v0.3.1 + github.com/veraison/go-cose v1.3.0 ) require ( @@ -54,7 +57,6 @@ require ( github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-jose/go-jose/v3 v3.0.4 // indirect - github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/gofrs/uuid/v5 v5.4.0 // indirect github.com/google/certificate-transparency-go v1.1.8 // indirect github.com/google/go-attestation v0.5.1 // indirect @@ -73,6 +75,7 @@ require ( github.com/opencontainers/image-spec v1.1.0 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240917153116-6f2963f01587 // indirect github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/x448/float16 v0.8.4 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.38.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect @@ -118,7 +121,7 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/net v0.48.0 // indirect + golang.org/x/net v0.48.0 golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 golang.org/x/text v0.33.0 // indirect diff --git a/go.sum b/go.sum index 08647cee..f4588bd4 100644 --- a/go.sum +++ b/go.sum @@ -96,6 +96,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= +github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= @@ -274,6 +276,10 @@ github.com/stretchr/objx v0.5.3/go.mod h1:rDQraq+vQZU7Fde9LOZLr8Tax6zZvy4kuNKF+Q github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/veraison/go-cose v1.3.0 h1:2/H5w8kdSpQJyVtIhx8gmwPJ2uSz1PkyWFx0idbd7rk= +github.com/veraison/go-cose v1.3.0/go.mod h1:df09OV91aHoQWLmy1KsDdYiagtXgyAwAl8vFeFn1gMc= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= diff --git a/internal/proto/attestation/v1/attestation.pb.go b/internal/proto/attestation/v1/attestation.pb.go index 7ff78887..18560e9a 100644 --- a/internal/proto/attestation/v1/attestation.pb.go +++ b/internal/proto/attestation/v1/attestation.pb.go @@ -144,7 +144,7 @@ func (x *AttestationRequest) GetPlatformType() PlatformType { type AttestationResponse struct { state protoimpl.MessageState `protogen:"open.v1"` - Quote []byte `protobuf:"bytes,1,opt,name=quote,proto3" json:"quote,omitempty"` + EatToken []byte `protobuf:"bytes,1,opt,name=eat_token,json=eatToken,proto3" json:"eat_token,omitempty"` // EAT token (JWT or CBOR format) unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -179,9 +179,9 @@ func (*AttestationResponse) Descriptor() ([]byte, []int) { return file_internal_proto_attestation_v1_attestation_proto_rawDescGZIP(), []int{1} } -func (x *AttestationResponse) GetQuote() []byte { +func (x *AttestationResponse) GetEatToken() []byte { if x != nil { - return x.Quote + return x.EatToken } return nil } @@ -283,9 +283,9 @@ const file_internal_proto_attestation_v1_attestation_proto_rawDesc = "" + "\vreport_data\x18\x01 \x01(\fR\n" + "reportData\x12\x14\n" + "\x05nonce\x18\x02 \x01(\fR\x05nonce\x12A\n" + - "\rplatform_type\x18\x03 \x01(\x0e2\x1c.attestation.v1.PlatformTypeR\fplatformType\"+\n" + - "\x13AttestationResponse\x12\x14\n" + - "\x05quote\x18\x01 \x01(\fR\x05quote\")\n" + + "\rplatform_type\x18\x03 \x01(\x0e2\x1c.attestation.v1.PlatformTypeR\fplatformType\"2\n" + + "\x13AttestationResponse\x12\x1b\n" + + "\teat_token\x18\x01 \x01(\fR\beatToken\")\n" + "\x11AzureTokenRequest\x12\x14\n" + "\x05nonce\x18\x01 \x01(\fR\x05nonce\"*\n" + "\x12AzureTokenResponse\x12\x14\n" + diff --git a/internal/proto/attestation/v1/attestation.proto b/internal/proto/attestation/v1/attestation.proto index f72fae32..0499991d 100644 --- a/internal/proto/attestation/v1/attestation.proto +++ b/internal/proto/attestation/v1/attestation.proto @@ -16,7 +16,7 @@ message AttestationRequest { } message AttestationResponse { - bytes quote = 1; + bytes eat_token = 1; // EAT token (JWT or CBOR format) } message AzureTokenRequest { diff --git a/manager/attestation_policy_embed.go b/manager/attestation_policy_embed.go index f9573104..9f4ebb13 100644 --- a/manager/attestation_policy_embed.go +++ b/manager/attestation_policy_embed.go @@ -8,7 +8,7 @@ package manager import ( "context" - attestationPolicy "github.com/ultravioletrs/cocos/scripts/attestation_policy" + attestationPolicy "github.com/ultravioletrs/cocos/scripts/attestation_policy/sev-snp" ) func (ms *managerService) FetchAttestationPolicy(_ context.Context, _ string) ([]byte, error) { diff --git a/pkg/atls/atls_test.go b/pkg/atls/atls_test.go index 1eb639a6..18163cf7 100644 --- a/pkg/atls/atls_test.go +++ b/pkg/atls/atls_test.go @@ -3,6 +3,7 @@ package atls import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -32,7 +33,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/mocks" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/crypto/sha3" "google.golang.org/protobuf/encoding/protojson" @@ -44,6 +44,32 @@ const ( var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}} +// mockAttestationClient is a simple mock for testing. +type mockAttestationClient struct { + mock.Mock +} + +func (m *mockAttestationClient) GetAttestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType attestation.PlatformType) ([]byte, error) { + args := m.Called(ctx, reportData, nonce, attType) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]byte), args.Error(1) +} + +func (m *mockAttestationClient) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) { + args := m.Called(ctx, nonce) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]byte), args.Error(1) +} + +func (m *mockAttestationClient) Close() error { + args := m.Called() + return args.Error(0) +} + func generateTestCertPEM(t *testing.T) string { return generateTestCertPEMWithSubject(t, "test") } @@ -133,9 +159,8 @@ func TestUnifiedCertificateGenerator(t *testing.T) { // TestPlatformAttestationProvider tests the platform attestation provider. func TestPlatformAttestationProvider(t *testing.T) { - mockProvider := new(mocks.Provider) - t.Run("NewAttestationProvider", func(t *testing.T) { + mockClient := new(mockAttestationClient) cases := []struct { name string platformType attestation.PlatformType @@ -149,7 +174,7 @@ func TestPlatformAttestationProvider(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - provider, err := NewAttestationProvider(mockProvider, c.platformType) + provider, err := NewAttestationProvider(mockClient, c.platformType) if c.expectError { assert.Error(t, err) @@ -164,10 +189,11 @@ func TestPlatformAttestationProvider(t *testing.T) { }) t.Run("GetAttestation", func(t *testing.T) { + mockClient := new(mockAttestationClient) expectedAttestation := []byte("test-attestation") - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(expectedAttestation, nil) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedAttestation, nil) - provider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) pubKey := []byte("test-pubkey") @@ -177,14 +203,14 @@ func TestPlatformAttestationProvider(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expectedAttestation, attestation) - mockProvider.AssertExpectations(t) + mockClient.AssertExpectations(t) }) t.Run("GetAttestationError", func(t *testing.T) { - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) - provider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + provider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) _, err = provider.Attest([]byte("pubkey"), []byte("nonce")) @@ -194,12 +220,11 @@ func TestPlatformAttestationProvider(t *testing.T) { // TestAttestedCertificateProvider tests the attested certificate provider. func TestAttestedCertificateProvider(t *testing.T) { - mockProvider := new(mocks.Provider) - t.Run("GetCertificateSuccess", func(t *testing.T) { - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) subject := DefaultCertificateSubject() @@ -223,8 +248,8 @@ func TestAttestedCertificateProvider(t *testing.T) { }) t.Run("InvalidServerName", func(t *testing.T) { - mockProvider := new(mocks.Provider) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + mockClient := new(mockAttestationClient) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) @@ -237,10 +262,10 @@ func TestAttestedCertificateProvider(t *testing.T) { }) t.Run("AttestationError", func(t *testing.T) { - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) @@ -260,10 +285,10 @@ func TestAttestedCertificateProvider(t *testing.T) { // TestNewProvider tests the factory function. func TestNewProvider(t *testing.T) { - mockProvider := new(mocks.Provider) + mockClient := new(mockAttestationClient) t.Run("SelfSignedProvider", func(t *testing.T) { - provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil) + provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil) assert.NoError(t, err) assert.NotNil(t, provider) }) @@ -271,19 +296,19 @@ func TestNewProvider(t *testing.T) { t.Run("CASignedProviderWithSDK", func(t *testing.T) { mockSDK := sdkmocks.NewSDK(t) - provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) + provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) assert.NoError(t, err) assert.NotNil(t, provider) }) t.Run("SelfSignedProviderNilSDK", func(t *testing.T) { - provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", nil) + provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", nil) assert.NoError(t, err) assert.NotNil(t, provider) }) t.Run("InvalidPlatformType", func(t *testing.T) { - _, err := NewProvider(mockProvider, attestation.PlatformType(999), "", "", nil) + _, err := NewProvider(mockClient, attestation.PlatformType(999), "", "", nil) assert.Error(t, err) }) } @@ -714,8 +739,8 @@ func TestCertificateVerification(t *testing.T) { // TestAttestedCAProvider tests the CA-signed certificate provider. func TestAttestedCAProvider(t *testing.T) { - mockProvider := new(mocks.Provider) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + mockClient := new(mockAttestationClient) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) subject := DefaultCertificateSubject() @@ -740,8 +765,8 @@ func TestAttestedCAProvider(t *testing.T) { // TestCASignedCertificateErrors tests error cases in CA-signed certificate generation. func TestCASignedCertificateErrors(t *testing.T) { - mockProvider := new(mocks.Provider) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + mockClient := new(mockAttestationClient) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) subject := DefaultCertificateSubject() @@ -787,8 +812,8 @@ func TestCASignedCertificateErrors(t *testing.T) { // TestGetCertificateErrors tests error paths in certificate generation. func TestGetCertificateErrors(t *testing.T) { t.Run("InvalidServerNameFormat", func(t *testing.T) { - mockProvider := new(mocks.Provider) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + mockClient := new(mockAttestationClient) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) @@ -803,10 +828,10 @@ func TestGetCertificateErrors(t *testing.T) { }) t.Run("AttestationProviderError", func(t *testing.T) { - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed")) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject()) @@ -824,10 +849,10 @@ func TestGetCertificateErrors(t *testing.T) { }) t.Run("CASignedCertificateError", func(t *testing.T) { - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil) - attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM) + attestationProvider, err := NewAttestationProvider(mockClient, attestation.SNPvTPM) require.NoError(t, err) mockSDK := sdkmocks.NewSDK(t) @@ -904,7 +929,8 @@ func TestCertificateVerificationEdgeCases(t *testing.T) { err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType) assert.Error(t, err) - assert.Contains(t, err.Error(), "unsupported platform type") + // The error occurs during EAT token decoding before platform type validation + assert.Contains(t, err.Error(), "failed to decode EAT token") }) } @@ -973,12 +999,12 @@ func TestIntegrationScenarios(t *testing.T) { require.NoError(t, err) t.Run("FullSelfSignedFlow", func(t *testing.T) { - // Setup mock provider - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) + // Setup mock client + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) // Create provider - provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil) + provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil) require.NoError(t, err) // Generate certificate @@ -1017,10 +1043,10 @@ func TestIntegrationScenarios(t *testing.T) { mockSDK.On("CreateCSR", mock.Anything, mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil)) mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil)) - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) - provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) + provider, err := NewProvider(mockClient, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK) require.NoError(t, err) nonce := make([]byte, 64) @@ -1041,17 +1067,17 @@ func TestIntegrationScenarios(t *testing.T) { assert.NotNil(t, parsedCert.Subject) - mockProvider.AssertExpectations(t) + mockClient.AssertExpectations(t) mockSDK.AssertExpectations(t) }) } // TestConcurrentAccess tests concurrent access scenarios. func TestConcurrentAccess(t *testing.T) { - mockProvider := new(mocks.Provider) - mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) + mockClient := new(mockAttestationClient) + mockClient.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil) - provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil) + provider, err := NewProvider(mockClient, attestation.SNPvTPM, "", "", nil) require.NoError(t, err) const numGoroutines = 10 diff --git a/pkg/atls/attestation_provider.go b/pkg/atls/attestation_provider.go index d19626d3..9a3fbbba 100644 --- a/pkg/atls/attestation_provider.go +++ b/pkg/atls/attestation_provider.go @@ -3,10 +3,12 @@ package atls import ( + "context" "encoding/asn1" "fmt" "github.com/ultravioletrs/cocos/pkg/attestation" + attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" "golang.org/x/crypto/sha3" ) @@ -19,20 +21,20 @@ type AttestationProvider interface { // PlatformAttestationProvider handles platform attestation operations. type platformAttestationProvider struct { - provider attestation.Provider + attClient attestation_client.Client oid asn1.ObjectIdentifier platformType attestation.PlatformType } // NewAttestationProvider creates a new attestation provider for the given platform type. -func NewAttestationProvider(provider attestation.Provider, platformType attestation.PlatformType) (AttestationProvider, error) { +func NewAttestationProvider(attClient attestation_client.Client, platformType attestation.PlatformType) (AttestationProvider, error) { oid, err := OID(platformType) if err != nil { return nil, fmt.Errorf("failed to get OID: %w", err) } return &platformAttestationProvider{ - provider: provider, + attClient: attClient, oid: oid, platformType: platformType, }, nil @@ -41,7 +43,21 @@ func NewAttestationProvider(provider attestation.Provider, platformType attestat func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) { teeNonce := append(pubKey, nonce...) hashNonce := sha3.Sum512(teeNonce) - return p.provider.Attestation(hashNonce[:], hashNonce[:32]) + + var reportData [64]byte + copy(reportData[:], hashNonce[:]) + + var nonceArray [32]byte + copy(nonceArray[:], hashNonce[:32]) + + // Get signed EAT token from attestation service + // The attestation service maintains a persistent signing key and returns a pre-signed token + eatToken, err := p.attClient.GetAttestation(context.Background(), reportData, nonceArray, p.platformType) + if err != nil { + return nil, fmt.Errorf("failed to get attestation from service: %w", err) + } + + return eatToken, nil } func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier { diff --git a/pkg/atls/certificate_provider.go b/pkg/atls/certificate_provider.go index 5c3a6eed..75e9ccfb 100644 --- a/pkg/atls/certificate_provider.go +++ b/pkg/atls/certificate_provider.go @@ -19,6 +19,7 @@ import ( "github.com/absmach/certs" sdk "github.com/absmach/certs/sdk" "github.com/ultravioletrs/cocos/pkg/attestation" + attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" ) // CertificateProvider defines the interface for providing TLS certificates. @@ -173,8 +174,8 @@ func (p *attestedCertificateProvider) generateCASignedCertificate(ctx context.Co return block.Bytes, nil } -func NewProvider(provider attestation.Provider, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) { - attestationProvider, err := NewAttestationProvider(provider, platformType) +func NewProvider(attClient attestation_client.Client, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) { + attestationProvider, err := NewAttestationProvider(attClient, platformType) if err != nil { return nil, fmt.Errorf("failed to create attestation provider: %w", err) } diff --git a/pkg/atls/certificate_verifier.go b/pkg/atls/certificate_verifier.go index 8a444f49..566735de 100644 --- a/pkg/atls/certificate_verifier.go +++ b/pkg/atls/certificate_verifier.go @@ -10,6 +10,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" "github.com/ultravioletrs/cocos/pkg/attestation/tdx" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/crypto/sha3" @@ -21,11 +22,15 @@ type CertificateVerifier interface { // CertificateVerifier handles certificate verification operations. type certificateVerifier struct { - rootCAs *x509.CertPool + rootCAs *x509.CertPool + verifierProvider func(attestation.PlatformType) (attestation.Verifier, error) } func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier { - return &certificateVerifier{rootCAs: rootCAs} + return &certificateVerifier{ + rootCAs: rootCAs, + verifierProvider: platformVerifier, + } } func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error { @@ -75,15 +80,43 @@ func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, } func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error { - verifier, err := platformVerifier(platformType) + // Decode EAT token from certificate extension + // Note: We don't have the public key for verification here, so we decode without verification + // The signature was created by the attester, and we trust the TEE hardware verification + claims, err := eat.DecodeCBOR(extension, nil) + if err != nil { + return fmt.Errorf("failed to decode EAT token: %w", err) + } + + // Verify nonce matches + teeNonce := append(pubKey, nonce...) + hashNonce := sha3.Sum512(teeNonce) + + // Compare nonces (EAT nonce should match our computed nonce) + if len(claims.Nonce) != len(hashNonce) { + return fmt.Errorf("nonce length mismatch: expected %d, got %d", len(hashNonce), len(claims.Nonce)) + } + + nonceMatch := true + for i := range claims.Nonce { + if claims.Nonce[i] != hashNonce[i] { + nonceMatch = false + break + } + } + + if !nonceMatch { + return fmt.Errorf("nonce mismatch in EAT token") + } + + // Get platform verifier + verifier, err := v.verifierProvider(platformType) if err != nil { return fmt.Errorf("failed to get platform verifier: %w", err) } - teeNonce := append(pubKey, nonce...) - hashNonce := sha3.Sum512(teeNonce) - - if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:32]); err != nil { + // Verify the binary attestation report embedded in EAT token + if err = verifier.VerifyAttestation(claims.RawReport, hashNonce[:], hashNonce[:32]); err != nil { return fmt.Errorf("failed to verify attestation: %w", err) } diff --git a/pkg/atls/certificate_verifier_test.go b/pkg/atls/certificate_verifier_test.go new file mode 100644 index 00000000..2593c4df --- /dev/null +++ b/pkg/atls/certificate_verifier_test.go @@ -0,0 +1,172 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package atls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "testing" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" + "golang.org/x/crypto/sha3" +) + +type mockVerifier struct { + verifyAttestationFunc func(report []byte, teeNonce []byte, vTpmNonce []byte) error +} + +func (m *mockVerifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error { + if m.verifyAttestationFunc != nil { + return m.verifyAttestationFunc(report, teeNonce, vTpmNonce) + } + return nil +} + +func (m *mockVerifier) VerifTeeAttestation(report []byte, teeNonce []byte) error { + return nil +} + +func (m *mockVerifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error { + return nil +} + +func (m *mockVerifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error { + return nil +} + +func (m *mockVerifier) JSONToPolicy(path string) error { + return nil +} + +func TestVerifyPeerCertificate_Success(t *testing.T) { + // Setup keys and cert templates + caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + caCert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(caCert) + + // Create verifier with mock platform verifier + verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) + verifier.verifierProvider = func(pt attestation.PlatformType) (attestation.Verifier, error) { + return &mockVerifier{ + verifyAttestationFunc: func(report []byte, teeNonce []byte, vTpmNonce []byte) error { + return nil + }, + }, nil + } + + peerKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + // Prepare EAT Claims + nonce := []byte("test-nonce") + peerPubKeyDER, err := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) + require.NoError(t, err) + + teeNonce := append(peerPubKeyDER, nonce...) + hashNonce := sha3.Sum512(teeNonce) + + claims := eat.EATClaims{ + Nonce: hashNonce[:], + RawReport: []byte("mock-report"), + } + eatBytes, err := cbor.Marshal(claims) + require.NoError(t, err) + + // Create Peer Cert with EAT extension + peerTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{CommonName: "Test Peer"}, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + ExtraExtensions: []pkix.Extension{ + { + Id: SNPvTPMOID, // Use SNPvTPMOID as default testing OID + Value: eatBytes, + }, + }, + } + peerCertDER, err := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) + require.NoError(t, err) + + err = verifier.VerifyPeerCertificate([][]byte{peerCertDER}, nil, nonce) + assert.NoError(t, err) +} + +func TestVerifyPeerCertificate_Failures(t *testing.T) { + caKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + caCertDER, _ := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + caCert, _ := x509.ParseCertificate(caCertDER) + rootCAs := x509.NewCertPool() + rootCAs.AddCert(caCert) + + verifier := NewCertificateVerifier(rootCAs).(*certificateVerifier) + + peerKey, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + peerTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(2), + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + certDER, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) + + err := verifier.VerifyPeerCertificate([][]byte{certDER}, nil, []byte("nonce")) + assert.ErrorContains(t, err, "attestation extension not found") + + nonce := []byte("nonce1") + wrongNonce := []byte("nonce2") + peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) + teeNonce := append(peerPubKeyDER, wrongNonce...) // Mismatching input + hashNonce := sha3.Sum512(teeNonce) + + claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")} + eatBytes, _ := cbor.Marshal(claims) + + peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}} + certDERMismatch, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) + + err = verifier.VerifyPeerCertificate([][]byte{certDERMismatch}, nil, nonce) // Pass nonce1 + assert.ErrorContains(t, err, "nonce mismatch") +} + +func TestVerifyPeerCertificate_Empty(t *testing.T) { + verifier := NewCertificateVerifier(nil) + err := verifier.VerifyPeerCertificate(nil, nil, nil) + assert.ErrorContains(t, err, "no certificates provided") +} diff --git a/pkg/attestation/attestation.go b/pkg/attestation/attestation.go index fa44b1b2..57a8cc1c 100644 --- a/pkg/attestation/attestation.go +++ b/pkg/attestation/attestation.go @@ -42,9 +42,20 @@ type PcrConfig struct { PCRValues PcrValues `json:"pcr_values"` } +// Config represents attestation configuration. type Config struct { *check.Config *PcrConfig + *EATValidation +} + +// EATValidation contains EAT token validation settings. +type EATValidation struct { + RequireEATFormat bool `json:"require_eat_format"` + AllowedFormats []string `json:"allowed_formats"` + MaxTokenAgeSeconds int `json:"max_token_age_seconds"` + RequireClaims []string `json:"require_claims"` + VerifySignature bool `json:"verify_signature"` } type ccCheck struct { @@ -61,6 +72,7 @@ type Provider interface { type Verifier interface { VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error + VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error VerifTeeAttestation(report []byte, teeNonce []byte) error VerifVTpmAttestation(report []byte, vTpmNonce []byte) error JSONToPolicy(path string) error diff --git a/pkg/attestation/azure/snp.go b/pkg/attestation/azure/snp.go index 8494303d..d2ab3dad 100644 --- a/pkg/attestation/azure/snp.go +++ b/pkg/attestation/azure/snp.go @@ -20,6 +20,7 @@ import ( "github.com/google/go-sev-guest/tools/lib/report" "github.com/google/go-tpm-tools/proto/attest" "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/protobuf/proto" @@ -154,6 +155,18 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce [] return nil } +// VerifyEAT verifies an EAT token and extracts the binary report for verification. +func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error { + // Decode EAT token + claims, err := eat.Decode(eatToken, nil) + if err != nil { + return fmt.Errorf("failed to decode EAT token: %w", err) + } + + // Verify the embedded binary report + return v.VerifyAttestation(claims.RawReport, teeNonce, vTpmNonce) +} + func (a verifier) JSONToPolicy(path string) error { return vtpm.ReadPolicy(path, a.Policy) } diff --git a/pkg/attestation/azure/snp_coverage_test.go b/pkg/attestation/azure/snp_coverage_test.go new file mode 100644 index 00000000..aaf05e91 --- /dev/null +++ b/pkg/attestation/azure/snp_coverage_test.go @@ -0,0 +1,102 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package azure + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + jose "github.com/go-jose/go-jose/v4" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateAttestationPolicy_Success(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + } + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + require.NoError(t, err) + cert, err := x509.ParseCertificate(certDER) + require.NoError(t, err) + + jwk := jose.JSONWebKey{ + Key: &key.PublicKey, + KeyID: testKID, + Algorithm: "RS256", + Use: "sig", + Certificates: []*x509.Certificate{cert}, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jwks := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{jwk}, + } + _ = json.NewEncoder(w).Encode(jwks) + })) + defer server.Close() + + originalMaaURL := MaaURL + MaaURL = server.URL + defer func() { MaaURL = originalMaaURL }() + + token := createTestToken(t, key, server.URL) + + policy, err := GenerateAttestationPolicy(token, "Milan", 0) + require.NoError(t, err) + require.NotNil(t, policy) + assert.Equal(t, "SEV_PRODUCT_MILAN", policy.Config.Policy.Product.Name.String()) +} + +func createTestToken(t *testing.T, key *rsa.PrivateKey, jku string) string { + claims := jwt.MapClaims{ + "iss": "https://test-issuer.com", + "aud": "test-audience", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + "x-ms-isolation-tee": map[string]any{ + "x-ms-sevsnpvm-familyId": "0102030405060708090a0b0c0d0e0f10", + "x-ms-sevsnpvm-imageId": "0102030405060708090a0b0c0d0e0f10", + "x-ms-sevsnpvm-launchmeasurement": "0102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f10", + "x-ms-sevsnpvm-bootloader-svn": float64(1), + "x-ms-sevsnpvm-tee-svn": float64(2), + "x-ms-sevsnpvm-snpfw-svn": float64(3), + "x-ms-sevsnpvm-microcode-svn": float64(4), + "x-ms-sevsnpvm-guestsvn": float64(5), + "x-ms-sevsnpvm-idkeydigest": "0102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f10", + "x-ms-sevsnpvm-reportid": "0102030405060708090a0b0c0d0e0f100102030405060708090a0b0c0d0e0f10", + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["jku"] = jku + token.Header["kid"] = testKID + + signedToken, err := token.SignedString(key) + require.NoError(t, err) + return signedToken +} + +func TestGenerateAttestationPolicy_InvalidToken(t *testing.T) { + // Test with invalid token string + _, err := GenerateAttestationPolicy("invalid-token", "Milan", 0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to validate token") +} diff --git a/pkg/attestation/azure/snp_policy_test.go b/pkg/attestation/azure/snp_policy_test.go new file mode 100644 index 00000000..255445b5 --- /dev/null +++ b/pkg/attestation/azure/snp_policy_test.go @@ -0,0 +1,262 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package azure + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + jose "github.com/go-jose/go-jose/v4" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateAttestationPolicy(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(1 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + cert, err := x509.ParseCertificate(certDER) + require.NoError(t, err) + + tests := []struct { + name string + token string + product string + policy uint64 + setupServer func(t *testing.T, key *rsa.PrivateKey, cert *x509.Certificate) *httptest.Server + wantErr bool + errorMessage string + setupTokenJKU bool + }{ + { + name: "valid token and claims", + product: "Milan-B0", + policy: 0, + setupServer: func(t *testing.T, key *rsa.PrivateKey, cert *x509.Certificate) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case openIDConfigPath: + config := map[string]any{ + "jwks_uri": "http://" + r.Host + certsPath, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(config); err != nil { + t.Errorf("failed to encode config: %v", err) + } + case certsPath: + jwks := generateJWKS(&key.PublicKey, cert) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(jwks); err != nil { + t.Errorf("failed to encode jwks: %v", err) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + }, + setupTokenJKU: true, + wantErr: false, + }, + { + name: "invalid token format", + token: "invalid-token", + product: "Milan-B0", + policy: 0, + setupServer: nil, + wantErr: true, + errorMessage: "failed to parse token", + setupTokenJKU: false, + }, + { + name: "missing familyId", + product: "Milan-B0", + policy: 0, + setupServer: func(t *testing.T, key *rsa.PrivateKey, cert *x509.Certificate) *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case openIDConfigPath: + config := map[string]any{ + "jwks_uri": "http://" + r.Host + certsPath, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(config); err != nil { + t.Errorf("failed to encode config: %v", err) + } + case certsPath: + jwks := generateJWKS(&key.PublicKey, cert) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(jwks); err != nil { + t.Errorf("failed to encode jwks: %v", err) + } + } + })) + }, + setupTokenJKU: true, + wantErr: true, + errorMessage: "failed to get familyId from claims", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var tokenString string + var server *httptest.Server + + if tt.setupServer != nil { + server = tt.setupServer(t, privateKey, cert) + defer server.Close() + + originalURL := MaaURL + MaaURL = "" // Clear it so it uses JKU + defer func() { MaaURL = originalURL }() + } + + if tt.token != "" { + tokenString = tt.token + } else { + // Generate token + claims := createValidClaims() + if tt.name == "missing familyId" { + if tee, ok := claims["x-ms-isolation-tee"].(map[string]any); ok { + delete(tee, "x-ms-sevsnpvm-familyId") + } + } + + jku := "" + if tt.setupTokenJKU && server != nil { + jku = server.URL + } + + var err error + tokenString, err = signToken(claims, privateKey, jku) + require.NoError(t, err) + } + + config, err := GenerateAttestationPolicy(tokenString, tt.product, tt.policy) + + if tt.wantErr { + assert.Error(t, err) + if tt.errorMessage != "" { + assert.Contains(t, err.Error(), tt.errorMessage) + } + assert.Nil(t, config) + } else { + assert.NoError(t, err) + assert.NotNil(t, config) + } + }) + } +} + +func TestVerifier_VerifyEAT(t *testing.T) { + tests := []struct { + name string + eatToken []byte + teeNonce []byte + vTpmNonce []byte + setupToken func() ([]byte, error) + wantErr bool + errorMessage string + }{ + { + name: "invalid cbor", + eatToken: []byte("invalid-cbor"), + teeNonce: testNonce, + vTpmNonce: testNonce, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := NewVerifier(&bytes.Buffer{}) + + token := tt.eatToken + if tt.setupToken != nil { + var err error + token, err = tt.setupToken() + require.NoError(t, err) + } + + err := v.VerifyEAT(token, tt.teeNonce, tt.vTpmNonce) + + if tt.wantErr { + assert.Error(t, err) + if tt.errorMessage != "" { + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// Helper functions + +func createValidClaims() jwt.MapClaims { + return jwt.MapClaims{ + "iss": "https://test-issuer.com", + "aud": "test-audience", + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Unix(), + "nbf": time.Now().Add(-1 * time.Hour).Unix(), + "x-ms-isolation-tee": map[string]any{ + "x-ms-sevsnpvm-familyId": "1234567890abcdef", + "x-ms-sevsnpvm-imageId": "fedcba0987654321", + "x-ms-sevsnpvm-launchmeasurement": "abcdef1234567890", + "x-ms-sevsnpvm-bootloader-svn": float64(1), + "x-ms-sevsnpvm-tee-svn": float64(2), + "x-ms-sevsnpvm-snpfw-svn": float64(3), + "x-ms-sevsnpvm-microcode-svn": float64(4), + "x-ms-sevsnpvm-guestsvn": float64(5), + "x-ms-sevsnpvm-idkeydigest": "1234567890abcdef", + "x-ms-sevsnpvm-reportid": "fedcba0987654321", + }, + } +} + +func signToken(claims jwt.MapClaims, key *rsa.PrivateKey, jku string) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = testKID + if jku != "" { + token.Header["jku"] = jku + } + return token.SignedString(key) +} + +func generateJWKS(pubKey *rsa.PublicKey, cert *x509.Certificate) *jose.JSONWebKeySet { + key := jose.JSONWebKey{ + Key: pubKey, + KeyID: testKID, + Algorithm: "RS256", + Use: "sig", + Certificates: []*x509.Certificate{cert}, + } + return &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{key}, + } +} diff --git a/pkg/attestation/azure/snp_test.go b/pkg/attestation/azure/snp_test.go index f0c5587f..28a4e558 100644 --- a/pkg/attestation/azure/snp_test.go +++ b/pkg/attestation/azure/snp_test.go @@ -22,8 +22,11 @@ import ( ) var ( - testNonce = []byte("test-nonce-12345678901234567890123456789012") - testReport = []byte("test-report-data") + testNonce = []byte("test-nonce-12345678901234567890123456789012") + testReport = []byte("test-report-data") + testKID = "test-kid" + openIDConfigPath = "/.well-known/openid_configuration" + certsPath = "/certs" ) func TestNewProvider(t *testing.T) { @@ -459,19 +462,19 @@ func TestIntegration_FullAttestationFlow(t *testing.T) { if err := json.NewEncoder(w).Encode(response); err != nil { t.Fatalf("Failed to encode response: %v", err) } - case "/.well-known/openid_configuration": + case openIDConfigPath: config := map[string]any{ - "jwks_uri": "maaServer.URL" + "/certs", + "jwks_uri": "maaServer.URL" + certsPath, } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(config); err != nil { t.Fatalf("Failed to encode OpenID configuration: %v", err) } - case "/certs": + case certsPath: jwks := map[string]any{ "keys": []map[string]any{ { - "kid": "test-kid", + "kid": testKID, "kty": "RSA", "use": "sig", "n": "test-n-value", diff --git a/pkg/attestation/eat/cbor_encoder.go b/pkg/attestation/eat/cbor_encoder.go new file mode 100644 index 00000000..81c76cc2 --- /dev/null +++ b/pkg/attestation/eat/cbor_encoder.go @@ -0,0 +1,74 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "crypto/ecdsa" + "crypto/rand" + "fmt" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/veraison/go-cose" +) + +// CBOREncoder encodes EAT claims to CBOR format (CWT - CBOR Web Token). +type CBOREncoder struct { + signingKey *ecdsa.PrivateKey + issuer string +} + +// NewCBOREncoder creates a new CBOR encoder. +func NewCBOREncoder(signingKey *ecdsa.PrivateKey, issuer string) *CBOREncoder { + return &CBOREncoder{ + signingKey: signingKey, + issuer: issuer, + } +} + +// Encode encodes EAT claims to CBOR bytes with COSE_Sign1 signature. +func (e *CBOREncoder) Encode(claims *EATClaims) ([]byte, error) { + // Set standard CWT claims + now := time.Now() + claims.Issuer = e.issuer + claims.IssuedAt = now.Unix() + claims.ExpiresAt = now.Add(5 * time.Minute).Unix() // 5 minute validity + + // Encode claims to CBOR (this will be the payload) + payload, err := cbor.Marshal(claims) + if err != nil { + return nil, fmt.Errorf("failed to encode CBOR payload: %w", err) + } + + // Create COSE Sign1 message + msg := cose.NewSign1Message() + msg.Payload = payload + msg.Headers.Protected.SetAlgorithm(cose.AlgorithmES256) + msg.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte(e.issuer) + + // Create signer from ECDSA private key + signer, err := cose.NewSigner(cose.AlgorithmES256, e.signingKey) + if err != nil { + return nil, fmt.Errorf("failed to create COSE signer: %w", err) + } + + // Sign the message + if err := msg.Sign(rand.Reader, nil, signer); err != nil { + return nil, fmt.Errorf("failed to sign COSE message: %w", err) + } + + // Encode the signed message to CBOR + signed, err := msg.MarshalCBOR() + if err != nil { + return nil, fmt.Errorf("failed to marshal COSE_Sign1: %w", err) + } + + return signed, nil +} + +// EncodeToCBOR is a convenience function to encode EAT claims to CBOR. +func EncodeToCBOR(claims *EATClaims, signingKey *ecdsa.PrivateKey, issuer string) ([]byte, error) { + encoder := NewCBOREncoder(signingKey, issuer) + return encoder.Encode(claims) +} diff --git a/pkg/attestation/eat/cbor_encoder_test.go b/pkg/attestation/eat/cbor_encoder_test.go new file mode 100644 index 00000000..e0c59cd9 --- /dev/null +++ b/pkg/attestation/eat/cbor_encoder_test.go @@ -0,0 +1,79 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/veraison/go-cose" +) + +func TestCBOREncoder_Encode(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + type fields struct { + signingKey *ecdsa.PrivateKey + issuer string + } + type args struct { + claims *EATClaims + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "Valid encoding", + fields: fields{ + signingKey: key, + issuer: "test-issuer", + }, + args: args{ + claims: &EATClaims{ + Nonce: []byte("test-nonce"), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewCBOREncoder(tt.fields.signingKey, tt.fields.issuer) + got, err := e.Encode(tt.args.claims) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, got) + + var msg cose.Sign1Message + err = msg.UnmarshalCBOR(got) + assert.NoError(t, err) + + verifier, err := cose.NewVerifier(cose.AlgorithmES256, &key.PublicKey) + assert.NoError(t, err) + err = msg.Verify(nil, verifier) + assert.NoError(t, err) + } + }) + } +} + +func TestEncodeToCBOR(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + claims := &EATClaims{Nonce: []byte("nonce")} + token, err := EncodeToCBOR(claims, key, "issuer") + assert.NoError(t, err) + assert.NotEmpty(t, token) +} diff --git a/pkg/attestation/eat/decoder.go b/pkg/attestation/eat/decoder.go new file mode 100644 index 00000000..517fb425 --- /dev/null +++ b/pkg/attestation/eat/decoder.go @@ -0,0 +1,143 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "bytes" + "crypto/ecdsa" + "encoding/json" + "fmt" + + "github.com/fxamacker/cbor/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/veraison/go-cose" +) + +// Decoder decodes EAT tokens (auto-detects JWT vs CBOR). +type Decoder struct { + verifyKey *ecdsa.PublicKey +} + +// NewDecoder creates a new EAT decoder. +func NewDecoder(verifyKey *ecdsa.PublicKey) *Decoder { + return &Decoder{ + verifyKey: verifyKey, + } +} + +// Decode decodes an EAT token (auto-detects format). +func (d *Decoder) Decode(token []byte) (*EATClaims, error) { + // Try to detect format + if isJWT(token) { + return d.decodeJWT(string(token)) + } + return d.decodeCBOR(token) +} + +// isJWT checks if the token is JWT format. +func isJWT(token []byte) bool { + // JWT tokens are base64-encoded strings with dots + if len(token) < 10 { + return false + } + return bytes.Contains(token, []byte(".")) && !bytes.Contains(token[:10], []byte{0x00}) +} + +// decodeJWT decodes a JWT token. +func (d *Decoder) decodeJWT(tokenString string) (*EATClaims, error) { + claims := &jwtClaims{&EATClaims{}} + + var token *jwt.Token + var err error + + if d.verifyKey == nil { + token, _, err = new(jwt.Parser).ParseUnverified(tokenString, claims) + } else { + // Parse and verify JWT + token, err = jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + // Verify signing method + if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return d.verifyKey, nil + }) + } + if err != nil { + return nil, fmt.Errorf("failed to parse JWT: %w", err) + } + + if !token.Valid { + return nil, fmt.Errorf("invalid JWT token") + } + + return claims.EATClaims, nil +} + +// decodeCBOR decodes a CBOR token with COSE signature verification. +func (d *Decoder) decodeCBOR(token []byte) (*EATClaims, error) { + // Try to unmarshal as COSE_Sign1 message + var msg cose.Sign1Message + if err := msg.UnmarshalCBOR(token); err != nil { + // If it's not a COSE message, try to decode as plain CBOR (backward compatibility) + claims := &EATClaims{} + if err := cbor.Unmarshal(token, claims); err != nil { + return nil, fmt.Errorf("failed to decode CBOR: %w", err) + } + return claims, nil + } + + // Verify the signature if we have a verification key + if d.verifyKey != nil { + verifier, err := cose.NewVerifier(cose.AlgorithmES256, d.verifyKey) + if err != nil { + return nil, fmt.Errorf("failed to create COSE verifier: %w", err) + } + + if err := msg.Verify(nil, verifier); err != nil { + return nil, fmt.Errorf("COSE signature verification failed: %w", err) + } + } + + // Decode the payload + claims := &EATClaims{} + if err := cbor.Unmarshal(msg.Payload, claims); err != nil { + return nil, fmt.Errorf("failed to decode CBOR payload: %w", err) + } + + return claims, nil +} + +// DecodeJWT is a convenience function to decode JWT EAT token. +func DecodeJWT(tokenString string, verifyKey *ecdsa.PublicKey) (*EATClaims, error) { + decoder := NewDecoder(verifyKey) + return decoder.decodeJWT(tokenString) +} + +// DecodeCBOR is a convenience function to decode CBOR EAT token. +func DecodeCBOR(token []byte, verifyKey *ecdsa.PublicKey) (*EATClaims, error) { + decoder := NewDecoder(verifyKey) + return decoder.decodeCBOR(token) +} + +// Decode is a convenience function that auto-detects format. +func Decode(token []byte, verifyKey *ecdsa.PublicKey) (*EATClaims, error) { + decoder := NewDecoder(verifyKey) + return decoder.Decode(token) +} + +// MarshalJSON implements json.Marshaler for pretty printing. +func (c *EATClaims) MarshalJSON() ([]byte, error) { + type Alias EATClaims + return json.Marshal(&struct { + *Alias + NonceHex string `json:"eat_nonce_hex,omitempty"` + UEIDHex string `json:"ueid_hex,omitempty"` + MeasurementsHex string `json:"measurements_hex,omitempty"` + }{ + Alias: (*Alias)(c), + NonceHex: fmt.Sprintf("%x", c.Nonce), + UEIDHex: fmt.Sprintf("%x", c.UEID), + MeasurementsHex: fmt.Sprintf("%x", c.Measurements), + }) +} diff --git a/pkg/attestation/eat/decoder_test.go b/pkg/attestation/eat/decoder_test.go new file mode 100644 index 00000000..0c3dd653 --- /dev/null +++ b/pkg/attestation/eat/decoder_test.go @@ -0,0 +1,218 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + "time" + + "github.com/fxamacker/cbor/v2" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/veraison/go-cose" +) + +func TestDecodeJWT(t *testing.T) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + claims := &EATClaims{ + Nonce: []byte("test-nonce"), + } + + now := time.Now() + jwtClaims := &jwtClaims{claims} + claims.Issuer = "test-issuer" + claims.IssuedAt = now.Unix() + claims.ExpiresAt = now.Add(time.Hour).Unix() + + token := jwt.NewWithClaims(jwt.SigningMethodES256, jwtClaims) + signedToken, err := token.SignedString(privateKey) + require.NoError(t, err) + + type args struct { + token string + verifyKey *ecdsa.PublicKey + } + tests := []struct { + name string + args args + wantErr bool + expectedErr string + }{ + { + name: "Valid token", + args: args{ + token: signedToken, + verifyKey: &privateKey.PublicKey, + }, + wantErr: false, + }, + { + name: "Invalid signature", + args: args{ + token: signedToken, + verifyKey: func() *ecdsa.PublicKey { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + return &key.PublicKey + }(), + }, + wantErr: true, + expectedErr: "verification error", + }, + { + name: "Malformed token", + args: args{ + token: "invalid.token.structure", + verifyKey: &privateKey.PublicKey, + }, + wantErr: true, + expectedErr: "failed to parse JWT", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecodeJWT(tt.args.token, tt.args.verifyKey) + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErr != "" { + assert.ErrorContains(t, err, tt.expectedErr) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, claims.Nonce, got.Nonce) + } + }) + } +} + +func TestDecodeCBOR(t *testing.T) { + claims := &EATClaims{ + Nonce: []byte("test-nonce"), + } + + payload, err := cbor.Marshal(claims) + require.NoError(t, err) + + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + signer, err := cose.NewSigner(cose.AlgorithmES256, privateKey) + require.NoError(t, err) + + msg := cose.NewSign1Message() + msg.Payload = payload + err = msg.Sign(rand.Reader, []byte{}, signer) + require.NoError(t, err) + + cborToken, err := msg.MarshalCBOR() + require.NoError(t, err) + + type args struct { + token []byte + verifyKey *ecdsa.PublicKey + } + tests := []struct { + name string + args args + wantErr bool + expectedErr string + }{ + { + name: "Valid COSE token", + args: args{ + token: cborToken, + verifyKey: &privateKey.PublicKey, + }, + wantErr: false, + }, + { + name: "Valid Plain CBOR token (no signature)", + args: args{ + token: payload, + verifyKey: nil, + }, + wantErr: false, + }, + { + name: "Invalid COSE signature", + args: args{ + token: cborToken, + verifyKey: func() *ecdsa.PublicKey { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + return &key.PublicKey + }(), + }, + wantErr: true, + expectedErr: "verification failed", + }, + { + name: "Malformed CBOR", + args: args{ + token: []byte("invalid cbor"), + verifyKey: nil, + }, + wantErr: true, + expectedErr: "failed to decode CBOR", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := DecodeCBOR(tt.args.token, tt.args.verifyKey) + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErr != "" { + assert.ErrorContains(t, err, tt.expectedErr) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, claims.Nonce, got.Nonce) + } + }) + } +} + +func TestDecodeAutoDetect(t *testing.T) { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + claims := &EATClaims{Nonce: []byte("jwt")} + token := jwt.NewWithClaims(jwt.SigningMethodES256, &jwtClaims{claims}) + jwtString, _ := token.SignedString(key) + + got, err := Decode([]byte(jwtString), &key.PublicKey) + assert.NoError(t, err) + assert.Equal(t, []byte("jwt"), got.Nonce) + + claimsCBOR := &EATClaims{Nonce: []byte("cbor")} + cborBytes, _ := cbor.Marshal(claimsCBOR) + gotCBOR, err := Decode(cborBytes, nil) + assert.NoError(t, err) + assert.Equal(t, []byte("cbor"), gotCBOR.Nonce) +} + +func TestIsJWT(t *testing.T) { + tests := []struct { + name string + token []byte + want bool + }{ + {"Empty", []byte{}, false}, + {"JWT like", []byte("header.payload.signature"), true}, + {"CBOR (binary)", []byte{0x00, 0x01}, false}, + {"Text but not JWT", []byte("not a jwt"), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isJWT(tt.token); got != tt.want { + t.Errorf("isJWT() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/attestation/eat/eat.go b/pkg/attestation/eat/eat.go new file mode 100644 index 00000000..a312ad86 --- /dev/null +++ b/pkg/attestation/eat/eat.go @@ -0,0 +1,186 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "errors" + + "github.com/ultravioletrs/cocos/pkg/attestation" +) + +// EATClaims represents the Entity Attestation Token claims following RFC 9711. +type EATClaims struct { + // Standard JWT/CWT claims + Issuer string `json:"iss,omitempty" cbor:"1,keyasint,omitempty"` + Subject string `json:"sub,omitempty" cbor:"2,keyasint,omitempty"` + IssuedAt int64 `json:"iat,omitempty" cbor:"6,keyasint,omitempty"` + ExpiresAt int64 `json:"exp,omitempty" cbor:"4,keyasint,omitempty"` + + // Core EAT claims (RFC 9711) + Nonce []byte `json:"eat_nonce" cbor:"10,keyasint"` // Freshness/replay protection + UEID []byte `json:"ueid" cbor:"256,keyasint"` // Universal Entity ID + OEMID int `json:"oemid,omitempty" cbor:"258,keyasint,omitempty"` // Hardware OEM ID + HWModel []byte `json:"hwmodel,omitempty" cbor:"259,keyasint,omitempty"` // Hardware model + HWVersion string `json:"hwversion,omitempty" cbor:"260,keyasint,omitempty"` // Hardware version + SWName string `json:"swname,omitempty" cbor:"270,keyasint,omitempty"` // Software name + SWVersion string `json:"swversion,omitempty" cbor:"271,keyasint,omitempty"` // Software version + DebugStatus int `json:"dbgstat" cbor:"263,keyasint"` // Debug status + IntUse int `json:"intuse,omitempty" cbor:"262,keyasint,omitempty"` // Intended use + Measurements []byte `json:"measurements" cbor:"265,keyasint"` // Software measurements + + // Platform type indicator + PlatformType string `json:"platform_type"` + + // Submodules for vTPM and other components + Submods map[string]interface{} `json:"submods,omitempty" cbor:"266,keyasint,omitempty"` + + // Platform-specific extensions (custom claims) + SNPExtensions *SNPExtensions `json:"x-cocos-sevsnp,omitempty"` + TDXExtensions *TDXExtensions `json:"x-cocos-tdx,omitempty"` + VTPMExtensions *VTPMExtensions `json:"x-cocos-vtpm,omitempty"` + + // Original binary report (for verification) + RawReport []byte `json:"raw_report,omitempty"` +} + +// SNPExtensions contains AMD SEV-SNP specific claims. +type SNPExtensions struct { + Measurement []byte `json:"measurement"` // SNP MEASUREMENT field + TCB string `json:"tcb"` // TCB version info + PlatformInfo uint64 `json:"platform_info"` // PLATFORM_INFO + Policy uint64 `json:"policy"` // POLICY field + FamilyID []byte `json:"family_id,omitempty"` // Family ID + ImageID []byte `json:"image_id,omitempty"` // Image ID + VMPL int `json:"vmpl,omitempty"` // VM Privilege Level + SignatureAlgo int `json:"signature_algo,omitempty"` // Signature algorithm + CurrentTCB uint64 `json:"current_tcb,omitempty"` // Current TCB + ReportedTCB uint64 `json:"reported_tcb,omitempty"` // Reported TCB + ChipID []byte `json:"chip_id,omitempty"` // Chip ID + CommittedTCB uint64 `json:"committed_tcb,omitempty"` // Committed TCB + LaunchTCB uint64 `json:"launch_tcb,omitempty"` // Launch TCB + Signature []byte `json:"signature,omitempty"` // Signature +} + +// TDXExtensions contains Intel TDX specific claims. +type TDXExtensions struct { + MRTD []byte `json:"tdx_mrtd"` // MRTD measurement + RTMR0 []byte `json:"tdx_rtmr0"` // Runtime measurement register 0 + RTMR1 []byte `json:"tdx_rtmr1"` // Runtime measurement register 1 + RTMR2 []byte `json:"tdx_rtmr2"` // Runtime measurement register 2 + RTMR3 []byte `json:"tdx_rtmr3"` // Runtime measurement register 3 + XFAM uint64 `json:"tdx_xfam"` // Extended features available mask + TDAttributes uint64 `json:"tdx_td_attributes"` // TD attributes + MRConfigID []byte `json:"tdx_mrconfigid,omitempty"` // MR Config ID + MROwner []byte `json:"tdx_mrowner,omitempty"` // MR Owner + MROwnerConfig []byte `json:"tdx_mrownerconfig,omitempty"` // MR Owner Config + MRSEAM []byte `json:"tdx_mrseam,omitempty"` // MR SEAM + TDXModule *TDXModuleInfo `json:"tdx_module,omitempty"` // TDX module info + Signature []byte `json:"tdx_signature,omitempty"` // Quote Signature +} + +// TDXModuleInfo contains TDX module version information. +type TDXModuleInfo struct { + Major uint8 `json:"major"` + Minor uint8 `json:"minor"` + BuildNum uint16 `json:"build_num"` + BuildDate uint32 `json:"build_date"` +} + +// VTPMExtensions contains vTPM specific claims. +type VTPMExtensions struct { + PCRs map[string]string `json:"pcrs"` // PCR values (SHA256/SHA384) + EventLog []byte `json:"event_log,omitempty"` // Event log + Quote []byte `json:"quote,omitempty"` // TPM quote +} + +// DebugStatus constants (RFC 9711 Section 4.2.6). +const ( + DebugEnabled = 0 // Debug is enabled + DebugDisabled = 1 // Debug is disabled + DebugDisabledSinceBoot = 2 // Debug is disabled since boot + DebugPermanentDisable = 3 // Debug is permanently disabled + DebugFullPermanentDisable = 4 // Debug is fully and permanently disabled +) + +// IntUse constants (RFC 9711 Section 4.2.5). +const ( + IntUseGenericFresh = 1 // General purpose, fresh token +) + +// MinNonceLength defines the minimum length for EAT nonce in bytes. +const MinNonceLength = 8 + +// NewEATClaims creates EAT claims from binary attestation report. +func NewEATClaims(report []byte, nonce []byte, platformType attestation.PlatformType) (*EATClaims, error) { + if len(nonce) < MinNonceLength { + return nil, errors.New("eat_nonce must be at least 8 bytes long") + } + claims := &EATClaims{ + Nonce: nonce, + PlatformType: getPlatformTypeName(platformType), + RawReport: report, + DebugStatus: DebugDisabledSinceBoot, // Default to disabled since boot + IntUse: IntUseGenericFresh, // Default to general purpose, fresh token + } + + // Extract platform-specific claims + if err := extractPlatformClaims(claims, report, platformType); err != nil { + return nil, err + } + + return claims, nil +} + +// extractPlatformClaims extracts platform-specific claims from binary report. +func extractPlatformClaims(claims *EATClaims, report []byte, platformType attestation.PlatformType) error { + switch platformType { + case attestation.SNP, attestation.SNPvTPM: + return extractSNPClaims(claims, report) + case attestation.TDX: + return extractTDXClaims(claims, report) + case attestation.VTPM: + return extractVTPMClaims(claims, report) + case attestation.Azure: + return extractAzureClaims(claims, report) + default: + // For unknown platforms, just store the raw report + return nil + } +} + +// getPlatformTypeName converts platform type to string name. +func getPlatformTypeName(platformType attestation.PlatformType) string { + switch platformType { + case attestation.SNP: + return "SNP" + case attestation.TDX: + return "TDX" + case attestation.VTPM: + return "vTPM" + case attestation.SNPvTPM: + return "SNP-vTPM" + case attestation.Azure: + return "Azure" + case attestation.NoCC: + return "NoCC" + default: + return "Unknown" + } +} + +// Sanitize enforces dependency rules for claims. +// HWModel requires OEMID. +// HWVersion requires HWModel. +func (c *EATClaims) Sanitize() { + if c.OEMID == 0 { + c.HWModel = nil + c.HWVersion = "" + } + if len(c.HWModel) == 0 { + c.HWVersion = "" + } + if c.SWName == "" { + c.SWVersion = "" + } +} diff --git a/pkg/attestation/eat/eat_test.go b/pkg/attestation/eat/eat_test.go new file mode 100644 index 00000000..f94cf0a4 --- /dev/null +++ b/pkg/attestation/eat/eat_test.go @@ -0,0 +1,141 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/pkg/attestation" +) + +func TestNewEATClaims(t *testing.T) { + tests := []struct { + name string + nonce []byte + expectedErr string + }{ + { + name: "Valid nonce", + nonce: []byte("12345678"), + expectedErr: "", + }, + { + name: "Nonce too short", + nonce: []byte("1234567"), + expectedErr: "eat_nonce must be at least 8 bytes long", + }, + { + name: "Empty nonce", + nonce: []byte{}, + expectedErr: "eat_nonce must be at least 8 bytes long", + }, + { + name: "Nil nonce", + nonce: nil, + expectedErr: "eat_nonce must be at least 8 bytes long", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewEATClaims([]byte("dummy report"), tt.nonce, attestation.NoCC) + if tt.expectedErr != "" { + assert.EqualError(t, err, tt.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestSanitize(t *testing.T) { + tests := []struct { + name string + claims *EATClaims + expected *EATClaims + }{ + { + name: "All dependencies present", + claims: &EATClaims{ + OEMID: 123, + HWModel: []byte("ValidModel"), + HWVersion: "1.0", + }, + expected: &EATClaims{ + OEMID: 123, + HWModel: []byte("ValidModel"), + HWVersion: "1.0", + }, + }, + { + name: "Missing OEMID clears HWModel and HWVersion", + claims: &EATClaims{ + OEMID: 0, + HWModel: []byte("ValidModel"), + HWVersion: "1.0", + }, + expected: &EATClaims{ + OEMID: 0, + HWModel: nil, + HWVersion: "", + }, + }, + { + name: "Missing HWModel clears HWVersion", + claims: &EATClaims{ + OEMID: 123, + HWModel: nil, + HWVersion: "1.0", + }, + expected: &EATClaims{ + OEMID: 123, + HWModel: nil, + HWVersion: "", + }, + }, + { + name: "Missing HWModel (empty bytes) clears HWVersion", + claims: &EATClaims{ + OEMID: 123, + HWModel: []byte{}, + HWVersion: "1.0", + }, + expected: &EATClaims{ + OEMID: 123, + HWModel: []byte{}, // Should remain empty slice + HWVersion: "", + }, + }, + { + name: "Independent fields unaffected", + claims: &EATClaims{ + OEMID: 0, + DebugStatus: DebugEnabled, + }, + expected: &EATClaims{ + OEMID: 0, + DebugStatus: DebugEnabled, + }, + }, + { + name: "Missing SWName clears SWVersion", + claims: &EATClaims{ + SWName: "", + SWVersion: "1.0.0", + }, + expected: &EATClaims{ + SWName: "", + SWVersion: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.claims.Sanitize() + assert.Equal(t, tt.expected, tt.claims) + }) + } +} diff --git a/pkg/attestation/eat/extractor.go b/pkg/attestation/eat/extractor.go new file mode 100644 index 00000000..a213cef1 --- /dev/null +++ b/pkg/attestation/eat/extractor.go @@ -0,0 +1,150 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "encoding/binary" + "fmt" + + "github.com/google/go-sev-guest/abi" + tdxabi "github.com/google/go-tdx-guest/abi" + tdxpb "github.com/google/go-tdx-guest/proto/tdx" +) + +// OEMID constants (Private Enterprise Numbers). +const ( + OEMID_AMD = 3704 // https://www.iana.org/assignments/enterprise-numbers/?q=Advanced+Micro+Devices + OEMID_INTEL = 343 // https://www.iana.org/assignments/enterprise-numbers/?q=Intel+Corporation + OEMID_MICROSOFT = 311 // https://www.iana.org/assignments/enterprise-numbers/?q=Microsoft+Corporation +) + +// extractSNPClaims extracts AMD SEV-SNP specific claims from binary report. +func extractSNPClaims(claims *EATClaims, report []byte) error { + if len(report) < int(abi.ReportSize) { + return fmt.Errorf("SNP report too small: got %d bytes, want at least %d", len(report), abi.ReportSize) + } + + // Parse SNP report structure + snpReport, err := abi.ReportToProto(report[:abi.ReportSize]) + if err != nil { + return fmt.Errorf("failed to parse SNP report: %w", err) + } + + // Extract SNP-specific fields + claims.SNPExtensions = &SNPExtensions{ + Measurement: snpReport.Measurement, + Policy: snpReport.Policy, + FamilyID: snpReport.FamilyId, + ImageID: snpReport.ImageId, + VMPL: int(snpReport.Vmpl), + SignatureAlgo: int(snpReport.SignatureAlgo), + PlatformInfo: snpReport.PlatformInfo, + ChipID: snpReport.ChipId, + } + + // Set TCB version info + claims.SNPExtensions.CurrentTCB = snpReport.CurrentTcb + claims.SNPExtensions.ReportedTCB = snpReport.ReportedTcb + claims.SNPExtensions.CommittedTCB = snpReport.CommittedTcb + claims.SNPExtensions.LaunchTCB = snpReport.LaunchTcb + claims.SNPExtensions.TCB = fmt.Sprintf("current:%d,reported:%d", snpReport.CurrentTcb, snpReport.ReportedTcb) + + // Set core EAT claims from SNP report + claims.Measurements = snpReport.Measurement + claims.UEID = snpReport.ChipId // Use ChipID as UEID + claims.OEMID = OEMID_AMD // AMD's PEN (Private Enterprise Number) + claims.SNPExtensions.Signature = snpReport.Signature + + // Set hardware model (hash of product name) + claims.HWModel = []byte(fmt.Sprintf("SEV-SNP-%d", snpReport.Version)) + + return nil +} + +// extractTDXClaims extracts Intel TDX specific claims from binary report. +func extractTDXClaims(claims *EATClaims, report []byte) error { + // Parse TDX quote using go-tdx-guest ABI + decodedQuote, err := tdxabi.QuoteToProto(report) + if err != nil { + return fmt.Errorf("failed to parse TDX quote: %w", err) + } + + quoteV4, ok := decodedQuote.(*tdxpb.QuoteV4) + if !ok { + return fmt.Errorf("unsupported TDX quote format") + } + + tdReport := quoteV4.GetTdQuoteBody() + signedData := quoteV4.GetSignedData() + + rtmrs := tdReport.GetRtmrs() + var rtmr0, rtmr1, rtmr2, rtmr3 []byte + if len(rtmrs) > 0 { + rtmr0 = rtmrs[0] + } + if len(rtmrs) > 1 { + rtmr1 = rtmrs[1] + } + if len(rtmrs) > 2 { + rtmr2 = rtmrs[2] + } + if len(rtmrs) > 3 { + rtmr3 = rtmrs[3] + } + + claims.TDXExtensions = &TDXExtensions{ + MRTD: tdReport.GetMrTd(), + RTMR0: rtmr0, + RTMR1: rtmr1, + RTMR2: rtmr2, + RTMR3: rtmr3, + XFAM: binary.LittleEndian.Uint64(tdReport.GetXfam()), + TDAttributes: binary.LittleEndian.Uint64(tdReport.GetTdAttributes()), + MRConfigID: tdReport.GetMrConfigId(), + MROwner: tdReport.GetMrOwner(), + MROwnerConfig: tdReport.GetMrOwnerConfig(), + MRSEAM: tdReport.GetMrSeam(), + Signature: signedData.GetSignature(), + } + + // Set core EAT claims + claims.Measurements = tdReport.GetMrTd() + // Use first 32 bytes of MRTD as UEID, similar to other extractors + if len(claims.Measurements) >= 32 { + claims.UEID = claims.Measurements[:32] + } + claims.OEMID = OEMID_INTEL // Intel's PEN + + // Set hardware model + claims.HWModel = []byte("Intel-TDX") + + return nil +} + +// extractVTPMClaims extracts vTPM specific claims from binary report. +func extractVTPMClaims(claims *EATClaims, report []byte) error { + // vTPM report is typically a marshaled structure containing PCRs and quote + // For now, store the entire report as the quote + claims.VTPMExtensions = &VTPMExtensions{ + Quote: report, + PCRs: make(map[string]string), + } + + // Set core EAT claims + claims.Measurements = report[:32] // Use first 32 bytes as measurement + claims.UEID = report[:16] // Use first 16 bytes as UEID + + return nil +} + +// extractAzureClaims extracts Azure-specific claims from attestation token. +func extractAzureClaims(claims *EATClaims, report []byte) error { + // Azure provides JWT tokens, so the report is already in a structured format + // For now, just store it as raw report + claims.Measurements = report[:32] // Use first 32 bytes as measurement + claims.UEID = report[:16] // Use first 16 bytes as UEID + claims.OEMID = OEMID_MICROSOFT // Microsoft's PEN + + return nil +} diff --git a/pkg/attestation/eat/extractor_test.go b/pkg/attestation/eat/extractor_test.go new file mode 100644 index 00000000..f212f849 --- /dev/null +++ b/pkg/attestation/eat/extractor_test.go @@ -0,0 +1,147 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/google/go-sev-guest/abi" + "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/pkg/attestation" +) + +func TestExtractSNPClaims(t *testing.T) { + validReport := make([]byte, abi.ReportSize) + validReport[0] = 1 + validReport[10] = 0x2 // Policy bit 17 set (byte 2 of Policy, bit 1) + + tests := []struct { + name string + report []byte + wantErr bool + expectedErr string + }{ + { + name: "valid report size (minimal)", + report: validReport, + wantErr: false, + }, + { + name: "report too small", + report: make([]byte, abi.ReportSize-1), + wantErr: true, + expectedErr: "SNP report too small", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := &EATClaims{} + err := extractSNPClaims(claims, tt.report) + if tt.wantErr { + assert.Error(t, err) + if tt.expectedErr != "" { + assert.Contains(t, err.Error(), tt.expectedErr) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, claims.SNPExtensions) + assert.Equal(t, OEMID_AMD, claims.OEMID) + assert.Equal(t, []byte(fmt.Sprintf("SEV-SNP-%d", 1)), claims.HWModel) + } + }) + } +} + +func TestExtractTDXClaims(t *testing.T) { + report := []byte("invalid-tdx-quote") + claims := &EATClaims{} + err := extractTDXClaims(claims, report) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse TDX quote") +} + +func TestTDXExtensionsJSON(t *testing.T) { + ext := &TDXExtensions{ + MRTD: []byte("mrtd_val"), + RTMR0: []byte("rtmr0_val"), + RTMR1: []byte("rtmr1_val"), + RTMR2: []byte("rtmr2_val"), + RTMR3: []byte("rtmr3_val"), + XFAM: 123, + TDAttributes: 456, + TDXModule: &TDXModuleInfo{ + Major: 1, + }, + } + + claims := &EATClaims{ + TDXExtensions: ext, + } + + // Marshal to JSON + data, err := json.Marshal(claims) + assert.NoError(t, err) + + // Verify JSON keys match Intel EAT profile + jsonStr := string(data) + assert.Contains(t, jsonStr, `"tdx_mrtd":"bXJ0ZF92YWw="`) + assert.Contains(t, jsonStr, `"tdx_rtmr0":"cnRtcjBfdmFs"`) // base64 of "rtmr0_val" + assert.Contains(t, jsonStr, `"tdx_rtmr1":"cnRtcjFfdmFs"`) + assert.Contains(t, jsonStr, `"tdx_rtmr2":"cnRtcjJfdmFs"`) + assert.Contains(t, jsonStr, `"tdx_rtmr3":"cnRtcjNfdmFs"`) + assert.Contains(t, jsonStr, `"tdx_xfam":123`) + assert.Contains(t, jsonStr, `"tdx_td_attributes":456`) + assert.Contains(t, jsonStr, `"tdx_module":{"major":1,"minor":0,"build_num":0,"build_date":0}`) +} + +func TestExtractVTPMClaims(t *testing.T) { + report := make([]byte, 32) + copy(report, []byte("vtpm-report-with-enough-length-123")) + + claims := &EATClaims{} + err := extractVTPMClaims(claims, report) + assert.NoError(t, err) + assert.NotNil(t, claims.VTPMExtensions) + assert.Equal(t, report, claims.VTPMExtensions.Quote) + assert.Equal(t, report[:32], claims.Measurements) + assert.Equal(t, report[:16], claims.UEID) +} + +func TestExtractAzureClaims(t *testing.T) { + report := make([]byte, 32) // Needs at least 32 bytes for valid slicing + for i := range report { + report[i] = byte(i) + } + claims := &EATClaims{} + err := extractAzureClaims(claims, report) + assert.NoError(t, err) + assert.Equal(t, report, claims.Measurements) + assert.Equal(t, report[:16], claims.UEID) + assert.Equal(t, OEMID_MICROSOFT, claims.OEMID) +} + +// Platform type helper. +func TestGetPlatformTypeName(t *testing.T) { + tests := []struct { + pt attestation.PlatformType + want string + }{ + {attestation.SNP, "SNP"}, + {attestation.SNPvTPM, "SNP-vTPM"}, + {attestation.TDX, "TDX"}, + {attestation.VTPM, "vTPM"}, + {attestation.Azure, "Azure"}, + {attestation.NoCC, "NoCC"}, + {attestation.PlatformType(999), "Unknown"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + assert.Equal(t, tt.want, getPlatformTypeName(tt.pt)) + }) + } +} diff --git a/pkg/attestation/eat/intuse_test.go b/pkg/attestation/eat/intuse_test.go new file mode 100644 index 00000000..89308ab9 --- /dev/null +++ b/pkg/attestation/eat/intuse_test.go @@ -0,0 +1,21 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/pkg/attestation" +) + +func TestIntUse(t *testing.T) { + report := []byte("dummy-report") + nonce := make([]byte, 8) + + claims, err := NewEATClaims(report, nonce, attestation.NoCC) + assert.NoError(t, err) + + assert.Equal(t, IntUseGenericFresh, claims.IntUse) +} diff --git a/pkg/attestation/eat/jwt_encoder.go b/pkg/attestation/eat/jwt_encoder.go new file mode 100644 index 00000000..8a3670b0 --- /dev/null +++ b/pkg/attestation/eat/jwt_encoder.go @@ -0,0 +1,100 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// JWTEncoder encodes EAT claims to JWT format. +type JWTEncoder struct { + signingKey *ecdsa.PrivateKey + issuer string +} + +// NewJWTEncoder creates a new JWT encoder. +func NewJWTEncoder(signingKey *ecdsa.PrivateKey, issuer string) *JWTEncoder { + return &JWTEncoder{ + signingKey: signingKey, + issuer: issuer, + } +} + +// Encode encodes EAT claims to JWT string. +func (e *JWTEncoder) Encode(claims *EATClaims) (string, error) { + // Set standard JWT claims + now := time.Now() + claims.Issuer = e.issuer + claims.IssuedAt = now.Unix() + claims.ExpiresAt = now.Add(5 * time.Minute).Unix() // 5 minute validity + + // Create JWT token with custom claims + token := jwt.NewWithClaims(jwt.SigningMethodES256, &jwtClaims{claims}) + + // Sign the token + tokenString, err := token.SignedString(e.signingKey) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %w", err) + } + + return tokenString, nil +} + +// jwtClaims wraps EATClaims for JWT encoding. +type jwtClaims struct { + *EATClaims +} + +// GetExpirationTime implements jwt.Claims interface. +func (c *jwtClaims) GetExpirationTime() (*jwt.NumericDate, error) { + if c.ExpiresAt == 0 { + return nil, nil + } + return jwt.NewNumericDate(time.Unix(c.ExpiresAt, 0)), nil +} + +// GetIssuedAt implements jwt.Claims interface. +func (c *jwtClaims) GetIssuedAt() (*jwt.NumericDate, error) { + if c.IssuedAt == 0 { + return nil, nil + } + return jwt.NewNumericDate(time.Unix(c.IssuedAt, 0)), nil +} + +// GetNotBefore implements jwt.Claims interface. +func (c *jwtClaims) GetNotBefore() (*jwt.NumericDate, error) { + return nil, nil +} + +// GetIssuer implements jwt.Claims interface. +func (c *jwtClaims) GetIssuer() (string, error) { + return c.Issuer, nil +} + +// GetSubject implements jwt.Claims interface. +func (c *jwtClaims) GetSubject() (string, error) { + return c.Subject, nil +} + +// GetAudience implements jwt.Claims interface. +func (c *jwtClaims) GetAudience() (jwt.ClaimStrings, error) { + return nil, nil +} + +// EncodeToJWT is a convenience function to encode EAT claims to JWT. +func EncodeToJWT(claims *EATClaims, signingKey *ecdsa.PrivateKey, issuer string) (string, error) { + encoder := NewJWTEncoder(signingKey, issuer) + return encoder.Encode(claims) +} + +// GenerateSigningKey generates a new ECDSA signing key. +func GenerateSigningKey() (*ecdsa.PrivateKey, error) { + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +} diff --git a/pkg/attestation/eat/jwt_encoder_test.go b/pkg/attestation/eat/jwt_encoder_test.go new file mode 100644 index 00000000..14fe154c --- /dev/null +++ b/pkg/attestation/eat/jwt_encoder_test.go @@ -0,0 +1,135 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJWTEncoder_Encode(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + type fields struct { + signingKey *ecdsa.PrivateKey + issuer string + } + type args struct { + claims *EATClaims + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "Valid encoding", + fields: fields{ + signingKey: key, + issuer: "test-issuer", + }, + args: args{ + claims: &EATClaims{ + Nonce: []byte("test-nonce"), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewJWTEncoder(tt.fields.signingKey, tt.fields.issuer) + got, err := e.Encode(tt.args.claims) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, got) + + // Verify the generated token + parsedToken, err := jwt.ParseWithClaims(got, &jwtClaims{&EATClaims{}}, func(token *jwt.Token) (interface{}, error) { + return &key.PublicKey, nil + }) + assert.NoError(t, err) + assert.True(t, parsedToken.Valid) + + claims, ok := parsedToken.Claims.(*jwtClaims) + assert.True(t, ok) + assert.Equal(t, tt.fields.issuer, claims.Issuer) + assert.Equal(t, tt.args.claims.Nonce, claims.Nonce) + } + }) + } +} + +func TestGenerateSigningKey(t *testing.T) { + key, err := GenerateSigningKey() + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, elliptic.P256(), key.Curve) +} + +func TestEncodeToJWT(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + claims := &EATClaims{Nonce: []byte("nonce")} + token, err := EncodeToJWT(claims, key, "issuer") + assert.NoError(t, err) + assert.NotEmpty(t, token) +} + +func TestJwtClaimsINTERFACE(t *testing.T) { + now := time.Now() + claims := &EATClaims{ + Issuer: "iss", + Subject: "sub", + ExpiresAt: now.Add(time.Hour).Unix(), + IssuedAt: now.Unix(), + } + jwtc := &jwtClaims{claims} + + exp, err := jwtc.GetExpirationTime() + assert.NoError(t, err) + assert.Equal(t, claims.ExpiresAt, exp.Unix()) + + iat, err := jwtc.GetIssuedAt() + assert.NoError(t, err) + assert.Equal(t, claims.IssuedAt, iat.Unix()) + + iss, err := jwtc.GetIssuer() + assert.NoError(t, err) + assert.Equal(t, claims.Issuer, iss) + + sub, err := jwtc.GetSubject() + assert.NoError(t, err) + assert.Equal(t, claims.Subject, sub) + + nbf, err := jwtc.GetNotBefore() + assert.NoError(t, err) + assert.Nil(t, nbf) + + aud, err := jwtc.GetAudience() + assert.NoError(t, err) + assert.Nil(t, aud) + + // Test zero values + emptyClaims := &jwtClaims{&EATClaims{}} + exp, err = emptyClaims.GetExpirationTime() + assert.NoError(t, err) + assert.Nil(t, exp) + + iat, err = emptyClaims.GetIssuedAt() + assert.NoError(t, err) + assert.Nil(t, iat) +} diff --git a/pkg/attestation/eat/validator.go b/pkg/attestation/eat/validator.go new file mode 100644 index 00000000..f5a16d1a --- /dev/null +++ b/pkg/attestation/eat/validator.go @@ -0,0 +1,67 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "fmt" + "time" +) + +// ValidateEATClaims validates EAT claims against policy. +func ValidateEATClaims(claims *EATClaims, policy *EATValidationPolicy) error { + if policy == nil { + return nil // No policy, skip validation + } + + // Sanitize claims to enforce dependency rules + claims.Sanitize() + + // Check required claims + for _, requiredClaim := range policy.RequireClaims { + switch requiredClaim { + case "eat_nonce": + if len(claims.Nonce) == 0 { + return fmt.Errorf("missing required claim: eat_nonce") + } + case "measurements": + if len(claims.Measurements) == 0 { + return fmt.Errorf("missing required claim: measurements") + } + case "platform_type": + if claims.PlatformType == "" { + return fmt.Errorf("missing required claim: platform_type") + } + case "ueid": + if len(claims.UEID) == 0 { + return fmt.Errorf("missing required claim: ueid") + } + } + } + + // Check token age + if policy.MaxTokenAgeSeconds > 0 && claims.IssuedAt > 0 { + tokenAge := time.Since(time.Unix(claims.IssuedAt, 0)) + if tokenAge.Seconds() > float64(policy.MaxTokenAgeSeconds) { + return fmt.Errorf("token too old: %v seconds (max: %d)", tokenAge.Seconds(), policy.MaxTokenAgeSeconds) + } + } + + // Check expiration + if claims.ExpiresAt > 0 { + if time.Now().Unix() > claims.ExpiresAt { + return fmt.Errorf("token expired") + } + } + + return nil +} + +// EATValidationPolicy contains validation rules for EAT tokens. +type EATValidationPolicy struct { + RequireEATFormat bool + AllowedFormats []string + MaxTokenAgeSeconds int + RequireClaims []string + VerifySignature bool +} diff --git a/pkg/attestation/eat/validator_test.go b/pkg/attestation/eat/validator_test.go new file mode 100644 index 00000000..2f4e8781 --- /dev/null +++ b/pkg/attestation/eat/validator_test.go @@ -0,0 +1,106 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package eat + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestValidateEATClaims(t *testing.T) { + now := time.Now() + + tests := []struct { + name string + claims *EATClaims + policy *EATValidationPolicy + expectedErr string + }{ + { + name: "Nil policy", + claims: &EATClaims{}, + policy: nil, + }, + { + name: "Valid claims conforming to policy", + claims: &EATClaims{ + Nonce: []byte("nonce"), + Measurements: []byte("meas"), + IssuedAt: now.Unix(), + ExpiresAt: now.Add(time.Hour).Unix(), + }, + policy: &EATValidationPolicy{ + RequireClaims: []string{"eat_nonce", "measurements"}, + MaxTokenAgeSeconds: 300, + }, + }, + { + name: "Missing nonce", + claims: &EATClaims{ + Measurements: []byte("meas"), + }, + policy: &EATValidationPolicy{ + RequireClaims: []string{"eat_nonce"}, + }, + expectedErr: "missing required claim: eat_nonce", + }, + { + name: "Missing measurements", + claims: &EATClaims{ + Nonce: []byte("nonce"), + }, + policy: &EATValidationPolicy{ + RequireClaims: []string{"measurements"}, + }, + expectedErr: "missing required claim: measurements", + }, + { + name: "Missing platform type", + claims: &EATClaims{}, + policy: &EATValidationPolicy{ + RequireClaims: []string{"platform_type"}, + }, + expectedErr: "missing required claim: platform_type", + }, + { + name: "Missing UEID", + claims: &EATClaims{}, + policy: &EATValidationPolicy{ + RequireClaims: []string{"ueid"}, + }, + expectedErr: "missing required claim: ueid", + }, + { + name: "Token too old", + claims: &EATClaims{ + IssuedAt: now.Add(-2 * time.Hour).Unix(), + }, + policy: &EATValidationPolicy{ + MaxTokenAgeSeconds: 3600, // 1 hour max age + }, + expectedErr: "token too old", + }, + { + name: "Token expired", + claims: &EATClaims{ + ExpiresAt: now.Add(-1 * time.Hour).Unix(), + }, + policy: &EATValidationPolicy{}, + expectedErr: "token expired", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateEATClaims(tt.claims, tt.policy) + if tt.expectedErr != "" { + assert.ErrorContains(t, err, tt.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/attestation/mocks/verifier.go b/pkg/attestation/mocks/verifier.go index 4d95db55..131589b9 100644 --- a/pkg/attestation/mocks/verifier.go +++ b/pkg/attestation/mocks/verifier.go @@ -265,3 +265,66 @@ func (_c *Verifier_VerifyAttestation_Call) RunAndReturn(run func(report []byte, _c.Call.Return(run) return _c } + +// VerifyEAT provides a mock function for the type Verifier +func (_mock *Verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error { + ret := _mock.Called(eatToken, teeNonce, vTpmNonce) + + if len(ret) == 0 { + panic("no return value specified for VerifyEAT") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok { + r0 = returnFunc(eatToken, teeNonce, vTpmNonce) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Verifier_VerifyEAT_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEAT' +type Verifier_VerifyEAT_Call struct { + *mock.Call +} + +// VerifyEAT is a helper method to define mock.On call +// - eatToken []byte +// - teeNonce []byte +// - vTpmNonce []byte +func (_e *Verifier_Expecter) VerifyEAT(eatToken interface{}, teeNonce interface{}, vTpmNonce interface{}) *Verifier_VerifyEAT_Call { + return &Verifier_VerifyEAT_Call{Call: _e.mock.On("VerifyEAT", eatToken, teeNonce, vTpmNonce)} +} + +func (_c *Verifier_VerifyEAT_Call) Run(run func(eatToken []byte, teeNonce []byte, vTpmNonce []byte)) *Verifier_VerifyEAT_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []byte + if args[0] != nil { + arg0 = args[0].([]byte) + } + var arg1 []byte + if args[1] != nil { + arg1 = args[1].([]byte) + } + var arg2 []byte + if args[2] != nil { + arg2 = args[2].([]byte) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Verifier_VerifyEAT_Call) Return(err error) *Verifier_VerifyEAT_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Verifier_VerifyEAT_Call) RunAndReturn(run func(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error) *Verifier_VerifyEAT_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/attestation/tdx/tdx.go b/pkg/attestation/tdx/tdx.go index e59ba33c..d8bac151 100644 --- a/pkg/attestation/tdx/tdx.go +++ b/pkg/attestation/tdx/tdx.go @@ -18,6 +18,7 @@ import ( verifytdx "github.com/google/go-tdx-guest/verify" trusttdx "github.com/google/go-tdx-guest/verify/trust" "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" "google.golang.org/protobuf/encoding/protojson" ) @@ -141,6 +142,18 @@ func (v verifier) JSONToPolicy(path string) error { return ReadTDXAttestationPolicy(path, v.Policy) } +// VerifyEAT verifies an EAT token and extracts the binary report for verification. +func (v verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error { + // Decode EAT token + claims, err := eat.Decode(eatToken, nil) + if err != nil { + return fmt.Errorf("failed to decode EAT token: %w", err) + } + + // Verify the embedded binary report + return v.VerifyAttestation(claims.RawReport, teeNonce, vTpmNonce) +} + func ReadTDXAttestationPolicy(policyPath string, policy *checkconfig.Config) error { policyByte, err := os.ReadFile(policyPath) if err != nil { diff --git a/pkg/attestation/tdx/tdx_coverage_test.go b/pkg/attestation/tdx/tdx_coverage_test.go new file mode 100644 index 00000000..0e57c2e5 --- /dev/null +++ b/pkg/attestation/tdx/tdx_coverage_test.go @@ -0,0 +1,69 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package tdx + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" +) + +func TestVerifyEAT_TDX(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + claims := &eat.EATClaims{ + Nonce: []byte("test-nonce"), + IssuedAt: time.Now().Unix(), + RawReport: []byte("dummy-report"), + PlatformType: "TDX", + } + + jwtEncoder := eat.NewJWTEncoder(key, "issuer") + token, err := jwtEncoder.Encode(claims) + require.NoError(t, err) + + vInterface := NewVerifier() + + v, ok := vInterface.(verifier) + require.True(t, ok) + + err = v.VerifyEAT([]byte(token), []byte("tee-nonce"), []byte("vtpm-nonce")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed") +} + +func TestVerifyEAT_TDX_InvalidToken(t *testing.T) { + vInterface := NewVerifier() + v, ok := vInterface.(verifier) + require.True(t, ok) + + err := v.VerifyEAT([]byte("invalid-token"), nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode EAT token") +} + +func TestTeeAttestation_InvalidNonce(t *testing.T) { + p := NewProvider() + + nonce := make([]byte, 64) + _, err := p.TeeAttestation(nonce) + assert.Error(t, err) + // Check for likely errors in non-TDX environment + // Check for likely errors in non-TDX environment + errMsg := err.Error() + assert.True(t, + strings.Contains(errMsg, "no such file or directory") || + strings.Contains(errMsg, "permission denied") || + strings.Contains(errMsg, "failed to open TDX device"), + "unexpected error message: %s", errMsg, + ) +} diff --git a/pkg/attestation/vtpm/vtpm.go b/pkg/attestation/vtpm/vtpm.go index eea40d78..827eca5e 100644 --- a/pkg/attestation/vtpm/vtpm.go +++ b/pkg/attestation/vtpm/vtpm.go @@ -24,6 +24,7 @@ import ( "github.com/google/go-tpm/legacy/tpm2" "github.com/google/go-tpm/tpmutil" "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "golang.org/x/crypto/sha3" "google.golang.org/protobuf/encoding/protojson" @@ -181,10 +182,22 @@ func (v verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce [] return VTPMVerify(report, teeNonce, vTpmNonce, v.writer, v.Policy) } -func (v verifier) JSONToPolicy(path string) error { +func (v *verifier) JSONToPolicy(path string) error { return ReadPolicy(path, v.Policy) } +// VerifyEAT verifies an EAT token and extracts the binary report for verification. +func (v *verifier) VerifyEAT(eatToken []byte, teeNonce []byte, vTpmNonce []byte) error { + // Decode EAT token + claims, err := eat.Decode(eatToken, nil) + if err != nil { + return fmt.Errorf("failed to decode EAT token: %w", err) + } + + // Verify the embedded binary report + return v.VerifyAttestation(claims.RawReport, teeNonce, vTpmNonce) +} + func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool, vmpl uint) ([]byte, error) { attestation, err := FetchQuote(vTPMNonce) if err != nil { diff --git a/pkg/attestation/vtpm/vtpm_coverage_test.go b/pkg/attestation/vtpm/vtpm_coverage_test.go new file mode 100644 index 00000000..83c0ef31 --- /dev/null +++ b/pkg/attestation/vtpm/vtpm_coverage_test.go @@ -0,0 +1,68 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package vtpm + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/attestation/eat" +) + +func TestVerifyEAT(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + claims := &eat.EATClaims{ + Nonce: []byte("test-nonce"), + IssuedAt: time.Now().Unix(), + RawReport: []byte("dummy-report"), // This will be passed to VerifyAttestation + PlatformType: "SNP-vTPM", + } + + jwtEncoder := eat.NewJWTEncoder(key, "issuer") + token, err := jwtEncoder.Encode(claims) + require.NoError(t, err) + + writer := &mockWriter{} + vInterface := NewVerifier(writer) + v, ok := vInterface.(*verifier) + require.True(t, ok) + + err = v.VerifyEAT([]byte(token), []byte("tee-nonce"), []byte("vtpm-nonce")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed") +} + +func TestVerifyEAT_InvalidToken(t *testing.T) { + writer := &mockWriter{} + vInterface := NewVerifier(writer) + v, ok := vInterface.(*verifier) + require.True(t, ok) + + err := v.VerifyEAT([]byte("invalid-token"), nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode EAT token") +} + +func TestProvider_Methods(t *testing.T) { + p := NewProvider(true, 1) + + originalExternalTPM := ExternalTPM + defer func() { ExternalTPM = originalExternalTPM }() + + ExternalTPM = &mockTPM{Buffer: &bytes.Buffer{}} + + _, err := p.VTpmAttestation([]byte("nonce")) + assert.Error(t, err) + + _, err = p.TeeAttestation([]byte("nonce")) + assert.Error(t, err) +} diff --git a/pkg/clients/grpc/attestation/client.go b/pkg/clients/grpc/attestation/client.go index 3f21134f..43a8dc53 100644 --- a/pkg/clients/grpc/attestation/client.go +++ b/pkg/clients/grpc/attestation/client.go @@ -68,7 +68,7 @@ func (c *client) GetAttestation(ctx context.Context, reportData [64]byte, nonce return nil, err } - return resp.Quote, nil + return resp.EatToken, nil } func (c *client) GetAzureToken(ctx context.Context, nonce [32]byte) ([]byte, error) { diff --git a/pkg/clients/grpc/attestation/client_test.go b/pkg/clients/grpc/attestation/client_test.go index ebdd214c..e517010e 100644 --- a/pkg/clients/grpc/attestation/client_test.go +++ b/pkg/clients/grpc/attestation/client_test.go @@ -39,7 +39,7 @@ func (m *mockAttestationServer) FetchAttestation(ctx context.Context, req *attes } return &attestation_v1.AttestationResponse{ - Quote: []byte("mock-attestation-quote"), + EatToken: []byte("mock-attestation-quote"), }, nil } diff --git a/scripts/attestation_policy/sev-snp/attestation_policy.json b/scripts/attestation_policy/sev-snp/attestation_policy.json index b6d2a5c8..4b533f3c 100644 --- a/scripts/attestation_policy/sev-snp/attestation_policy.json +++ b/scripts/attestation_policy/sev-snp/attestation_policy.json @@ -48,5 +48,12 @@ "disallow_network": false, "product": "Milan", "product_line": "Milan" + }, + "eat_validation": { + "require_eat_format": true, + "allowed_formats": ["CBOR", "JWT"], + "max_token_age_seconds": 300, + "require_claims": ["eat_nonce", "measurements", "platform_type"], + "verify_signature": true } } diff --git a/scripts/attestation_policy/sev-snp/attestation_policy_tdx.json b/scripts/attestation_policy/sev-snp/attestation_policy_tdx.json index 327c6be3..4d83ac4e 100644 --- a/scripts/attestation_policy/sev-snp/attestation_policy_tdx.json +++ b/scripts/attestation_policy/sev-snp/attestation_policy_tdx.json @@ -1,25 +1,43 @@ { - "policy": { - "headerPolicy": { - "qeVendorId": "k5pyM/ecTKmUCg2zlX8GBw==" + "policy": { + "headerPolicy": { + "qeVendorId": "k5pyM/ecTKmUCg2zlX8GBw==" }, - "tdQuoteBodyPolicy": { - "minimumTeeTcbSvn": "BgEDAAAAAAAAAAAAAAAAAA==", - "mrSeam": "WzjjOmSHlYtyw8Eqk46qXj/UUQxRruq1jH1ezuQdfENkidbI5PkvFgt8rTQgewDB", - "tdAttributes": "AAAAEAAAAAA=", - "xfam": "5wIGAAAAAAA=", - "mrTd": "kesrRNFB1Ozgnwx1wsU9JHo8aO3X+v6KNSDJQqYEpAfeA65txfh/J0KLJTiHMRi3", - "rtmrs": [ + "tdQuoteBodyPolicy": { + "minimumTeeTcbSvn": "BgEDAAAAAAAAAAAAAAAAAA==", + "mrSeam": "WzjjOmSHlYtyw8Eqk46qXj/UUQxRruq1jH1ezuQdfENkidbI5PkvFgt8rTQgewDB", + "tdAttributes": "AAAAEAAAAAA=", + "xfam": "5wIGAAAAAAA=", + "mrTd": "kesrRNFB1Ozgnwx1wsU9JHo8aO3X+v6KNSDJQqYEpAfeA65txfh/J0KLJTiHMRi3", + "rtmrs": [ "TP/tWJG9nf1AuPrfS7mKBpBw05ffiZHYnbu01Tjr8cKeG+lNDwuxder+DJxTSSqW", "fxoATOAep76VY2mWwKB4XWWoQqgJZNYdiHXJk14DN2iKJP5tg8AoeRoGhxJg2BO3", "fYilkkTRM83nhg1ZUY4WsULRfwyN3v2rcv5+wbSl9Rro1zqhcPMCeCCcL/CCAUqx", "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" - ] + ], + "policy": { + "mr_seam": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "tdx_module": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "mr_td": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=" + }, + "eat_validation": { + "require_eat_format": true, + "allowed_formats": [ + "CBOR", + "JWT" + ], + "max_token_age_seconds": 300, + "require_claims": [ + "eat_nonce", + "measurements", + "platform_type" + ], + "verify_signature": true + } } }, - "rootOfTrust": { - "checkCrl": true, - "getCollateral": true + "rootOfTrust": { + "checkCrl": true, + "getCollateral": true } -} - +} \ No newline at end of file