diff --git a/agent/api/grpc/requests.go b/agent/api/grpc/requests.go index 36ac45c1..6363d48b 100644 --- a/agent/api/grpc/requests.go +++ b/agent/api/grpc/requests.go @@ -6,7 +6,6 @@ import ( "errors" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) @@ -42,7 +41,7 @@ func (req resultReq) validate() error { } type attestationReq struct { - TeeNonce [quoteprovider.Nonce]byte + TeeNonce [vtpm.SEVNonce]byte VtpmNonce [vtpm.Nonce]byte AttType attestation.PlatformType } diff --git a/agent/api/grpc/server.go b/agent/api/grpc/server.go index 9159d1fb..f7f9b685 100644 --- a/agent/api/grpc/server.go +++ b/agent/api/grpc/server.go @@ -14,7 +14,6 @@ import ( "github.com/go-kit/kit/transport/grpc" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -134,7 +133,7 @@ func encodeResultResponse(_ context.Context, response any) (any, error) { func validateNonce(nonce []byte, maxLen int, target any) error { if len(nonce) > maxLen { switch maxLen { - case quoteprovider.Nonce: + case vtpm.SEVNonce: return ErrTEENonceLength case vtpm.Nonce: return ErrVTPMNonceLength @@ -144,7 +143,7 @@ func validateNonce(nonce []byte, maxLen int, target any) error { } switch t := target.(type) { - case *[quoteprovider.Nonce]byte: + case *[vtpm.SEVNonce]byte: copy(t[:], nonce) case *[vtpm.Nonce]byte: copy(t[:], nonce) @@ -156,10 +155,10 @@ func validateNonce(nonce []byte, maxLen int, target any) error { func decodeAttestationRequest(_ context.Context, grpcReq any) (any, error) { req := grpcReq.(*agent.AttestationRequest) - var reportData [quoteprovider.Nonce]byte + var reportData [vtpm.SEVNonce]byte var nonce [vtpm.Nonce]byte - if err := validateNonce(req.TeeNonce, quoteprovider.Nonce, &reportData); err != nil { + if err := validateNonce(req.TeeNonce, vtpm.SEVNonce, &reportData); err != nil { return nil, err } diff --git a/agent/api/grpc/server_test.go b/agent/api/grpc/server_test.go index e788ff20..cf00dd39 100644 --- a/agent/api/grpc/server_test.go +++ b/agent/api/grpc/server_test.go @@ -12,7 +12,6 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/mocks" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -229,7 +228,7 @@ func TestAttestation(t *testing.T) { return len(resp.File) > 0 })).Return(nil).Once() - reportData := [quoteprovider.Nonce]byte{} + reportData := [vtpm.SEVNonce]byte{} vtpmNonce := [vtpm.Nonce]byte{} attestationType := attestation.SNP mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil) @@ -298,8 +297,8 @@ func TestValidateNonce(t *testing.T) { }{ { name: "valid TEE nonce", - nonce: make([]byte, quoteprovider.Nonce), - maxLen: quoteprovider.Nonce, + nonce: make([]byte, vtpm.SEVNonce), + maxLen: vtpm.SEVNonce, shouldError: false, }, { @@ -310,8 +309,8 @@ func TestValidateNonce(t *testing.T) { }, { name: "TEE nonce too long", - nonce: make([]byte, quoteprovider.Nonce+1), - maxLen: quoteprovider.Nonce, + nonce: make([]byte, vtpm.SEVNonce+1), + maxLen: vtpm.SEVNonce, shouldError: true, expectedErr: ErrTEENonceLength, }, @@ -326,8 +325,8 @@ func TestValidateNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.maxLen == quoteprovider.Nonce { - var target [quoteprovider.Nonce]byte + if tt.maxLen == vtpm.SEVNonce { + var target [vtpm.SEVNonce]byte err := validateNonce(tt.nonce, tt.maxLen, &target) if tt.shouldError { assert.Error(t, err) @@ -388,7 +387,7 @@ func TestEncodeResultResponse(t *testing.T) { } func TestDecodeAttestationRequest(t *testing.T) { - teeNonce := make([]byte, quoteprovider.Nonce) + teeNonce := make([]byte, vtpm.SEVNonce) vtpmNonce := make([]byte, vtpm.Nonce) req := &agent.AttestationRequest{ @@ -406,7 +405,7 @@ func TestDecodeAttestationRequest(t *testing.T) { func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) { // Test with TEE nonce too long - teeNonce := make([]byte, quoteprovider.Nonce+1) + teeNonce := make([]byte, vtpm.SEVNonce+1) req := &agent.AttestationRequest{TeeNonce: teeNonce} _, err := decodeAttestationRequest(context.Background(), req) diff --git a/agent/api/logging.go b/agent/api/logging.go index 30aa77b2..3d9fc839 100644 --- a/agent/api/logging.go +++ b/agent/api/logging.go @@ -13,7 +13,6 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) @@ -105,7 +104,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e return lm.svc.Result(ctx) } -func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) { +func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin)) if err != nil { diff --git a/agent/api/metrics.go b/agent/api/metrics.go index 4aa2557f..b4dc1f2e 100644 --- a/agent/api/metrics.go +++ b/agent/api/metrics.go @@ -12,7 +12,6 @@ import ( "github.com/go-kit/kit/metrics" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) @@ -91,7 +90,7 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) { return ms.svc.Result(ctx) } -func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { +func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { defer func(begin time.Time) { ms.counter.With("method", "attestation").Add(1) ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds()) diff --git a/agent/service.go b/agent/service.go index 84db1f70..b5573561 100644 --- a/agent/service.go +++ b/agent/service.go @@ -21,7 +21,6 @@ import ( "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner" @@ -120,7 +119,7 @@ type Service interface { Algo(ctx context.Context, algorithm Algorithm) error Data(ctx context.Context, dataset Dataset) error Result(ctx context.Context) ([]byte, error) - Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) + Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) IMAMeasurements(ctx context.Context) ([]byte, []byte, error) AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error) State() string @@ -418,7 +417,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) { return as.result, as.runError } -func (as *agentService) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { +func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) { rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType) if err != nil { return []byte{}, errors.Wrap(ErrAttestationFailed, err) diff --git a/agent/service_test.go b/agent/service_test.go index 62f7c9ed..3796751c 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -24,7 +24,6 @@ import ( "github.com/ultravioletrs/cocos/agent/statemachine" smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks" "golang.org/x/crypto/sha3" @@ -336,7 +335,7 @@ func TestAttestation(t *testing.T) { cases := []struct { name string - reportData [quoteprovider.Nonce]byte + reportData [vtpm.SEVNonce]byte nonce [vtpm.Nonce]byte rawQuote []uint8 platform attestation.PlatformType @@ -457,8 +456,8 @@ func TestAzureAttestationToken(t *testing.T) { } } -func generateReportData() [quoteprovider.Nonce]byte { - bytes := make([]byte, quoteprovider.Nonce) +func generateReportData() [vtpm.SEVNonce]byte { + bytes := make([]byte, vtpm.SEVNonce) _, err := rand.Read(bytes) if err != nil { log.Fatalf("Failed to generate random bytes: %v", err) diff --git a/algorithm.lin-reg-py.enc b/algorithm.lin-reg-py.enc new file mode 100644 index 00000000..3e2a8f20 Binary files /dev/null and b/algorithm.lin-reg-py.enc differ diff --git a/cli/attestation.go b/cli/attestation.go index a05934fa..08140f43 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -21,7 +21,6 @@ import ( "github.com/spf13/pflag" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/tdx" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/protobuf/encoding/prototext" @@ -171,10 +170,10 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { return } - var fixedReportData [quoteprovider.Nonce]byte + var fixedReportData [vtpm.SEVNonce]byte if attType == attestation.SNP || attType == attestation.SNPvTPM { - if len(teeNonce) > quoteprovider.Nonce { - msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", quoteprovider.Nonce) + if len(teeNonce) > vtpm.SEVNonce { + msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", vtpm.SEVNonce) cmd.Println(msg) return } diff --git a/cli/attestation_test.go b/cli/attestation_test.go index 2379b439..3eb4439e 100644 --- a/cli/attestation_test.go +++ b/cli/attestation_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk/mocks" ) @@ -37,8 +36,8 @@ func TestNewAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOut(&buf) - reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce) - mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil) + reportData := bytes.Repeat([]byte{0x01}, vtpm.SEVNonce) + mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(reportData), mock.Anything).Return(nil) cmd.SetArgs([]string{hex.EncodeToString(reportData)}) err := cmd.Execute() @@ -50,7 +49,7 @@ func TestNewGetAttestationCmd(t *testing.T) { validattestation, err := os.ReadFile("../attestation.bin") require.NoError(t, err) - teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)) + teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)) vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce)) tokenNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce)) @@ -184,7 +183,7 @@ func TestNewGetAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOut(&buf) - mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { + mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { _, err := args.Get(4).(*os.File).Write(tc.mockResponse) require.NoError(t, err) }) @@ -891,7 +890,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) { }, { name: "TEE nonce too large", - args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))}, + args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce+1))}, setupMock: func(sdk *mocks.SDK) { }, expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes", @@ -912,7 +911,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) { }, { name: "successful TDX attestation", - args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))}, + args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))}, setupMock: func(sdk *mocks.SDK) { sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil).Run(func(args mock.Arguments) { @@ -925,7 +924,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) { }, { name: "file creation error", - args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))}, + args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))}, setupMock: func(sdk *mocks.SDK) { }, expectedErr: "Error creating attestation file", @@ -1380,7 +1379,7 @@ func TestContextCancellation(t *testing.T) { cmd.SetOut(&buf) cmd.SetErr(&buf) - teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)) + teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)) cmd.SetArgs([]string{"snp", "--tee", teeNonceHex}) err := cmd.Execute() diff --git a/cmd/attestation-service/main.go b/cmd/attestation-service/main.go index 9e2488b8..5c9ff527 100644 --- a/cmd/attestation-service/main.go +++ b/cmd/attestation-service/main.go @@ -18,7 +18,6 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/tdx" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/sync/errgroup" @@ -89,7 +88,7 @@ func main() { } if ccPlatform == attestation.SNP || ccPlatform == attestation.SNPvTPM { - if err := quoteprovider.FetchCertificates(uint(cfg.Vmpl)); err != nil { + if err := vtpm.FetchSEVCertificates(uint(cfg.Vmpl)); err != nil { logger.Error(fmt.Sprintf("failed to fetch certificates: %s", err)) exitCode = 1 return diff --git a/encryption.key b/encryption.key new file mode 100644 index 00000000..840cae66 --- /dev/null +++ b/encryption.key @@ -0,0 +1 @@ +bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3 diff --git a/kbs-admin.key b/kbs-admin.key new file mode 100644 index 00000000..62418798 --- /dev/null +++ b/kbs-admin.key @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY +-----END PRIVATE KEY----- diff --git a/kbs-admin.pub b/kbs-admin.pub new file mode 100644 index 00000000..8fc02406 --- /dev/null +++ b/kbs-admin.pub @@ -0,0 +1,3 @@ +-----BEGIN PUBLIC KEY----- +MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw= +-----END PUBLIC KEY----- diff --git a/pkg/attestation/azure/snp.go b/pkg/attestation/azure/snp.go index d2ab3dad..a7ee9707 100644 --- a/pkg/attestation/azure/snp.go +++ b/pkg/attestation/azure/snp.go @@ -21,7 +21,6 @@ import ( "github.com/google/go-tpm-tools/proto/attest" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/protobuf/proto" ) @@ -130,7 +129,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error { return errors.Wrap(fmt.Errorf("failed to convert TEE report to proto"), err) } - return quoteprovider.VerifyAttestationReportTLS(attestationReport, teeNonce, a.Policy) + return vtpm.VerifySEVAttestationReportTLS(attestationReport, teeNonce, a.Policy) } func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error { @@ -148,7 +147,7 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce [] } snpReport := quote.GetSevSnpAttestation() - if err = quoteprovider.VerifyAttestationReportTLS(snpReport, nil, a.Policy); err != nil { + if err = vtpm.VerifySEVAttestationReportTLS(snpReport, nil, a.Policy); err != nil { return fmt.Errorf("failed to verify vTPM attestation report: %w", err) } @@ -266,7 +265,7 @@ func GenerateAttestationPolicy(token, product string, policy uint64) (*attestati return nil, fmt.Errorf("failed to decode reportID: %w", err) } - sevSnpProduct := quoteprovider.GetProductName(product) + sevSnpProduct := vtpm.GetSEVProductName(product) return &attestation.Config{ Config: &check.Config{ diff --git a/pkg/attestation/quoteprovider/sev_test.go b/pkg/attestation/quoteprovider/sev_test.go deleted file mode 100644 index 975cb10f..00000000 --- a/pkg/attestation/quoteprovider/sev_test.go +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -//go:build !embed - -package quoteprovider - -import ( - "os" - "path" - "testing" - - "github.com/google/go-sev-guest/proto/check" - "github.com/google/go-sev-guest/proto/sevsnp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestFillInAttestationLocal(t *testing.T) { - originalHome := os.Getenv("HOME") - defer func() { - os.Setenv("HOME", originalHome) - }() - - tempDir, err := os.MkdirTemp("", "test_home") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - os.Setenv("HOME", tempDir) - - cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan) - err = os.MkdirAll(cocosDir, 0o755) - require.NoError(t, err) - - bundleContent := []byte("mock ASK ARK bundle") - bundlePath := path.Join(cocosDir, arkAskBundleName) - err = os.WriteFile(bundlePath, bundleContent, 0o644) - require.NoError(t, err) - - config := &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductMilan, - }, - Policy: &check.Policy{}, - } - - tests := []struct { - name string - attestation *sevsnp.Attestation - setupFunc func() - expectedError bool - errorContains string - }{ - { - name: "Empty attestation - creates new chain", - attestation: &sevsnp.Attestation{ - CertificateChain: nil, - }, - setupFunc: func() {}, - expectedError: true, - errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block", - }, - { - name: "Attestation with existing chain - no changes needed", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("existing ASK cert"), - ArkCert: []byte("existing ARK cert"), - }, - }, - setupFunc: func() {}, - expectedError: false, - }, - { - name: "Attestation with empty chain - tries to load from file", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - setupFunc: func() {}, - expectedError: true, - errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block", - }, - { - name: "No bundle file exists - no error", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - setupFunc: func() { - os.Remove(bundlePath) - }, - expectedError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - os.Setenv("HOME", tempDir) - if _, err := os.Stat(bundlePath); os.IsNotExist(err) { - if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil { - t.Fatalf("Failed to write bundle file: %v", err) - } - } - - tt.setupFunc() - - err := fillInAttestationLocal(tt.attestation, config) - - if tt.expectedError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestGetProductName(t *testing.T) { - tests := []struct { - name string - product string - expected sevsnp.SevProduct_SevProductName - }{ - { - name: "Milan product", - product: sevSnpProductMilan, - expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN, - }, - { - name: "Genoa product", - product: sevSnpProductGenoa, - expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA, - }, - { - name: "Unknown product", - product: "UnknownProduct", - expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN, - }, - { - name: "Empty product", - product: "", - expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN, - }, - { - name: "Case sensitive - milan lowercase", - product: "milan", - expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GetProductName(tt.product) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestVerifyReport(t *testing.T) { - tests := []struct { - name string - attestation *sevsnp.Attestation - config *check.Config - expectedError bool - errorContains string - }{ - { - name: "Invalid product line", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: "InvalidProduct", - }, - Policy: &check.Policy{}, - }, - expectedError: true, - errorContains: "product name must be", - }, - { - name: "Valid Milan product line", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("mock ask cert"), - ArkCert: []byte("mock ark cert"), - }, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductMilan, - }, - Policy: &check.Policy{}, - }, - expectedError: true, - errorContains: "attestation verification failed", - }, - { - name: "Valid Genoa product line", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("mock ask cert"), - ArkCert: []byte("mock ark cert"), - }, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductGenoa, - }, - Policy: &check.Policy{}, - }, - expectedError: true, - errorContains: "attestation verification failed", - }, - { - name: "Config with existing product policy", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{ - AskCert: []byte("mock ask cert"), - ArkCert: []byte("mock ark cert"), - }, - }, - config: &check.Config{ - RootOfTrust: &check.RootOfTrust{ - ProductLine: sevSnpProductMilan, - }, - Policy: &check.Policy{ - Product: &sevsnp.SevProduct{ - Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN, - }, - }, - }, - expectedError: true, - errorContains: "attestation verification failed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := verifyReport(tt.attestation, tt.config) - - if tt.expectedError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestValidateReport(t *testing.T) { - tests := []struct { - name string - attestation *sevsnp.Attestation - config *check.Config - expectedError bool - errorContains string - }{ - { - name: "Basic validation test", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - config: &check.Config{ - Policy: &check.Policy{ - Policy: 196608, - }, - }, - expectedError: true, - errorContains: "attestation validation failed", - }, - { - name: "Validation with report data", - attestation: &sevsnp.Attestation{ - CertificateChain: &sevsnp.CertificateChain{}, - }, - config: &check.Config{ - Policy: &check.Policy{ - Policy: 196608, - ReportData: []byte("test report datatest report datatest report datatest report data"), - }, - }, - expectedError: true, - errorContains: "attestation validation failed", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validateReport(tt.attestation, tt.config) - - if tt.expectedError { - assert.Error(t, err) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestFetchAttestation(t *testing.T) { - tests := []struct { - name string - reportData []byte - vmpl uint - expectedError bool - errorContains string - }{ - { - name: "Report data too large", - reportData: make([]byte, Nonce+1), - vmpl: 0, - expectedError: true, - errorContains: "could not get quote provider", - }, - { - name: "Valid report data size", - reportData: make([]byte, 32), - vmpl: 0, - expectedError: true, - errorContains: "could not get quote provider", - }, - { - name: "Maximum valid report data size", - reportData: make([]byte, Nonce), - vmpl: 1, - expectedError: true, - errorContains: "could not get quote provider", - }, - { - name: "Empty report data", - reportData: []byte{}, - vmpl: 0, - expectedError: true, - errorContains: "could not get quote provider", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := FetchAttestation(tt.reportData, tt.vmpl) - - if tt.expectedError { - assert.Error(t, err) - assert.Empty(t, result) - if tt.errorContains != "" { - assert.Contains(t, err.Error(), tt.errorContains) - } - } else { - assert.NoError(t, err) - assert.NotEmpty(t, result) - } - }) - } -} - -func TestGetLeveledQuoteProvider(t *testing.T) { - t.Run("GetLeveledQuoteProvider call", func(t *testing.T) { - provider, err := GetLeveledQuoteProvider() - - if err != nil { - assert.Error(t, err) - assert.Nil(t, provider) - } else { - assert.NoError(t, err) - assert.NotNil(t, provider) - } - }) -} diff --git a/pkg/attestation/quoteprovider/sev.go b/pkg/attestation/vtpm/sev.go similarity index 78% rename from pkg/attestation/quoteprovider/sev.go rename to pkg/attestation/vtpm/sev.go index 5dcf437c..b23b2ded 100644 --- a/pkg/attestation/quoteprovider/sev.go +++ b/pkg/attestation/vtpm/sev.go @@ -3,7 +3,7 @@ //go:build !embed -package quoteprovider +package vtpm import ( "crypto/rand" @@ -32,7 +32,7 @@ const ( cocosDirectory = "/cocos" arkAskBundleName = "ask_ark.pem" vcekName = "vcek.pem" - Nonce = 64 + SEVNonce = 64 sevSnpProductMilan = "Milan" sevSnpProductGenoa = "Genoa" ) @@ -43,9 +43,9 @@ var ( ) var ( - ErrProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa)) - ErrAttVerification = errors.New("attestation verification failed") - errAttValidation = errors.New("attestation validation failed") + ErrSEVProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa)) + ErrSEVAttVerification = errors.New("attestation verification failed") + errSEVAttValidation = errors.New("attestation validation failed") ) func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error { @@ -77,16 +77,17 @@ func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) return nil } +// verifyReport verifies the SEV-SNP attestation report. func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error { sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust) if err != nil { - return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrAttVerification, err)) + return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrSEVAttVerification, err)) } if cfg.Policy.Product == nil { - productName := GetProductName(cfg.RootOfTrust.ProductLine) + productName := GetSEVProductName(cfg.RootOfTrust.ProductLine) if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN { - return ErrProductLine + return ErrSEVProductLine } sopts.Product = &sevsnp.SevProduct{ @@ -107,30 +108,33 @@ func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error { } if err := verify.SnpAttestation(attestationPB, sopts); err != nil { - return errors.Wrap(ErrAttVerification, err) + return errors.Wrap(ErrSEVAttVerification, err) } return nil } +// validateReport validates the SEV-SNP attestation report against policy. func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error { opts, err := validate.PolicyToOptions(cfg.Policy) if err != nil { - return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrAttVerification, err)) + return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrSEVAttVerification, err)) } if err = validate.SnpAttestation(attestationPB, opts); err != nil { - return errors.Wrap(errAttValidation, err) + return errors.Wrap(errSEVAttValidation, err) } return nil } -func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { +// getLeveledQuoteProvider returns a leveled quote provider for SEV-SNP. +func getLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { return client.GetLeveledQuoteProvider() } -func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error { +// VerifySEVAttestationReportTLS verifies a SEV-SNP attestation report for TLS (exported for azure package). +func VerifySEVAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error { config := policy.Config // Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report. @@ -141,10 +145,11 @@ func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData [] config.Policy.ReportData = reportData[:] } - return VerifyAndValidate(attestationPB, config) + return verifySEVAndValidate(attestationPB, config) } -func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error { +// verifySEVAndValidate performs both verification and validation of a SEV-SNP attestation. +func verifySEVAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error { logger.Init("", false, false, io.Discard) if err := verifyReport(attestationPB, cfg); err != nil { @@ -158,15 +163,16 @@ func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) err return nil } -func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) { - var reportData [Nonce]byte +// fetchSEVAttestation fetches a SEV-SNP attestation report. +func fetchSEVAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) { + var reportData [SEVNonce]byte - qp, err := GetLeveledQuoteProvider() + qp, err := getLeveledQuoteProvider() if err != nil { return []byte{}, fmt.Errorf("could not get quote provider") } - if len(reportData) > Nonce { + if len(reportData) > SEVNonce { return []byte{}, fmt.Errorf("attestation report size mismatch") } copy(reportData[:], reportDataSlice) @@ -206,7 +212,8 @@ func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) { return result, nil } -func GetProductName(product string) sevsnp.SevProduct_SevProductName { +// GetSEVProductName maps a product string to a SEV product name. +func GetSEVProductName(product string) sevsnp.SevProduct_SevProductName { switch product { case sevSnpProductMilan: return sevsnp.SevProduct_SEV_PRODUCT_MILAN @@ -217,6 +224,7 @@ func GetProductName(product string) sevsnp.SevProduct_SevProductName { } } +// derToPem converts DER-encoded certificate to PEM format. func derToPem(der []byte) []byte { // Try to parse to make sure it's a certificate if _, err := x509.ParseCertificate(der); err != nil { @@ -226,15 +234,16 @@ func derToPem(der []byte) []byte { return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) } -func FetchCertificates(vmpl uint) error { - var reportData [Nonce]byte +// FetchSEVCertificates fetches SEV-SNP certificates from KDS. +func FetchSEVCertificates(vmpl uint) error { + var reportData [SEVNonce]byte - qp, err := GetLeveledQuoteProvider() + qp, err := getLeveledQuoteProvider() if err != nil { return fmt.Errorf("could not get quote provider") } - if len(reportData) > Nonce { + if len(reportData) > SEVNonce { return fmt.Errorf("attestation report size mismatch") } diff --git a/pkg/attestation/vtpm/vtpm.go b/pkg/attestation/vtpm/vtpm.go index 827eca5e..a777d671 100644 --- a/pkg/attestation/vtpm/vtpm.go +++ b/pkg/attestation/vtpm/vtpm.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-tpm/tpmutil" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/eat" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "golang.org/x/crypto/sha3" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/prototext" @@ -120,7 +119,7 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error) } func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) { - return quoteprovider.FetchAttestation(teeNonce, v.vmpl) + return fetchSEVAttestation(teeNonce, v.vmpl) } func (v provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) { @@ -171,7 +170,7 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error { } attestationReport := sevsnp.Attestation{Report: attestReport, CertificateChain: nil} - return quoteprovider.VerifyAttestationReportTLS(&attestationReport, teeNonce, v.Policy) + return VerifySEVAttestationReportTLS(&attestationReport, teeNonce, v.Policy) } func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error { @@ -234,7 +233,7 @@ func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Write attestData := sha3.Sum512(nonce) - if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil { + if err := VerifySEVAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil { return fmt.Errorf("failed to verify TEE attestation report: %v", err) } @@ -336,7 +335,7 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint) attestData := sha3.Sum512(teeNonce) - rawTeeAttestation, err := quoteprovider.FetchAttestation(attestData[:], vmpl) + rawTeeAttestation, err := fetchSEVAttestation(attestData[:], vmpl) if err != nil { return fmt.Errorf("failed to fetch TEE attestation report: %v", err) } diff --git a/pkg/attestation/vtpm/vtpm_test.go b/pkg/attestation/vtpm/vtpm_test.go index 955664e9..16043a48 100644 --- a/pkg/attestation/vtpm/vtpm_test.go +++ b/pkg/attestation/vtpm/vtpm_test.go @@ -22,12 +22,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/pkg/attestation" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "google.golang.org/protobuf/encoding/protojson" ) -const sevSnpProductMilan = "Milan" - var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}} type mockTPM struct { @@ -679,13 +676,13 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) { name: "Valid attestation, distorted signature", attestationReport: attestationPB, reportData: reportData, - err: quoteprovider.ErrAttVerification, + err: ErrSEVAttVerification, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } @@ -713,13 +710,13 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) { name: "Valid attestation, unknown product", attestationReport: attestationPB, reportData: reportData, - err: quoteprovider.ErrProductLine, + err: ErrSEVProductLine, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } @@ -752,7 +749,7 @@ func TestVerifyAttestationReportSuccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } @@ -780,13 +777,13 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) { name: "Valid attestation, malformed policy (measurement)", attestationReport: attestationPB, reportData: reportData, - err: quoteprovider.ErrAttVerification, + err: ErrSEVAttVerification, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) + err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy) assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err)) }) } diff --git a/pkg/clients/doc.go b/pkg/clients/doc.go index 1b921bd5..e949c6e2 100644 --- a/pkg/clients/doc.go +++ b/pkg/clients/doc.go @@ -2,5 +2,5 @@ // SPDX-License-Identifier: Apache-2.0 // Package clients contains the domain concept definitions needed to support -// HTTP/gRPC Client functionality. +// client configuration for HTTP/gRPC connections. package clients diff --git a/pkg/clients/grpc/agent/agent.go b/pkg/clients/grpc/agent/agent.go index 6408282c..c98d5643 100644 --- a/pkg/clients/grpc/agent/agent.go +++ b/pkg/clients/grpc/agent/agent.go @@ -9,6 +9,7 @@ import ( "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/clients" "github.com/ultravioletrs/cocos/pkg/clients/grpc" + "github.com/ultravioletrs/cocos/pkg/tls" grpchealth "google.golang.org/grpc/health/grpc_health_v1" ) @@ -21,7 +22,7 @@ func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc return nil, nil, err } - if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() { + if client.Secure() != tls.WithMATLS.String() && client.Secure() != tls.WithATLS.String() && client.Secure() != tls.WithTLS.String() { health := grpchealth.NewHealthClient(client.Connection()) resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{ Service: "agent", diff --git a/pkg/clients/grpc/connect_test.go b/pkg/clients/grpc/connect_test.go index 2a3d38a4..f3fe8e7b 100644 --- a/pkg/clients/grpc/connect_test.go +++ b/pkg/clients/grpc/connect_test.go @@ -22,6 +22,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" ) func TestNewClient(t *testing.T) { @@ -119,7 +120,7 @@ func TestNewClient(t *testing.T) { ServerCAFile: "nonexistent.pem", }, wantErr: true, - err: clients.ErrFailedToLoadRootCA, + err: tls.ErrFailedToLoadRootCA, }, { name: "Fail with invalid ClientCert", @@ -130,7 +131,7 @@ func TestNewClient(t *testing.T) { ClientKey: clientKeyFile, }, wantErr: true, - err: clients.ErrFailedToLoadClientCertKey, + err: tls.ErrFailedToLoadClientCertKey, }, { name: "Fail with invalid ClientKey", @@ -141,7 +142,7 @@ func TestNewClient(t *testing.T) { ClientKey: "nonexistent.pem", }, wantErr: true, - err: clients.ErrFailedToLoadClientCertKey, + err: tls.ErrFailedToLoadClientCertKey, }, } @@ -169,32 +170,32 @@ func TestNewClient(t *testing.T) { func TestClientSecure(t *testing.T) { tests := []struct { name string - secure clients.Security + secure tls.Security expected string }{ { name: "Without TLS", - secure: clients.WithoutTLS, + secure: tls.WithoutTLS, expected: "without TLS", }, { name: "With TLS", - secure: clients.WithTLS, + secure: tls.WithTLS, expected: "with TLS", }, { name: "With mTLS", - secure: clients.WithMTLS, + secure: tls.WithMTLS, expected: "with mTLS", }, { name: "With aTLS", - secure: clients.WithATLS, + secure: tls.WithATLS, expected: "with aTLS", }, { name: "With maTLS", - secure: clients.WithMATLS, + secure: tls.WithMATLS, expected: "with maTLS", }, } diff --git a/pkg/clients/grpc/grpc.go b/pkg/clients/grpc/grpc.go index 76806c49..b68e7414 100644 --- a/pkg/clients/grpc/grpc.go +++ b/pkg/clients/grpc/grpc.go @@ -6,6 +6,7 @@ package grpc import ( "github.com/absmach/supermq/pkg/errors" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -26,7 +27,7 @@ type Client interface { type client struct { *grpc.ClientConn cfg clients.ClientConfiguration - security clients.Security + security tls.Security } var _ Client = (*client)(nil) @@ -59,14 +60,19 @@ func (c *client) Connection() *grpc.ClientConn { return c.ClientConn } -func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) { +func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, error) { opts := []grpc.DialOption{ grpc.WithStatsHandler(otelgrpc.NewClientHandler()), } - security := clients.WithoutTLS + security := tls.WithoutTLS if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS { - result, err := clients.LoadATLSConfig(agcfg) + result, err := tls.LoadATLSConfig( + agcfg.AttestationPolicy, + agcfg.ServerCAFile, + agcfg.ClientCert, + agcfg.ClientKey, + ) if err != nil { return nil, security, err } @@ -90,13 +96,13 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Securit return conn, security, nil } -func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) { - result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey) +func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) { + result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey) if err != nil { - return nil, clients.WithoutTLS, err + return nil, tls.WithoutTLS, err } - if result.Security == clients.WithoutTLS || result.Config == nil { + if result.Security == tls.WithoutTLS || result.Config == nil { return insecure.NewCredentials(), result.Security, nil } diff --git a/pkg/clients/http/client.go b/pkg/clients/http/client.go index c1f9759f..97fd3933 100644 --- a/pkg/clients/http/client.go +++ b/pkg/clients/http/client.go @@ -8,6 +8,7 @@ import ( "time" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" ) type Client interface { @@ -19,7 +20,7 @@ type Client interface { type client struct { transport *http.Transport cfg clients.ClientConfiguration - security clients.Security + security tls.Security } var _ Client = (*client)(nil) @@ -49,17 +50,22 @@ func (c *client) Timeout() time.Duration { return c.cfg.Config().Timeout } -func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) { +func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Security, error) { transport := &http.Transport{ MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, } - security := clients.WithoutTLS + security := tls.WithoutTLS if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS { - result, err := clients.LoadATLSConfig(*agcfg) + result, err := tls.LoadATLSConfig( + agcfg.AttestationPolicy, + agcfg.ServerCAFile, + agcfg.ClientCert, + agcfg.ClientKey, + ) if err != nil { return nil, security, err } @@ -69,12 +75,12 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients. } else { conf := cfg.Config() - result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey) + result, err := tls.LoadBasicConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey) if err != nil { return nil, security, err } - if result.Security != clients.WithoutTLS { + if result.Security != tls.WithoutTLS { transport.TLSClientConfig = result.Config } diff --git a/pkg/clients/http/client_test.go b/pkg/clients/http/client_test.go index 85bbd90f..3686b8c9 100644 --- a/pkg/clients/http/client_test.go +++ b/pkg/clients/http/client_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/ultravioletrs/cocos/pkg/clients" + "github.com/ultravioletrs/cocos/pkg/tls" ) func TestConfig_Configuration(t *testing.T) { @@ -144,7 +145,7 @@ func TestClient_Secure(t *testing.T) { URL: "http://localhost:8080", Timeout: 30 * time.Second, }, - expected: clients.WithoutTLS.String(), + expected: tls.WithoutTLS.String(), }, } @@ -183,7 +184,7 @@ func TestCreateTransport_DefaultSettings(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, transport) - assert.Equal(t, clients.WithoutTLS, security) + assert.Equal(t, tls.WithoutTLS, security) assert.Equal(t, 100, transport.MaxIdleConns) assert.Equal(t, 90*time.Second, transport.IdleConnTimeout) assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) @@ -205,7 +206,7 @@ func TestCreateTransport_ATLSError(t *testing.T) { assert.Error(t, err) assert.Nil(t, transport) - assert.Equal(t, clients.WithoutTLS, security) + assert.Equal(t, tls.WithoutTLS, security) assert.Contains(t, err.Error(), "failed to stat attestation policy") } @@ -220,7 +221,7 @@ func TestCreateTransport_BasicTLSError(t *testing.T) { assert.Error(t, err) assert.Nil(t, transport) - assert.Equal(t, clients.WithoutTLS, security) + assert.Equal(t, tls.WithoutTLS, security) assert.Contains(t, err.Error(), "failed to load root ca file") } diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go index f20a4ddc..5d19530b 100644 --- a/pkg/sdk/agent_test.go +++ b/pkg/sdk/agent_test.go @@ -19,7 +19,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk" "golang.org/x/crypto/sha3" @@ -388,7 +387,7 @@ func TestAttestation(t *testing.T) { cases := []struct { name string userKey any - reportData [quoteprovider.Nonce]byte + reportData [vtpm.SEVNonce]byte nonce [vtpm.Nonce]byte response *agent.AttestationResponse svcRes []byte @@ -397,7 +396,7 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report successfully", userKey: resultConsumerKey, - reportData: [quoteprovider.Nonce]byte(reportData), + reportData: [vtpm.SEVNonce]byte(reportData), nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, @@ -408,7 +407,7 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report with different key type", userKey: resultConsumer1Key, - reportData: [quoteprovider.Nonce]byte(reportData), + reportData: [vtpm.SEVNonce]byte(reportData), nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, @@ -419,7 +418,7 @@ func TestAttestation(t *testing.T) { { name: "failed to fetch attestation report", userKey: resultConsumerKey, - reportData: [quoteprovider.Nonce]byte(reportData), + reportData: [vtpm.SEVNonce]byte(reportData), nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, @@ -429,7 +428,7 @@ func TestAttestation(t *testing.T) { { name: "invalid report data", userKey: resultConsumerKey, - reportData: [quoteprovider.Nonce]byte{}, + reportData: [vtpm.SEVNonce]byte{}, nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, diff --git a/pkg/tls/doc.go b/pkg/tls/doc.go new file mode 100644 index 00000000..fce8e36e --- /dev/null +++ b/pkg/tls/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Package tls provides TLS configuration utilities for client connections. +// It supports standard TLS, mutual TLS (mTLS), and attested TLS (aTLS). +package tls diff --git a/pkg/clients/tls.go b/pkg/tls/tls.go similarity index 80% rename from pkg/clients/tls.go rename to pkg/tls/tls.go index 7538b736..51bd4d8d 100644 --- a/pkg/clients/tls.go +++ b/pkg/tls/tls.go @@ -1,7 +1,7 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 -package clients +package tls import ( "crypto/rand" @@ -53,20 +53,20 @@ var ( errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file") ) -// TLSResult contains the result of TLS configuration. -type TLSResult struct { +// Result contains the result of TLS configuration. +type Result struct { Config *tls.Config Security Security } -// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS). -func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) { +// LoadBasicConfig loads standard TLS configuration (TLS/mTLS). +func LoadBasicConfig(serverCAFile, clientCert, clientKey string) (*Result, error) { tlsConfig := &tls.Config{} security := WithoutTLS // If no TLS configuration is provided, return nil config (no TLS) if serverCAFile == "" && clientCert == "" && clientKey == "" { - return &TLSResult{Config: nil, Security: security}, nil + return &Result{Config: nil, Security: security}, nil } if serverCAFile != "" { @@ -96,14 +96,15 @@ func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, security = WithMTLS } - return &TLSResult{Config: tlsConfig, Security: security}, nil + return &Result{Config: tlsConfig, Security: security}, nil } // LoadATLSConfig configures Attested TLS. -func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { +// Parameters are passed individually to avoid circular dependencies with the clients package. +func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey string) (*Result, error) { security := WithATLS - info, err := os.Stat(cfg.AttestationPolicy) + info, err := os.Stat(attestationPolicy) if err != nil { return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err) } @@ -112,12 +113,12 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { return nil, errAttestationPolicyIrregular } - attestation.AttestationPolicyPath = cfg.AttestationPolicy + attestation.AttestationPolicyPath = attestationPolicy var rootCAs *x509.CertPool - if cfg.ServerCAFile != "" { - rootCAs, err = loadRootCAs(cfg.ServerCAFile) + if serverCAFile != "" { + rootCAs, err = loadRootCAs(serverCAFile) if err != nil { return nil, err } @@ -141,8 +142,8 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { }, } - if cfg.ClientCert != "" || cfg.ClientKey != "" { - certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) + if clientCert != "" || clientKey != "" { + certificate, err := tls.LoadX509KeyPair(clientCert, clientKey) if err != nil { return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err) } @@ -150,7 +151,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) { tlsConfig.Certificates = []tls.Certificate{certificate} } - return &TLSResult{Config: tlsConfig, Security: security}, nil + return &Result{Config: tlsConfig, Security: security}, nil } // loadRootCAs loads root CA certificates from a file. diff --git a/pkg/clients/tls_test.go b/pkg/tls/tls_test.go similarity index 77% rename from pkg/clients/tls_test.go rename to pkg/tls/tls_test.go index b9e82a66..d37d3e98 100644 --- a/pkg/clients/tls_test.go +++ b/pkg/tls/tls_test.go @@ -1,7 +1,7 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 -package clients +package tls import ( "crypto/rand" @@ -64,7 +64,7 @@ func TestSecurity_String(t *testing.T) { } } -func TestLoadBasicTLSConfig(t *testing.T) { +func TestLoadBasicConfig(t *testing.T) { // Create temporary directory for test files tmpDir := t.TempDir() @@ -164,7 +164,7 @@ func TestLoadBasicTLSConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey) + result, err := LoadBasicConfig(tt.serverCAFile, tt.clientCert, tt.clientKey) if tt.expectError { assert.Error(t, err) @@ -201,100 +201,86 @@ func TestLoadATLSConfig(t *testing.T) { require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644)) tests := []struct { - name string - config AttestedClientConfig - expectedSec Security - expectError bool - errorMsg string + name string + attestationPolicy string + serverCAFile string + clientCert string + clientKey string + expectedSec Security + expectError bool + errorMsg string }{ { - name: "ValidATLSConfig", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ServerCAFile: "", - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithATLS, - expectError: false, + name: "ValidATLSConfig", + attestationPolicy: policyFile, + serverCAFile: "", + clientCert: "", + clientKey: "", + expectedSec: WithATLS, + expectError: false, }, { - name: "ValidMATLSConfig", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ServerCAFile: caFile, - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithMATLS, - expectError: false, + name: "ValidMATLSConfig", + attestationPolicy: policyFile, + serverCAFile: caFile, + clientCert: "", + clientKey: "", + expectedSec: WithMATLS, + expectError: false, }, { - name: "ValidATLSWithClientCert", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ClientCert: certFile, - ClientKey: keyFile, - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithATLS, - expectError: false, + name: "ValidATLSWithClientCert", + attestationPolicy: policyFile, + serverCAFile: "", + clientCert: certFile, + clientKey: keyFile, + expectedSec: WithATLS, + expectError: false, }, { - name: "NonexistentPolicyFile", - config: AttestedClientConfig{ - AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"), - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, - errorMsg: "failed to stat attestation policy file", + name: "NonexistentPolicyFile", + attestationPolicy: filepath.Join(tmpDir, "nonexistent.json"), + serverCAFile: "", + clientCert: "", + clientKey: "", + expectedSec: WithoutTLS, + expectError: true, + errorMsg: "failed to stat attestation policy file", }, { - name: "PolicyFileIsDirectory", - config: AttestedClientConfig{ - AttestationPolicy: tmpDir, // Directory instead of file - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, - errorMsg: "attestation policy file is not a regular file", + name: "PolicyFileIsDirectory", + attestationPolicy: tmpDir, // Directory instead of file + serverCAFile: "", + clientCert: "", + clientKey: "", + expectedSec: WithoutTLS, + expectError: true, + errorMsg: "attestation policy file is not a regular file", }, { - name: "InvalidCAFile", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"), - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, - errorMsg: "failed to read certificate file", + name: "InvalidCAFile", + attestationPolicy: policyFile, + serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"), + clientCert: "", + clientKey: "", + expectedSec: WithoutTLS, + expectError: true, + errorMsg: "failed to read certificate file", }, { - name: "InvalidClientCert", - config: AttestedClientConfig{ - StandardClientConfig: StandardClientConfig{ - ClientCert: filepath.Join(tmpDir, "nonexistent.crt"), - ClientKey: keyFile, - }, - AttestationPolicy: policyFile, - ProductName: "test-product", - }, - expectedSec: WithoutTLS, - expectError: true, + name: "InvalidClientCert", + attestationPolicy: policyFile, + serverCAFile: "", + clientCert: filepath.Join(tmpDir, "nonexistent.crt"), + clientKey: keyFile, + expectedSec: WithoutTLS, + expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := LoadATLSConfig(tt.config) + result, err := LoadATLSConfig(tt.attestationPolicy, tt.serverCAFile, tt.clientCert, tt.clientKey) if tt.expectError { assert.Error(t, err)