mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
COCOS-525-487 - Refactor attestation and atls (#562)
* Refactor attestation handling to remove quoteprovider dependency - Removed references to quoteprovider in various files, replacing them with vtpm where necessary. - Updated function signatures and implementations to use SEVNonce instead of quoteprovider.Nonce. - Introduced new vtpm package to handle SEV-related attestation logic, including fetching and verifying attestation reports. - Adjusted tests to reflect changes in the attestation logic and ensure compatibility with the new structure. - Deleted the now redundant quoteprovider/sev_test.go file. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Add veraison/go-cose dependency to go.mod Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Introduce TLS package for enhanced security configuration and refactor client code to utilize new TLS utilities Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
de50b6d2d4
commit
207bfd99af
@@ -6,7 +6,6 @@ import (
|
||||
"errors"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
@@ -42,7 +41,7 @@ func (req resultReq) validate() error {
|
||||
}
|
||||
|
||||
type attestationReq struct {
|
||||
TeeNonce [quoteprovider.Nonce]byte
|
||||
TeeNonce [vtpm.SEVNonce]byte
|
||||
VtpmNonce [vtpm.Nonce]byte
|
||||
AttType attestation.PlatformType
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/go-kit/kit/transport/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -134,7 +133,7 @@ func encodeResultResponse(_ context.Context, response any) (any, error) {
|
||||
func validateNonce(nonce []byte, maxLen int, target any) error {
|
||||
if len(nonce) > maxLen {
|
||||
switch maxLen {
|
||||
case quoteprovider.Nonce:
|
||||
case vtpm.SEVNonce:
|
||||
return ErrTEENonceLength
|
||||
case vtpm.Nonce:
|
||||
return ErrVTPMNonceLength
|
||||
@@ -144,7 +143,7 @@ func validateNonce(nonce []byte, maxLen int, target any) error {
|
||||
}
|
||||
|
||||
switch t := target.(type) {
|
||||
case *[quoteprovider.Nonce]byte:
|
||||
case *[vtpm.SEVNonce]byte:
|
||||
copy(t[:], nonce)
|
||||
case *[vtpm.Nonce]byte:
|
||||
copy(t[:], nonce)
|
||||
@@ -156,10 +155,10 @@ func validateNonce(nonce []byte, maxLen int, target any) error {
|
||||
|
||||
func decodeAttestationRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*agent.AttestationRequest)
|
||||
var reportData [quoteprovider.Nonce]byte
|
||||
var reportData [vtpm.SEVNonce]byte
|
||||
var nonce [vtpm.Nonce]byte
|
||||
|
||||
if err := validateNonce(req.TeeNonce, quoteprovider.Nonce, &reportData); err != nil {
|
||||
if err := validateNonce(req.TeeNonce, vtpm.SEVNonce, &reportData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -229,7 +228,7 @@ func TestAttestation(t *testing.T) {
|
||||
return len(resp.File) > 0
|
||||
})).Return(nil).Once()
|
||||
|
||||
reportData := [quoteprovider.Nonce]byte{}
|
||||
reportData := [vtpm.SEVNonce]byte{}
|
||||
vtpmNonce := [vtpm.Nonce]byte{}
|
||||
attestationType := attestation.SNP
|
||||
mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, attestationType).Return(attestationData, nil)
|
||||
@@ -298,8 +297,8 @@ func TestValidateNonce(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "valid TEE nonce",
|
||||
nonce: make([]byte, quoteprovider.Nonce),
|
||||
maxLen: quoteprovider.Nonce,
|
||||
nonce: make([]byte, vtpm.SEVNonce),
|
||||
maxLen: vtpm.SEVNonce,
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
@@ -310,8 +309,8 @@ func TestValidateNonce(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "TEE nonce too long",
|
||||
nonce: make([]byte, quoteprovider.Nonce+1),
|
||||
maxLen: quoteprovider.Nonce,
|
||||
nonce: make([]byte, vtpm.SEVNonce+1),
|
||||
maxLen: vtpm.SEVNonce,
|
||||
shouldError: true,
|
||||
expectedErr: ErrTEENonceLength,
|
||||
},
|
||||
@@ -326,8 +325,8 @@ func TestValidateNonce(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.maxLen == quoteprovider.Nonce {
|
||||
var target [quoteprovider.Nonce]byte
|
||||
if tt.maxLen == vtpm.SEVNonce {
|
||||
var target [vtpm.SEVNonce]byte
|
||||
err := validateNonce(tt.nonce, tt.maxLen, &target)
|
||||
if tt.shouldError {
|
||||
assert.Error(t, err)
|
||||
@@ -388,7 +387,7 @@ func TestEncodeResultResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDecodeAttestationRequest(t *testing.T) {
|
||||
teeNonce := make([]byte, quoteprovider.Nonce)
|
||||
teeNonce := make([]byte, vtpm.SEVNonce)
|
||||
vtpmNonce := make([]byte, vtpm.Nonce)
|
||||
|
||||
req := &agent.AttestationRequest{
|
||||
@@ -406,7 +405,7 @@ func TestDecodeAttestationRequest(t *testing.T) {
|
||||
|
||||
func TestDecodeAttestationRequestWithInvalidNonce(t *testing.T) {
|
||||
// Test with TEE nonce too long
|
||||
teeNonce := make([]byte, quoteprovider.Nonce+1)
|
||||
teeNonce := make([]byte, vtpm.SEVNonce+1)
|
||||
req := &agent.AttestationRequest{TeeNonce: teeNonce}
|
||||
|
||||
_, err := decodeAttestationRequest(context.Background(), req)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
@@ -105,7 +104,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e
|
||||
return lm.svc.Result(ctx)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) {
|
||||
func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) (response []byte, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin))
|
||||
if err != nil {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
)
|
||||
|
||||
@@ -91,7 +90,7 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) {
|
||||
return ms.svc.Result(ctx)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "attestation").Add(1)
|
||||
ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds())
|
||||
|
||||
+2
-3
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
|
||||
@@ -120,7 +119,7 @@ type Service interface {
|
||||
Algo(ctx context.Context, algorithm Algorithm) error
|
||||
Data(ctx context.Context, dataset Dataset) error
|
||||
Result(ctx context.Context) ([]byte, error)
|
||||
Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
|
||||
Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error)
|
||||
IMAMeasurements(ctx context.Context) ([]byte, []byte, error)
|
||||
AzureAttestationToken(ctx context.Context, nonce [vtpm.Nonce]byte) ([]byte, error)
|
||||
State() string
|
||||
@@ -418,7 +417,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
|
||||
return as.result, as.runError
|
||||
}
|
||||
|
||||
func (as *agentService) Attestation(ctx context.Context, reportData [quoteprovider.Nonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
func (as *agentService) Attestation(ctx context.Context, reportData [vtpm.SEVNonce]byte, nonce [vtpm.Nonce]byte, attType attestation.PlatformType) ([]byte, error) {
|
||||
rawQuote, err := as.attestationClient.GetAttestation(ctx, reportData, nonce, attType)
|
||||
if err != nil {
|
||||
return []byte{}, errors.Wrap(ErrAttestationFailed, err)
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
@@ -336,7 +335,7 @@ func TestAttestation(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
reportData [quoteprovider.Nonce]byte
|
||||
reportData [vtpm.SEVNonce]byte
|
||||
nonce [vtpm.Nonce]byte
|
||||
rawQuote []uint8
|
||||
platform attestation.PlatformType
|
||||
@@ -457,8 +456,8 @@ func TestAzureAttestationToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateReportData() [quoteprovider.Nonce]byte {
|
||||
bytes := make([]byte, quoteprovider.Nonce)
|
||||
func generateReportData() [vtpm.SEVNonce]byte {
|
||||
bytes := make([]byte, vtpm.SEVNonce)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate random bytes: %v", err)
|
||||
|
||||
Binary file not shown.
+3
-4
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
@@ -171,10 +170,10 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
var fixedReportData [quoteprovider.Nonce]byte
|
||||
var fixedReportData [vtpm.SEVNonce]byte
|
||||
if attType == attestation.SNP || attType == attestation.SNPvTPM {
|
||||
if len(teeNonce) > quoteprovider.Nonce {
|
||||
msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", quoteprovider.Nonce)
|
||||
if len(teeNonce) > vtpm.SEVNonce {
|
||||
msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", vtpm.SEVNonce)
|
||||
cmd.Println(msg)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
)
|
||||
@@ -37,8 +36,8 @@ func TestNewAttestationCmd(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
|
||||
reportData := bytes.Repeat([]byte{0x01}, quoteprovider.Nonce)
|
||||
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(reportData), mock.Anything).Return(nil)
|
||||
reportData := bytes.Repeat([]byte{0x01}, vtpm.SEVNonce)
|
||||
mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(reportData), mock.Anything).Return(nil)
|
||||
|
||||
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
|
||||
err := cmd.Execute()
|
||||
@@ -50,7 +49,7 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
validattestation, err := os.ReadFile("../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
||||
teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))
|
||||
vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
|
||||
tokenNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce))
|
||||
|
||||
@@ -184,7 +183,7 @@ func TestNewGetAttestationCmd(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
|
||||
mockSDK.On("Attestation", mock.Anything, [quoteprovider.Nonce]byte(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||
mockSDK.On("Attestation", mock.Anything, [vtpm.SEVNonce]byte(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(4).(*os.File).Write(tc.mockResponse)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
@@ -891,7 +890,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "TEE nonce too large",
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))},
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
|
||||
@@ -912,7 +911,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "successful TDX attestation",
|
||||
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
||||
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(nil).Run(func(args mock.Arguments) {
|
||||
@@ -925,7 +924,7 @@ func TestGetAttestationCmdEdgeCases(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "file creation error",
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Error creating attestation file",
|
||||
@@ -1380,7 +1379,7 @@ func TestContextCancellation(t *testing.T) {
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
||||
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.SEVNonce))
|
||||
cmd.SetArgs([]string{"snp", "--tee", teeNonceHex})
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -89,7 +88,7 @@ func main() {
|
||||
}
|
||||
|
||||
if ccPlatform == attestation.SNP || ccPlatform == attestation.SNPvTPM {
|
||||
if err := quoteprovider.FetchCertificates(uint(cfg.Vmpl)); err != nil {
|
||||
if err := vtpm.FetchSEVCertificates(uint(cfg.Vmpl)); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to fetch certificates: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3
|
||||
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY
|
||||
-----END PRIVATE KEY-----
|
||||
@@ -0,0 +1,3 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw=
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -130,7 +129,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
return errors.Wrap(fmt.Errorf("failed to convert TEE report to proto"), err)
|
||||
}
|
||||
|
||||
return quoteprovider.VerifyAttestationReportTLS(attestationReport, teeNonce, a.Policy)
|
||||
return vtpm.VerifySEVAttestationReportTLS(attestationReport, teeNonce, a.Policy)
|
||||
}
|
||||
|
||||
func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
@@ -148,7 +147,7 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []
|
||||
}
|
||||
|
||||
snpReport := quote.GetSevSnpAttestation()
|
||||
if err = quoteprovider.VerifyAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
|
||||
if err = vtpm.VerifySEVAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
|
||||
return fmt.Errorf("failed to verify vTPM attestation report: %w", err)
|
||||
}
|
||||
|
||||
@@ -266,7 +265,7 @@ func GenerateAttestationPolicy(token, product string, policy uint64) (*attestati
|
||||
return nil, fmt.Errorf("failed to decode reportID: %w", err)
|
||||
}
|
||||
|
||||
sevSnpProduct := quoteprovider.GetProductName(product)
|
||||
sevSnpProduct := vtpm.GetSEVProductName(product)
|
||||
|
||||
return &attestation.Config{
|
||||
Config: &check.Config{
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !embed
|
||||
|
||||
package quoteprovider
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer func() {
|
||||
os.Setenv("HOME", originalHome)
|
||||
}()
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
os.Setenv("HOME", tempDir)
|
||||
|
||||
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
|
||||
err = os.MkdirAll(cocosDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleContent := []byte("mock ASK ARK bundle")
|
||||
bundlePath := path.Join(cocosDir, arkAskBundleName)
|
||||
err = os.WriteFile(bundlePath, bundleContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
setupFunc func()
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty attestation - creates new chain",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: nil,
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "Attestation with existing chain - no changes needed",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("existing ASK cert"),
|
||||
ArkCert: []byte("existing ARK cert"),
|
||||
},
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Attestation with empty chain - tries to load from file",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "No bundle file exists - no error",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {
|
||||
os.Remove(bundlePath)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Setenv("HOME", tempDir)
|
||||
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write bundle file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tt.setupFunc()
|
||||
|
||||
err := fillInAttestationLocal(tt.attestation, config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProductName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
product string
|
||||
expected sevsnp.SevProduct_SevProductName
|
||||
}{
|
||||
{
|
||||
name: "Milan product",
|
||||
product: sevSnpProductMilan,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
{
|
||||
name: "Genoa product",
|
||||
product: sevSnpProductGenoa,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
|
||||
},
|
||||
{
|
||||
name: "Unknown product",
|
||||
product: "UnknownProduct",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Empty product",
|
||||
product: "",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive - milan lowercase",
|
||||
product: "milan",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetProductName(tt.product)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Invalid product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: "InvalidProduct",
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "product name must be",
|
||||
},
|
||||
{
|
||||
name: "Valid Milan product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Valid Genoa product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductGenoa,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Config with existing product policy",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := verifyReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Basic validation test",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
{
|
||||
name: "Validation with report data",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
ReportData: []byte("test report datatest report datatest report datatest report data"),
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reportData []byte
|
||||
vmpl uint
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Report data too large",
|
||||
reportData: make([]byte, Nonce+1),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Valid report data size",
|
||||
reportData: make([]byte, 32),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Maximum valid report data size",
|
||||
reportData: make([]byte, Nonce),
|
||||
vmpl: 1,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Empty report data",
|
||||
reportData: []byte{},
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FetchAttestation(tt.reportData, tt.vmpl)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, result)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLeveledQuoteProvider(t *testing.T) {
|
||||
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
|
||||
provider, err := GetLeveledQuoteProvider()
|
||||
|
||||
if err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
//go:build !embed
|
||||
|
||||
package quoteprovider
|
||||
package vtpm
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -32,7 +32,7 @@ const (
|
||||
cocosDirectory = "/cocos"
|
||||
arkAskBundleName = "ask_ark.pem"
|
||||
vcekName = "vcek.pem"
|
||||
Nonce = 64
|
||||
SEVNonce = 64
|
||||
sevSnpProductMilan = "Milan"
|
||||
sevSnpProductGenoa = "Genoa"
|
||||
)
|
||||
@@ -43,9 +43,9 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
|
||||
ErrAttVerification = errors.New("attestation verification failed")
|
||||
errAttValidation = errors.New("attestation validation failed")
|
||||
ErrSEVProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
|
||||
ErrSEVAttVerification = errors.New("attestation verification failed")
|
||||
errSEVAttValidation = errors.New("attestation validation failed")
|
||||
)
|
||||
|
||||
func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error {
|
||||
@@ -77,16 +77,17 @@ func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyReport verifies the SEV-SNP attestation report.
|
||||
func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrAttVerification, err))
|
||||
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrSEVAttVerification, err))
|
||||
}
|
||||
|
||||
if cfg.Policy.Product == nil {
|
||||
productName := GetProductName(cfg.RootOfTrust.ProductLine)
|
||||
productName := GetSEVProductName(cfg.RootOfTrust.ProductLine)
|
||||
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
|
||||
return ErrProductLine
|
||||
return ErrSEVProductLine
|
||||
}
|
||||
|
||||
sopts.Product = &sevsnp.SevProduct{
|
||||
@@ -107,30 +108,33 @@ func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
}
|
||||
|
||||
if err := verify.SnpAttestation(attestationPB, sopts); err != nil {
|
||||
return errors.Wrap(ErrAttVerification, err)
|
||||
return errors.Wrap(ErrSEVAttVerification, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateReport validates the SEV-SNP attestation report against policy.
|
||||
func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
opts, err := validate.PolicyToOptions(cfg.Policy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrAttVerification, err))
|
||||
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrSEVAttVerification, err))
|
||||
}
|
||||
|
||||
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
|
||||
return errors.Wrap(errAttValidation, err)
|
||||
return errors.Wrap(errSEVAttValidation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
|
||||
// getLeveledQuoteProvider returns a leveled quote provider for SEV-SNP.
|
||||
func getLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
|
||||
return client.GetLeveledQuoteProvider()
|
||||
}
|
||||
|
||||
func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
|
||||
// VerifySEVAttestationReportTLS verifies a SEV-SNP attestation report for TLS (exported for azure package).
|
||||
func VerifySEVAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
|
||||
config := policy.Config
|
||||
|
||||
// Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report.
|
||||
@@ -141,10 +145,11 @@ func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []
|
||||
config.Policy.ReportData = reportData[:]
|
||||
}
|
||||
|
||||
return VerifyAndValidate(attestationPB, config)
|
||||
return verifySEVAndValidate(attestationPB, config)
|
||||
}
|
||||
|
||||
func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
// verifySEVAndValidate performs both verification and validation of a SEV-SNP attestation.
|
||||
func verifySEVAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
logger.Init("", false, false, io.Discard)
|
||||
|
||||
if err := verifyReport(attestationPB, cfg); err != nil {
|
||||
@@ -158,15 +163,16 @@ func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
|
||||
var reportData [Nonce]byte
|
||||
// fetchSEVAttestation fetches a SEV-SNP attestation report.
|
||||
func fetchSEVAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
|
||||
var reportData [SEVNonce]byte
|
||||
|
||||
qp, err := GetLeveledQuoteProvider()
|
||||
qp, err := getLeveledQuoteProvider()
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("could not get quote provider")
|
||||
}
|
||||
|
||||
if len(reportData) > Nonce {
|
||||
if len(reportData) > SEVNonce {
|
||||
return []byte{}, fmt.Errorf("attestation report size mismatch")
|
||||
}
|
||||
copy(reportData[:], reportDataSlice)
|
||||
@@ -206,7 +212,8 @@ func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func GetProductName(product string) sevsnp.SevProduct_SevProductName {
|
||||
// GetSEVProductName maps a product string to a SEV product name.
|
||||
func GetSEVProductName(product string) sevsnp.SevProduct_SevProductName {
|
||||
switch product {
|
||||
case sevSnpProductMilan:
|
||||
return sevsnp.SevProduct_SEV_PRODUCT_MILAN
|
||||
@@ -217,6 +224,7 @@ func GetProductName(product string) sevsnp.SevProduct_SevProductName {
|
||||
}
|
||||
}
|
||||
|
||||
// derToPem converts DER-encoded certificate to PEM format.
|
||||
func derToPem(der []byte) []byte {
|
||||
// Try to parse to make sure it's a certificate
|
||||
if _, err := x509.ParseCertificate(der); err != nil {
|
||||
@@ -226,15 +234,16 @@ func derToPem(der []byte) []byte {
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
}
|
||||
|
||||
func FetchCertificates(vmpl uint) error {
|
||||
var reportData [Nonce]byte
|
||||
// FetchSEVCertificates fetches SEV-SNP certificates from KDS.
|
||||
func FetchSEVCertificates(vmpl uint) error {
|
||||
var reportData [SEVNonce]byte
|
||||
|
||||
qp, err := GetLeveledQuoteProvider()
|
||||
qp, err := getLeveledQuoteProvider()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get quote provider")
|
||||
}
|
||||
|
||||
if len(reportData) > Nonce {
|
||||
if len(reportData) > SEVNonce {
|
||||
return fmt.Errorf("attestation report size mismatch")
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/google/go-tpm/tpmutil"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
@@ -120,7 +119,7 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
return quoteprovider.FetchAttestation(teeNonce, v.vmpl)
|
||||
return fetchSEVAttestation(teeNonce, v.vmpl)
|
||||
}
|
||||
|
||||
func (v provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
|
||||
@@ -171,7 +170,7 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
}
|
||||
|
||||
attestationReport := sevsnp.Attestation{Report: attestReport, CertificateChain: nil}
|
||||
return quoteprovider.VerifyAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
|
||||
return VerifySEVAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
|
||||
}
|
||||
|
||||
func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
@@ -234,7 +233,7 @@ func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Write
|
||||
|
||||
attestData := sha3.Sum512(nonce)
|
||||
|
||||
if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
|
||||
if err := VerifySEVAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
|
||||
return fmt.Errorf("failed to verify TEE attestation report: %v", err)
|
||||
}
|
||||
|
||||
@@ -336,7 +335,7 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint)
|
||||
|
||||
attestData := sha3.Sum512(teeNonce)
|
||||
|
||||
rawTeeAttestation, err := quoteprovider.FetchAttestation(attestData[:], vmpl)
|
||||
rawTeeAttestation, err := fetchSEVAttestation(attestData[:], vmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch TEE attestation report: %v", err)
|
||||
}
|
||||
|
||||
@@ -22,12 +22,9 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const sevSnpProductMilan = "Milan"
|
||||
|
||||
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
||||
|
||||
type mockTPM struct {
|
||||
@@ -679,13 +676,13 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
name: "Valid attestation, distorted signature",
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: quoteprovider.ErrAttVerification,
|
||||
err: ErrSEVAttVerification,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
@@ -713,13 +710,13 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
|
||||
name: "Valid attestation, unknown product",
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: quoteprovider.ErrProductLine,
|
||||
err: ErrSEVProductLine,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
@@ -752,7 +749,7 @@ func TestVerifyAttestationReportSuccess(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
@@ -780,13 +777,13 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
|
||||
name: "Valid attestation, malformed policy (measurement)",
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: quoteprovider.ErrAttVerification,
|
||||
err: ErrSEVAttVerification,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
|
||||
+1
-1
@@ -2,5 +2,5 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package clients contains the domain concept definitions needed to support
|
||||
// HTTP/gRPC Client functionality.
|
||||
// client configuration for HTTP/gRPC connections.
|
||||
package clients
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() {
|
||||
if client.Secure() != tls.WithMATLS.String() && client.Secure() != tls.WithATLS.String() && client.Secure() != tls.WithTLS.String() {
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
@@ -119,7 +120,7 @@ func TestNewClient(t *testing.T) {
|
||||
ServerCAFile: "nonexistent.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrFailedToLoadRootCA,
|
||||
err: tls.ErrFailedToLoadRootCA,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ClientCert",
|
||||
@@ -130,7 +131,7 @@ func TestNewClient(t *testing.T) {
|
||||
ClientKey: clientKeyFile,
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrFailedToLoadClientCertKey,
|
||||
err: tls.ErrFailedToLoadClientCertKey,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ClientKey",
|
||||
@@ -141,7 +142,7 @@ func TestNewClient(t *testing.T) {
|
||||
ClientKey: "nonexistent.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrFailedToLoadClientCertKey,
|
||||
err: tls.ErrFailedToLoadClientCertKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -169,32 +170,32 @@ func TestNewClient(t *testing.T) {
|
||||
func TestClientSecure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
secure clients.Security
|
||||
secure tls.Security
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Without TLS",
|
||||
secure: clients.WithoutTLS,
|
||||
secure: tls.WithoutTLS,
|
||||
expected: "without TLS",
|
||||
},
|
||||
{
|
||||
name: "With TLS",
|
||||
secure: clients.WithTLS,
|
||||
secure: tls.WithTLS,
|
||||
expected: "with TLS",
|
||||
},
|
||||
{
|
||||
name: "With mTLS",
|
||||
secure: clients.WithMTLS,
|
||||
secure: tls.WithMTLS,
|
||||
expected: "with mTLS",
|
||||
},
|
||||
{
|
||||
name: "With aTLS",
|
||||
secure: clients.WithATLS,
|
||||
secure: tls.WithATLS,
|
||||
expected: "with aTLS",
|
||||
},
|
||||
{
|
||||
name: "With maTLS",
|
||||
secure: clients.WithMATLS,
|
||||
secure: tls.WithMATLS,
|
||||
expected: "with maTLS",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package grpc
|
||||
import (
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -26,7 +27,7 @@ type Client interface {
|
||||
type client struct {
|
||||
*grpc.ClientConn
|
||||
cfg clients.ClientConfiguration
|
||||
security clients.Security
|
||||
security tls.Security
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
@@ -59,14 +60,19 @@ func (c *client) Connection() *grpc.ClientConn {
|
||||
return c.ClientConn
|
||||
}
|
||||
|
||||
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) {
|
||||
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, error) {
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
|
||||
}
|
||||
security := clients.WithoutTLS
|
||||
security := tls.WithoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
|
||||
result, err := clients.LoadATLSConfig(agcfg)
|
||||
result, err := tls.LoadATLSConfig(
|
||||
agcfg.AttestationPolicy,
|
||||
agcfg.ServerCAFile,
|
||||
agcfg.ClientCert,
|
||||
agcfg.ClientKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
@@ -90,13 +96,13 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Securit
|
||||
return conn, security, nil
|
||||
}
|
||||
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) {
|
||||
result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey)
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) {
|
||||
result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, clients.WithoutTLS, err
|
||||
return nil, tls.WithoutTLS, err
|
||||
}
|
||||
|
||||
if result.Security == clients.WithoutTLS || result.Config == nil {
|
||||
if result.Security == tls.WithoutTLS || result.Config == nil {
|
||||
return insecure.NewCredentials(), result.Security, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
@@ -19,7 +20,7 @@ type Client interface {
|
||||
type client struct {
|
||||
transport *http.Transport
|
||||
cfg clients.ClientConfiguration
|
||||
security clients.Security
|
||||
security tls.Security
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
@@ -49,17 +50,22 @@ func (c *client) Timeout() time.Duration {
|
||||
return c.cfg.Config().Timeout
|
||||
}
|
||||
|
||||
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) {
|
||||
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Security, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
security := clients.WithoutTLS
|
||||
security := tls.WithoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
|
||||
result, err := clients.LoadATLSConfig(*agcfg)
|
||||
result, err := tls.LoadATLSConfig(
|
||||
agcfg.AttestationPolicy,
|
||||
agcfg.ServerCAFile,
|
||||
agcfg.ClientCert,
|
||||
agcfg.ClientKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
@@ -69,12 +75,12 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.
|
||||
} else {
|
||||
conf := cfg.Config()
|
||||
|
||||
result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
result, err := tls.LoadBasicConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
if result.Security != clients.WithoutTLS {
|
||||
if result.Security != tls.WithoutTLS {
|
||||
transport.TLSClientConfig = result.Config
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
|
||||
func TestConfig_Configuration(t *testing.T) {
|
||||
@@ -144,7 +145,7 @@ func TestClient_Secure(t *testing.T) {
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
expected: clients.WithoutTLS.String(),
|
||||
expected: tls.WithoutTLS.String(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -183,7 +184,7 @@ func TestCreateTransport_DefaultSettings(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Equal(t, 100, transport.MaxIdleConns)
|
||||
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
|
||||
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
|
||||
@@ -205,7 +206,7 @@ func TestCreateTransport_ATLSError(t *testing.T) {
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "failed to stat attestation policy")
|
||||
}
|
||||
|
||||
@@ -220,7 +221,7 @@ func TestCreateTransport_BasicTLSError(t *testing.T) {
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "failed to load root ca file")
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
"golang.org/x/crypto/sha3"
|
||||
@@ -388,7 +387,7 @@ func TestAttestation(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
userKey any
|
||||
reportData [quoteprovider.Nonce]byte
|
||||
reportData [vtpm.SEVNonce]byte
|
||||
nonce [vtpm.Nonce]byte
|
||||
response *agent.AttestationResponse
|
||||
svcRes []byte
|
||||
@@ -397,7 +396,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "fetch attestation report successfully",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
reportData: [vtpm.SEVNonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
@@ -408,7 +407,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "fetch attestation report with different key type",
|
||||
userKey: resultConsumer1Key,
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
reportData: [vtpm.SEVNonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
@@ -419,7 +418,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "failed to fetch attestation report",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
reportData: [vtpm.SEVNonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: []byte{},
|
||||
@@ -429,7 +428,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "invalid report data",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [quoteprovider.Nonce]byte{},
|
||||
reportData: [vtpm.SEVNonce]byte{},
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: []byte{},
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package tls provides TLS configuration utilities for client connections.
|
||||
// It supports standard TLS, mutual TLS (mTLS), and attested TLS (aTLS).
|
||||
package tls
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -53,20 +53,20 @@ var (
|
||||
errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file")
|
||||
)
|
||||
|
||||
// TLSResult contains the result of TLS configuration.
|
||||
type TLSResult struct {
|
||||
// Result contains the result of TLS configuration.
|
||||
type Result struct {
|
||||
Config *tls.Config
|
||||
Security Security
|
||||
}
|
||||
|
||||
// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS).
|
||||
func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) {
|
||||
// LoadBasicConfig loads standard TLS configuration (TLS/mTLS).
|
||||
func LoadBasicConfig(serverCAFile, clientCert, clientKey string) (*Result, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
security := WithoutTLS
|
||||
|
||||
// If no TLS configuration is provided, return nil config (no TLS)
|
||||
if serverCAFile == "" && clientCert == "" && clientKey == "" {
|
||||
return &TLSResult{Config: nil, Security: security}, nil
|
||||
return &Result{Config: nil, Security: security}, nil
|
||||
}
|
||||
|
||||
if serverCAFile != "" {
|
||||
@@ -96,14 +96,15 @@ func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult,
|
||||
security = WithMTLS
|
||||
}
|
||||
|
||||
return &TLSResult{Config: tlsConfig, Security: security}, nil
|
||||
return &Result{Config: tlsConfig, Security: security}, nil
|
||||
}
|
||||
|
||||
// LoadATLSConfig configures Attested TLS.
|
||||
func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
// Parameters are passed individually to avoid circular dependencies with the clients package.
|
||||
func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey string) (*Result, error) {
|
||||
security := WithATLS
|
||||
|
||||
info, err := os.Stat(cfg.AttestationPolicy)
|
||||
info, err := os.Stat(attestationPolicy)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err)
|
||||
}
|
||||
@@ -112,12 +113,12 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
return nil, errAttestationPolicyIrregular
|
||||
}
|
||||
|
||||
attestation.AttestationPolicyPath = cfg.AttestationPolicy
|
||||
attestation.AttestationPolicyPath = attestationPolicy
|
||||
|
||||
var rootCAs *x509.CertPool
|
||||
|
||||
if cfg.ServerCAFile != "" {
|
||||
rootCAs, err = loadRootCAs(cfg.ServerCAFile)
|
||||
if serverCAFile != "" {
|
||||
rootCAs, err = loadRootCAs(serverCAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -141,8 +142,8 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.ClientCert != "" || cfg.ClientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if clientCert != "" || clientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
|
||||
}
|
||||
@@ -150,7 +151,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
return &TLSResult{Config: tlsConfig, Security: security}, nil
|
||||
return &Result{Config: tlsConfig, Security: security}, nil
|
||||
}
|
||||
|
||||
// loadRootCAs loads root CA certificates from a file.
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -64,7 +64,7 @@ func TestSecurity_String(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBasicTLSConfig(t *testing.T) {
|
||||
func TestLoadBasicConfig(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -164,7 +164,7 @@ func TestLoadBasicTLSConfig(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
result, err := LoadBasicConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
@@ -201,100 +201,86 @@ func TestLoadATLSConfig(t *testing.T) {
|
||||
require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AttestedClientConfig
|
||||
expectedSec Security
|
||||
expectError bool
|
||||
errorMsg string
|
||||
name string
|
||||
attestationPolicy string
|
||||
serverCAFile string
|
||||
clientCert string
|
||||
clientKey string
|
||||
expectedSec Security
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "ValidATLSConfig",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: "",
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
name: "ValidATLSConfig",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidMATLSConfig",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: caFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithMATLS,
|
||||
expectError: false,
|
||||
name: "ValidMATLSConfig",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: caFile,
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithMATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidATLSWithClientCert",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ClientCert: certFile,
|
||||
ClientKey: keyFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
name: "ValidATLSWithClientCert",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: "",
|
||||
clientCert: certFile,
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "NonexistentPolicyFile",
|
||||
config: AttestedClientConfig{
|
||||
AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to stat attestation policy file",
|
||||
name: "NonexistentPolicyFile",
|
||||
attestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to stat attestation policy file",
|
||||
},
|
||||
{
|
||||
name: "PolicyFileIsDirectory",
|
||||
config: AttestedClientConfig{
|
||||
AttestationPolicy: tmpDir, // Directory instead of file
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "attestation policy file is not a regular file",
|
||||
name: "PolicyFileIsDirectory",
|
||||
attestationPolicy: tmpDir, // Directory instead of file
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "attestation policy file is not a regular file",
|
||||
},
|
||||
{
|
||||
name: "InvalidCAFile",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to read certificate file",
|
||||
name: "InvalidCAFile",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to read certificate file",
|
||||
},
|
||||
{
|
||||
name: "InvalidClientCert",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ClientCert: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
ClientKey: keyFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
name: "InvalidClientCert",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: "",
|
||||
clientCert: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LoadATLSConfig(tt.config)
|
||||
result, err := LoadATLSConfig(tt.attestationPolicy, tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
Reference in New Issue
Block a user