COCOS-525-487 - Refactor attestation and atls (#562)

* Refactor attestation handling to remove quoteprovider dependency

- Removed references to quoteprovider in various files, replacing them with vtpm where necessary.
- Updated function signatures and implementations to use SEVNonce instead of quoteprovider.Nonce.
- Introduced new vtpm package to handle SEV-related attestation logic, including fetching and verifying attestation reports.
- Adjusted tests to reflect changes in the attestation logic and ensure compatibility with the new structure.
- Deleted the now redundant quoteprovider/sev_test.go file.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Add veraison/go-cose dependency to go.mod

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Introduce TLS package for enhanced security configuration and refactor client code to utilize new TLS utilities

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-02-18 13:53:04 +03:00
committed by GitHub
parent de50b6d2d4
commit 207bfd99af
29 changed files with 222 additions and 591 deletions
+1 -2
View File
@@ -6,7 +6,6 @@ import (
"errors"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
@@ -42,7 +41,7 @@ func (req resultReq) validate() error {
}
type attestationReq struct {
TeeNonce [quoteprovider.Nonce]byte
TeeNonce [vtpm.SEVNonce]byte
VtpmNonce [vtpm.Nonce]byte
AttType attestation.PlatformType
}
+4 -5
View File
@@ -14,7 +14,6 @@ import (
"github.com/go-kit/kit/transport/grpc"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
@@ -134,7 +133,7 @@ func encodeResultResponse(_ context.Context, response any) (any, error) {
func validateNonce(nonce []byte, maxLen int, target any) error {
if len(nonce) > maxLen {
switch maxLen {
case quoteprovider.Nonce:
case vtpm.SEVNonce:
return ErrTEENonceLength
case vtpm.Nonce:
return ErrVTPMNonceLength
@@ -144,7 +143,7 @@ func validateNonce(nonce []byte, maxLen int, target any) error {
}
switch t := target.(type) {
case *[quoteprovider.Nonce]byte:
case *[vtpm.SEVNonce]byte:
copy(t[:], nonce)
case *[vtpm.Nonce]byte:
copy(t[:], nonce)
@@ -156,10 +155,10 @@ func validateNonce(nonce []byte, maxLen int, target any) error {
func decodeAttestationRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*agent.AttestationRequest)
var reportData [quoteprovider.Nonce]byte
var reportData [vtpm.SEVNonce]byte
var nonce [vtpm.Nonce]byte
if err := validateNonce(req.TeeNonce, quoteprovider.Nonce, &reportData); err != nil {
if err := validateNonce(req.TeeNonce, vtpm.SEVNonce, &reportData); err != nil {
return nil, err
}
+9 -10
View File
@@ -12,7 +12,6 @@ import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
@@ -229,7 +228,7 @@ func TestAttestation(t *testing.T) {
return len(resp.File) > 0
})).Return(nil).Once()
reportData := [quoteprovider.Nonce]byte{}
reportData := [vtpm.SEVNonce]byte{}
vtpmNonce := [vtpm.Nonce]byte{}
attestationType := attestation.SNP
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil)
@@ -298,8 +297,8 @@ func TestValidateNonce(t *testing.T) {
}{
{
name: "valid TEE nonce",
nonce: make([]byte, quoteprovider.Nonce),
maxLen: quoteprovider.Nonce,
nonce: make([]byte, vtpm.SEVNonce),
maxLen: vtpm.SEVNonce,
shouldError: false,
},
{
@@ -310,8 +309,8 @@ func TestValidateNonce(t *testing.T) {
},
{
name: "TEE nonce too long",
nonce: make([]byte, quoteprovider.Nonce+1),
maxLen: quoteprovider.Nonce,
nonce: make([]byte, vtpm.SEVNonce+1),
maxLen: vtpm.SEVNonce,
shouldError: true,
expectedErr: ErrTEENonceLength,
},
@@ -326,8 +325,8 @@ func TestValidateNonce(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.maxLen == quoteprovider.Nonce {
var target [quoteprovider.Nonce]byte
if tt.maxLen == vtpm.SEVNonce {
var target [vtpm.SEVNonce]byte
err := validateNonce(tt.nonce, tt.maxLen, &target)
if tt.shouldError {
assert.Error(t, err)
@@ -388,7 +387,7 @@ func TestEncodeResultResponse(t *testing.T) {
}
func TestDecodeAttestationRequest(t *testing.T) {
teeNonce := make([]byte, quoteprovider.Nonce)
teeNonce := make([]byte, vtpm.SEVNonce)
vtpmNonce := make([]byte, vtpm.Nonce)
req := &agent.AttestationRequest{
@@ -406,7 +405,7 @@ func TestDecodeAttestationRequest(t *testing.T) {
func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) {
// Test with TEE nonce too long
teeNonce := make([]byte, quoteprovider.Nonce+1)
teeNonce := make([]byte, vtpm.SEVNonce+1)
req := &agent.AttestationRequest{TeeNonce: teeNonce}
_, err := decodeAttestationRequest(context.Background(), req)
+1 -2
View File
@@ -13,7 +13,6 @@ import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
@@ -105,7 +104,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e
return lm.svc.Result(ctx)
}
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) {
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin))
if err != nil {
+1 -2
View File
@@ -12,7 +12,6 @@ import (
"github.com/go-kit/kit/metrics"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
)
@@ -91,7 +90,7 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) {
return ms.svc.Result(ctx)
}
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
defer func(begin time.Time) {
ms.counter.With("method", "attestation").Add(1)
ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds())
+2 -3
View File
@@ -21,7 +21,6 @@ import (
"github.com/ultravioletrs/cocos/agent/statemachine"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
@@ -120,7 +119,7 @@ type Service interface {
Algo(ctx context.Context, algorithm Algorithm) error
Data(ctx context.Context, dataset Dataset) error
Result(ctx context.Context) ([]byte, error)
Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
IMAMeasurements(ctx context.Context) ([]byte, []byte, error)
AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error)
State() string
@@ -418,7 +417,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
return as.result, as.runError
}
func (as *agentService) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType)
if err != nil {
return []byte{}, errors.Wrap(ErrAttestationFailed, err)
+3 -4
View File
@@ -24,7 +24,6 @@ import (
"github.com/ultravioletrs/cocos/agent/statemachine"
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
"golang.org/x/crypto/sha3"
@@ -336,7 +335,7 @@ func TestAttestation(t *testing.T) {
cases := []struct {
name string
reportData [quoteprovider.Nonce]byte
reportData [vtpm.SEVNonce]byte
nonce [vtpm.Nonce]byte
rawQuote []uint8
platform attestation.PlatformType
@@ -457,8 +456,8 @@ func TestAzureAttestationToken(t *testing.T) {
}
}
func generateReportData() [quoteprovider.Nonce]byte {
bytes := make([]byte, quoteprovider.Nonce)
func generateReportData() [vtpm.SEVNonce]byte {
bytes := make([]byte, vtpm.SEVNonce)
_, err := rand.Read(bytes)
if err != nil {
log.Fatalf("Failed to generate random bytes: %v", err)
Binary file not shown.
+3 -4
View File
@@ -21,7 +21,6 @@ import (
"github.com/spf13/pflag"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/protobuf/encoding/prototext"
@@ -171,10 +170,10 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
return
}
var fixedReportData [quoteprovider.Nonce]byte
var fixedReportData [vtpm.SEVNonce]byte
if attType == attestation.SNP || attType == attestation.SNPvTPM {
if len(teeNonce) > quoteprovider.Nonce {
msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", quoteprovider.Nonce)
if len(teeNonce) > vtpm.SEVNonce {
msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", vtpm.SEVNonce)
cmd.Println(msg)
return
}
+8 -9
View File
@@ -21,7 +21,6 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
)
@@ -37,8 +36,8 @@ func TestNewAttestationCmd(t *testing.T) {
var buf bytes.Buffer
cmd.SetOut(&buf)
reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce)
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil)
reportData := bytes.Repeat([]byte{0x01}, vtpm.SEVNonce)
mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(reportData), mock.Anything).Return(nil)
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
err := cmd.Execute()
@@ -50,7 +49,7 @@ func TestNewGetAttestationCmd(t *testing.T) {
validattestation, err := os.ReadFile("../attestation.bin")
require.NoError(t, err)
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))
vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
tokenNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
@@ -184,7 +183,7 @@ func TestNewGetAttestationCmd(t *testing.T) {
var buf bytes.Buffer
cmd.SetOut(&buf)
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
_, err := args.Get(4).(*os.File).Write(tc.mockResponse)
require.NoError(t, err)
})
@@ -891,7 +890,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) {
},
{
name: "TEE nonce too large",
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))},
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce+1))},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
@@ -912,7 +911,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) {
},
{
name: "successful TDX attestation",
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))},
setupMock: func(sdk *mocks.SDK) {
sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(nil).Run(func(args mock.Arguments) {
@@ -925,7 +924,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) {
},
{
name: "file creation error",
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))},
setupMock: func(sdk *mocks.SDK) {
},
expectedErr: "Error creating attestation file",
@@ -1380,7 +1379,7 @@ func TestContextCancellation(t *testing.T) {
cmd.SetOut(&buf)
cmd.SetErr(&buf)
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))
cmd.SetArgs([]string{"snp", "--tee", teeNonceHex})
err := cmd.Execute()
+1 -2
View File
@@ -18,7 +18,6 @@ import (
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/sync/errgroup"
@@ -89,7 +88,7 @@ func main() {
}
if ccPlatform == attestation.SNP || ccPlatform == attestation.SNPvTPM {
if err := quoteprovider.FetchCertificates(uint(cfg.Vmpl)); err != nil {
if err := vtpm.FetchSEVCertificates(uint(cfg.Vmpl)); err != nil {
logger.Error(fmt.Sprintf("failed to fetch certificates: %s", err))
exitCode = 1
return
+1
View File
@@ -0,0 +1 @@
bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3
+3
View File
@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY
-----END PRIVATE KEY-----
+3
View File
@@ -0,0 +1,3 @@
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw=
-----END PUBLIC KEY-----
+3 -4
View File
@@ -21,7 +21,6 @@ import (
"github.com/google/go-tpm-tools/proto/attest"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/protobuf/proto"
)
@@ -130,7 +129,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
return errors.Wrap(fmt.Errorf("failed to convert TEE report to proto"), err)
}
return quoteprovider.VerifyAttestationReportTLS(attestationReport, teeNonce, a.Policy)
return vtpm.VerifySEVAttestationReportTLS(attestationReport, teeNonce, a.Policy)
}
func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
@@ -148,7 +147,7 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []
}
snpReport := quote.GetSevSnpAttestation()
if err = quoteprovider.VerifyAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
if err = vtpm.VerifySEVAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
return fmt.Errorf("failed to verify vTPM attestation report: %w", err)
}
@@ -266,7 +265,7 @@ func GenerateAttestationPolicy(token, product string, policy uint64) (*attestati
return nil, fmt.Errorf("failed to decode reportID: %w", err)
}
sevSnpProduct := quoteprovider.GetProductName(product)
sevSnpProduct := vtpm.GetSEVProductName(product)
return &attestation.Config{
Config: &check.Config{
-377
View File
@@ -1,377 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build !embed
package quoteprovider
import (
"os"
"path"
"testing"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFillInAttestationLocal(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
os.Setenv("HOME", originalHome)
}()
tempDir, err := os.MkdirTemp("", "test_home")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
os.Setenv("HOME", tempDir)
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
err = os.MkdirAll(cocosDir, 0o755)
require.NoError(t, err)
bundleContent := []byte("mock ASK ARK bundle")
bundlePath := path.Join(cocosDir, arkAskBundleName)
err = os.WriteFile(bundlePath, bundleContent, 0o644)
require.NoError(t, err)
config := &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{},
}
tests := []struct {
name string
attestation *sevsnp.Attestation
setupFunc func()
expectedError bool
errorContains string
}{
{
name: "Empty attestation - creates new chain",
attestation: &sevsnp.Attestation{
CertificateChain: nil,
},
setupFunc: func() {},
expectedError: true,
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
},
{
name: "Attestation with existing chain - no changes needed",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("existing ASK cert"),
ArkCert: []byte("existing ARK cert"),
},
},
setupFunc: func() {},
expectedError: false,
},
{
name: "Attestation with empty chain - tries to load from file",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
setupFunc: func() {},
expectedError: true,
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
},
{
name: "No bundle file exists - no error",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
setupFunc: func() {
os.Remove(bundlePath)
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("HOME", tempDir)
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
t.Fatalf("Failed to write bundle file: %v", err)
}
}
tt.setupFunc()
err := fillInAttestationLocal(tt.attestation, config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestGetProductName(t *testing.T) {
tests := []struct {
name string
product string
expected sevsnp.SevProduct_SevProductName
}{
{
name: "Milan product",
product: sevSnpProductMilan,
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
{
name: "Genoa product",
product: sevSnpProductGenoa,
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
},
{
name: "Unknown product",
product: "UnknownProduct",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
{
name: "Empty product",
product: "",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
{
name: "Case sensitive - milan lowercase",
product: "milan",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetProductName(tt.product)
assert.Equal(t, tt.expected, result)
})
}
}
func TestVerifyReport(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
config *check.Config
expectedError bool
errorContains string
}{
{
name: "Invalid product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: "InvalidProduct",
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "product name must be",
},
{
name: "Valid Milan product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "attestation verification failed",
},
{
name: "Valid Genoa product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductGenoa,
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "attestation verification failed",
},
{
name: "Config with existing product policy",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{
Product: &sevsnp.SevProduct{
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
},
},
expectedError: true,
errorContains: "attestation verification failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := verifyReport(tt.attestation, tt.config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateReport(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
config *check.Config
expectedError bool
errorContains string
}{
{
name: "Basic validation test",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
Policy: &check.Policy{
Policy: 196608,
},
},
expectedError: true,
errorContains: "attestation validation failed",
},
{
name: "Validation with report data",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
Policy: &check.Policy{
Policy: 196608,
ReportData: []byte("test report datatest report datatest report datatest report data"),
},
},
expectedError: true,
errorContains: "attestation validation failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateReport(tt.attestation, tt.config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestFetchAttestation(t *testing.T) {
tests := []struct {
name string
reportData []byte
vmpl uint
expectedError bool
errorContains string
}{
{
name: "Report data too large",
reportData: make([]byte, Nonce+1),
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Valid report data size",
reportData: make([]byte, 32),
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Maximum valid report data size",
reportData: make([]byte, Nonce),
vmpl: 1,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Empty report data",
reportData: []byte{},
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := FetchAttestation(tt.reportData, tt.vmpl)
if tt.expectedError {
assert.Error(t, err)
assert.Empty(t, result)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
assert.NotEmpty(t, result)
}
})
}
}
func TestGetLeveledQuoteProvider(t *testing.T) {
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
provider, err := GetLeveledQuoteProvider()
if err != nil {
assert.Error(t, err)
assert.Nil(t, provider)
} else {
assert.NoError(t, err)
assert.NotNil(t, provider)
}
})
}
@@ -3,7 +3,7 @@
//go:build !embed
package quoteprovider
package vtpm
import (
"crypto/rand"
@@ -32,7 +32,7 @@ const (
cocosDirectory = "/cocos"
arkAskBundleName = "ask_ark.pem"
vcekName = "vcek.pem"
Nonce = 64
SEVNonce = 64
sevSnpProductMilan = "Milan"
sevSnpProductGenoa = "Genoa"
)
@@ -43,9 +43,9 @@ var (
)
var (
ErrProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
ErrAttVerification = errors.New("attestation verification failed")
errAttValidation = errors.New("attestation validation failed")
ErrSEVProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
ErrSEVAttVerification = errors.New("attestation verification failed")
errSEVAttValidation = errors.New("attestation validation failed")
)
func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error {
@@ -77,16 +77,17 @@ func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config)
return nil
}
// verifyReport verifies the SEV-SNP attestation report.
func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
if err != nil {
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrAttVerification, err))
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrSEVAttVerification, err))
}
if cfg.Policy.Product == nil {
productName := GetProductName(cfg.RootOfTrust.ProductLine)
productName := GetSEVProductName(cfg.RootOfTrust.ProductLine)
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
return ErrProductLine
return ErrSEVProductLine
}
sopts.Product = &sevsnp.SevProduct{
@@ -107,30 +108,33 @@ func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
}
if err := verify.SnpAttestation(attestationPB, sopts); err != nil {
return errors.Wrap(ErrAttVerification, err)
return errors.Wrap(ErrSEVAttVerification, err)
}
return nil
}
// validateReport validates the SEV-SNP attestation report against policy.
func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
opts, err := validate.PolicyToOptions(cfg.Policy)
if err != nil {
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrAttVerification, err))
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrSEVAttVerification, err))
}
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
return errors.Wrap(errAttValidation, err)
return errors.Wrap(errSEVAttValidation, err)
}
return nil
}
func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
// getLeveledQuoteProvider returns a leveled quote provider for SEV-SNP.
func getLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
return client.GetLeveledQuoteProvider()
}
func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
// VerifySEVAttestationReportTLS verifies a SEV-SNP attestation report for TLS (exported for azure package).
func VerifySEVAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
config := policy.Config
// Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report.
@@ -141,10 +145,11 @@ func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []
config.Policy.ReportData = reportData[:]
}
return VerifyAndValidate(attestationPB, config)
return verifySEVAndValidate(attestationPB, config)
}
func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
// verifySEVAndValidate performs both verification and validation of a SEV-SNP attestation.
func verifySEVAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
logger.Init("", false, false, io.Discard)
if err := verifyReport(attestationPB, cfg); err != nil {
@@ -158,15 +163,16 @@ func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) err
return nil
}
func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
var reportData [Nonce]byte
// fetchSEVAttestation fetches a SEV-SNP attestation report.
func fetchSEVAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
var reportData [SEVNonce]byte
qp, err := GetLeveledQuoteProvider()
qp, err := getLeveledQuoteProvider()
if err != nil {
return []byte{}, fmt.Errorf("could not get quote provider")
}
if len(reportData) > Nonce {
if len(reportData) > SEVNonce {
return []byte{}, fmt.Errorf("attestation report size mismatch")
}
copy(reportData[:], reportDataSlice)
@@ -206,7 +212,8 @@ func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
return result, nil
}
func GetProductName(product string) sevsnp.SevProduct_SevProductName {
// GetSEVProductName maps a product string to a SEV product name.
func GetSEVProductName(product string) sevsnp.SevProduct_SevProductName {
switch product {
case sevSnpProductMilan:
return sevsnp.SevProduct_SEV_PRODUCT_MILAN
@@ -217,6 +224,7 @@ func GetProductName(product string) sevsnp.SevProduct_SevProductName {
}
}
// derToPem converts DER-encoded certificate to PEM format.
func derToPem(der []byte) []byte {
// Try to parse to make sure it's a certificate
if _, err := x509.ParseCertificate(der); err != nil {
@@ -226,15 +234,16 @@ func derToPem(der []byte) []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
}
func FetchCertificates(vmpl uint) error {
var reportData [Nonce]byte
// FetchSEVCertificates fetches SEV-SNP certificates from KDS.
func FetchSEVCertificates(vmpl uint) error {
var reportData [SEVNonce]byte
qp, err := GetLeveledQuoteProvider()
qp, err := getLeveledQuoteProvider()
if err != nil {
return fmt.Errorf("could not get quote provider")
}
if len(reportData) > Nonce {
if len(reportData) > SEVNonce {
return fmt.Errorf("attestation report size mismatch")
}
+4 -5
View File
@@ -25,7 +25,6 @@ import (
"github.com/google/go-tpm/tpmutil"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/encoding/prototext"
@@ -120,7 +119,7 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
}
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
return quoteprovider.FetchAttestation(teeNonce, v.vmpl)
return fetchSEVAttestation(teeNonce, v.vmpl)
}
func (v provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
@@ -171,7 +170,7 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
}
attestationReport := sevsnp.Attestation{Report: attestReport, CertificateChain: nil}
return quoteprovider.VerifyAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
return VerifySEVAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
}
func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
@@ -234,7 +233,7 @@ func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Write
attestData := sha3.Sum512(nonce)
if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
if err := VerifySEVAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
return fmt.Errorf("failed to verify TEE attestation report: %v", err)
}
@@ -336,7 +335,7 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint)
attestData := sha3.Sum512(teeNonce)
rawTeeAttestation, err := quoteprovider.FetchAttestation(attestData[:], vmpl)
rawTeeAttestation, err := fetchSEVAttestation(attestData[:], vmpl)
if err != nil {
return fmt.Errorf("failed to fetch TEE attestation report: %v", err)
}
+7 -10
View File
@@ -22,12 +22,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"google.golang.org/protobuf/encoding/protojson"
)
const sevSnpProductMilan = "Milan"
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
type mockTPM struct {
@@ -679,13 +676,13 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
name: "Valid attestation, distorted signature",
attestationReport: attestationPB,
reportData: reportData,
err: quoteprovider.ErrAttVerification,
err: ErrSEVAttVerification,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
@@ -713,13 +710,13 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
name: "Valid attestation, unknown product",
attestationReport: attestationPB,
reportData: reportData,
err: quoteprovider.ErrProductLine,
err: ErrSEVProductLine,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
@@ -752,7 +749,7 @@ func TestVerifyAttestationReportSuccess(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
@@ -780,13 +777,13 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
name: "Valid attestation, malformed policy (measurement)",
attestationReport: attestationPB,
reportData: reportData,
err: quoteprovider.ErrAttVerification,
err: ErrSEVAttVerification,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
+1 -1
View File
@@ -2,5 +2,5 @@
// SPDX-License-Identifier: Apache-2.0
// Package clients contains the domain concept definitions needed to support
// HTTP/gRPC Client functionality.
// client configuration for HTTP/gRPC connections.
package clients
+2 -1
View File
@@ -9,6 +9,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/tls"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
@@ -21,7 +22,7 @@ func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc
return nil, nil, err
}
if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() {
if client.Secure() != tls.WithMATLS.String() && client.Secure() != tls.WithATLS.String() && client.Secure() != tls.WithTLS.String() {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
+10 -9
View File
@@ -22,6 +22,7 @@ import (
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
func TestNewClient(t *testing.T) {
@@ -119,7 +120,7 @@ func TestNewClient(t *testing.T) {
ServerCAFile: "nonexistent.pem",
},
wantErr: true,
err: clients.ErrFailedToLoadRootCA,
err: tls.ErrFailedToLoadRootCA,
},
{
name: "Fail with invalid ClientCert",
@@ -130,7 +131,7 @@ func TestNewClient(t *testing.T) {
ClientKey: clientKeyFile,
},
wantErr: true,
err: clients.ErrFailedToLoadClientCertKey,
err: tls.ErrFailedToLoadClientCertKey,
},
{
name: "Fail with invalid ClientKey",
@@ -141,7 +142,7 @@ func TestNewClient(t *testing.T) {
ClientKey: "nonexistent.pem",
},
wantErr: true,
err: clients.ErrFailedToLoadClientCertKey,
err: tls.ErrFailedToLoadClientCertKey,
},
}
@@ -169,32 +170,32 @@ func TestNewClient(t *testing.T) {
func TestClientSecure(t *testing.T) {
tests := []struct {
name string
secure clients.Security
secure tls.Security
expected string
}{
{
name: "Without TLS",
secure: clients.WithoutTLS,
secure: tls.WithoutTLS,
expected: "without TLS",
},
{
name: "With TLS",
secure: clients.WithTLS,
secure: tls.WithTLS,
expected: "with TLS",
},
{
name: "With mTLS",
secure: clients.WithMTLS,
secure: tls.WithMTLS,
expected: "with mTLS",
},
{
name: "With aTLS",
secure: clients.WithATLS,
secure: tls.WithATLS,
expected: "with aTLS",
},
{
name: "With maTLS",
secure: clients.WithMATLS,
secure: tls.WithMATLS,
expected: "with maTLS",
},
}
+14 -8
View File
@@ -6,6 +6,7 @@ package grpc
import (
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -26,7 +27,7 @@ type Client interface {
type client struct {
*grpc.ClientConn
cfg clients.ClientConfiguration
security clients.Security
security tls.Security
}
var _ Client = (*client)(nil)
@@ -59,14 +60,19 @@ func (c *client) Connection() *grpc.ClientConn {
return c.ClientConn
}
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) {
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, error) {
opts := []grpc.DialOption{
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
}
security := clients.WithoutTLS
security := tls.WithoutTLS
if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
result, err := clients.LoadATLSConfig(agcfg)
result, err := tls.LoadATLSConfig(
agcfg.AttestationPolicy,
agcfg.ServerCAFile,
agcfg.ClientCert,
agcfg.ClientKey,
)
if err != nil {
return nil, security, err
}
@@ -90,13 +96,13 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Securit
return conn, security, nil
}
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) {
result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey)
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) {
result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey)
if err != nil {
return nil, clients.WithoutTLS, err
return nil, tls.WithoutTLS, err
}
if result.Security == clients.WithoutTLS || result.Config == nil {
if result.Security == tls.WithoutTLS || result.Config == nil {
return insecure.NewCredentials(), result.Security, nil
}
+12 -6
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
type Client interface {
@@ -19,7 +20,7 @@ type Client interface {
type client struct {
transport *http.Transport
cfg clients.ClientConfiguration
security clients.Security
security tls.Security
}
var _ Client = (*client)(nil)
@@ -49,17 +50,22 @@ func (c *client) Timeout() time.Duration {
return c.cfg.Config().Timeout
}
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) {
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Security, error) {
transport := &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
}
security := clients.WithoutTLS
security := tls.WithoutTLS
if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
result, err := clients.LoadATLSConfig(*agcfg)
result, err := tls.LoadATLSConfig(
agcfg.AttestationPolicy,
agcfg.ServerCAFile,
agcfg.ClientCert,
agcfg.ClientKey,
)
if err != nil {
return nil, security, err
}
@@ -69,12 +75,12 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.
} else {
conf := cfg.Config()
result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
result, err := tls.LoadBasicConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
if err != nil {
return nil, security, err
}
if result.Security != clients.WithoutTLS {
if result.Security != tls.WithoutTLS {
transport.TLSClientConfig = result.Config
}
+5 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
func TestConfig_Configuration(t *testing.T) {
@@ -144,7 +145,7 @@ func TestClient_Secure(t *testing.T) {
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
},
expected: clients.WithoutTLS.String(),
expected: tls.WithoutTLS.String(),
},
}
@@ -183,7 +184,7 @@ func TestCreateTransport_DefaultSettings(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, tls.WithoutTLS, security)
assert.Equal(t, 100, transport.MaxIdleConns)
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
@@ -205,7 +206,7 @@ func TestCreateTransport_ATLSError(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, tls.WithoutTLS, security)
assert.Contains(t, err.Error(), "failed to stat attestation policy")
}
@@ -220,7 +221,7 @@ func TestCreateTransport_BasicTLSError(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, tls.WithoutTLS, security)
assert.Contains(t, err.Error(), "failed to load root ca file")
}
+5 -6
View File
@@ -19,7 +19,6 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/sdk"
"golang.org/x/crypto/sha3"
@@ -388,7 +387,7 @@ func TestAttestation(t *testing.T) {
cases := []struct {
name string
userKey any
reportData [quoteprovider.Nonce]byte
reportData [vtpm.SEVNonce]byte
nonce [vtpm.Nonce]byte
response *agent.AttestationResponse
svcRes []byte
@@ -397,7 +396,7 @@ func TestAttestation(t *testing.T) {
{
name: "fetch attestation report successfully",
userKey: resultConsumerKey,
reportData: [quoteprovider.Nonce]byte(reportData),
reportData: [vtpm.SEVNonce]byte(reportData),
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: report,
@@ -408,7 +407,7 @@ func TestAttestation(t *testing.T) {
{
name: "fetch attestation report with different key type",
userKey: resultConsumer1Key,
reportData: [quoteprovider.Nonce]byte(reportData),
reportData: [vtpm.SEVNonce]byte(reportData),
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: report,
@@ -419,7 +418,7 @@ func TestAttestation(t *testing.T) {
{
name: "failed to fetch attestation report",
userKey: resultConsumerKey,
reportData: [quoteprovider.Nonce]byte(reportData),
reportData: [vtpm.SEVNonce]byte(reportData),
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: []byte{},
@@ -429,7 +428,7 @@ func TestAttestation(t *testing.T) {
{
name: "invalid report data",
userKey: resultConsumerKey,
reportData: [quoteprovider.Nonce]byte{},
reportData: [vtpm.SEVNonce]byte{},
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: []byte{},
+6
View File
@@ -0,0 +1,6 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package tls provides TLS configuration utilities for client connections.
// It supports standard TLS, mutual TLS (mTLS), and attested TLS (aTLS).
package tls
+16 -15
View File
@@ -1,7 +1,7 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
package tls
import (
"crypto/rand"
@@ -53,20 +53,20 @@ var (
errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file")
)
// TLSResult contains the result of TLS configuration.
type TLSResult struct {
// Result contains the result of TLS configuration.
type Result struct {
Config *tls.Config
Security Security
}
// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS).
func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) {
// LoadBasicConfig loads standard TLS configuration (TLS/mTLS).
func LoadBasicConfig(serverCAFile, clientCert, clientKey string) (*Result, error) {
tlsConfig := &tls.Config{}
security := WithoutTLS
// If no TLS configuration is provided, return nil config (no TLS)
if serverCAFile == "" && clientCert == "" && clientKey == "" {
return &TLSResult{Config: nil, Security: security}, nil
return &Result{Config: nil, Security: security}, nil
}
if serverCAFile != "" {
@@ -96,14 +96,15 @@ func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult,
security = WithMTLS
}
return &TLSResult{Config: tlsConfig, Security: security}, nil
return &Result{Config: tlsConfig, Security: security}, nil
}
// LoadATLSConfig configures Attested TLS.
func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
// Parameters are passed individually to avoid circular dependencies with the clients package.
func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey string) (*Result, error) {
security := WithATLS
info, err := os.Stat(cfg.AttestationPolicy)
info, err := os.Stat(attestationPolicy)
if err != nil {
return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err)
}
@@ -112,12 +113,12 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
return nil, errAttestationPolicyIrregular
}
attestation.AttestationPolicyPath = cfg.AttestationPolicy
attestation.AttestationPolicyPath = attestationPolicy
var rootCAs *x509.CertPool
if cfg.ServerCAFile != "" {
rootCAs, err = loadRootCAs(cfg.ServerCAFile)
if serverCAFile != "" {
rootCAs, err = loadRootCAs(serverCAFile)
if err != nil {
return nil, err
}
@@ -141,8 +142,8 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
},
}
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if clientCert != "" || clientKey != "" {
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
if err != nil {
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
}
@@ -150,7 +151,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
tlsConfig.Certificates = []tls.Certificate{certificate}
}
return &TLSResult{Config: tlsConfig, Security: security}, nil
return &Result{Config: tlsConfig, Security: security}, nil
}
// loadRootCAs loads root CA certificates from a file.
+64 -78
View File
@@ -1,7 +1,7 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
package tls
import (
"crypto/rand"
@@ -64,7 +64,7 @@ func TestSecurity_String(t *testing.T) {
}
}
func TestLoadBasicTLSConfig(t *testing.T) {
func TestLoadBasicConfig(t *testing.T) {
// Create temporary directory for test files
tmpDir := t.TempDir()
@@ -164,7 +164,7 @@ func TestLoadBasicTLSConfig(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
result, err := LoadBasicConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
if tt.expectError {
assert.Error(t, err)
@@ -201,100 +201,86 @@ func TestLoadATLSConfig(t *testing.T) {
require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644))
tests := []struct {
name string
config AttestedClientConfig
expectedSec Security
expectError bool
errorMsg string
name string
attestationPolicy string
serverCAFile string
clientCert string
clientKey string
expectedSec Security
expectError bool
errorMsg string
}{
{
name: "ValidATLSConfig",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: "",
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithATLS,
expectError: false,
name: "ValidATLSConfig",
attestationPolicy: policyFile,
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithATLS,
expectError: false,
},
{
name: "ValidMATLSConfig",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: caFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithMATLS,
expectError: false,
name: "ValidMATLSConfig",
attestationPolicy: policyFile,
serverCAFile: caFile,
clientCert: "",
clientKey: "",
expectedSec: WithMATLS,
expectError: false,
},
{
name: "ValidATLSWithClientCert",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ClientCert: certFile,
ClientKey: keyFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithATLS,
expectError: false,
name: "ValidATLSWithClientCert",
attestationPolicy: policyFile,
serverCAFile: "",
clientCert: certFile,
clientKey: keyFile,
expectedSec: WithATLS,
expectError: false,
},
{
name: "NonexistentPolicyFile",
config: AttestedClientConfig{
AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to stat attestation policy file",
name: "NonexistentPolicyFile",
attestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to stat attestation policy file",
},
{
name: "PolicyFileIsDirectory",
config: AttestedClientConfig{
AttestationPolicy: tmpDir, // Directory instead of file
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "attestation policy file is not a regular file",
name: "PolicyFileIsDirectory",
attestationPolicy: tmpDir, // Directory instead of file
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "attestation policy file is not a regular file",
},
{
name: "InvalidCAFile",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to read certificate file",
name: "InvalidCAFile",
attestationPolicy: policyFile,
serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to read certificate file",
},
{
name: "InvalidClientCert",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ClientCert: filepath.Join(tmpDir, "nonexistent.crt"),
ClientKey: keyFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
name: "InvalidClientCert",
attestationPolicy: policyFile,
serverCAFile: "",
clientCert: filepath.Join(tmpDir, "nonexistent.crt"),
clientKey: keyFile,
expectedSec: WithoutTLS,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LoadATLSConfig(tt.config)
result, err := LoadATLSConfig(tt.attestationPolicy, tt.serverCAFile, tt.clientCert, tt.clientKey)
if tt.expectError {
assert.Error(t, err)