mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-525-487 - Refactor attestation and atls (#562)
* Refactor attestation handling to remove quoteprovider dependency - Removed references to quoteprovider in various files, replacing them with vtpm where necessary. - Updated function signatures and implementations to use SEVNonce instead of quoteprovider.Nonce. - Introduced new vtpm package to handle SEV-related attestation logic, including fetching and verifying attestation reports. - Adjusted tests to reflect changes in the attestation logic and ensure compatibility with the new structure. - Deleted the now redundant quoteprovider/sev_test.go file. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Add veraison/go-cose dependency to go.mod Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Introduce TLS package for enhanced security configuration and refactor client code to utilize new TLS utilities Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
de50b6d2d4
commit
207bfd99af
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -130,7 +129,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
return errors.Wrap(fmt.Errorf("failed to convert TEE report to proto"), err)
|
||||
}
|
||||
|
||||
return quoteprovider.VerifyAttestationReportTLS(attestationReport, teeNonce, a.Policy)
|
||||
return vtpm.VerifySEVAttestationReportTLS(attestationReport, teeNonce, a.Policy)
|
||||
}
|
||||
|
||||
func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
@@ -148,7 +147,7 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []
|
||||
}
|
||||
|
||||
snpReport := quote.GetSevSnpAttestation()
|
||||
if err = quoteprovider.VerifyAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
|
||||
if err = vtpm.VerifySEVAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
|
||||
return fmt.Errorf("failed to verify vTPM attestation report: %w", err)
|
||||
}
|
||||
|
||||
@@ -266,7 +265,7 @@ func GenerateAttestationPolicy(token, product string, policy uint64) (*attestati
|
||||
return nil, fmt.Errorf("failed to decode reportID: %w", err)
|
||||
}
|
||||
|
||||
sevSnpProduct := quoteprovider.GetProductName(product)
|
||||
sevSnpProduct := vtpm.GetSEVProductName(product)
|
||||
|
||||
return &attestation.Config{
|
||||
Config: &check.Config{
|
||||
|
||||
@@ -1,377 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !embed
|
||||
|
||||
package quoteprovider
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer func() {
|
||||
os.Setenv("HOME", originalHome)
|
||||
}()
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
os.Setenv("HOME", tempDir)
|
||||
|
||||
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
|
||||
err = os.MkdirAll(cocosDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleContent := []byte("mock ASK ARK bundle")
|
||||
bundlePath := path.Join(cocosDir, arkAskBundleName)
|
||||
err = os.WriteFile(bundlePath, bundleContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
setupFunc func()
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty attestation - creates new chain",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: nil,
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "Attestation with existing chain - no changes needed",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("existing ASK cert"),
|
||||
ArkCert: []byte("existing ARK cert"),
|
||||
},
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Attestation with empty chain - tries to load from file",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "No bundle file exists - no error",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {
|
||||
os.Remove(bundlePath)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Setenv("HOME", tempDir)
|
||||
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write bundle file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tt.setupFunc()
|
||||
|
||||
err := fillInAttestationLocal(tt.attestation, config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProductName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
product string
|
||||
expected sevsnp.SevProduct_SevProductName
|
||||
}{
|
||||
{
|
||||
name: "Milan product",
|
||||
product: sevSnpProductMilan,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
{
|
||||
name: "Genoa product",
|
||||
product: sevSnpProductGenoa,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
|
||||
},
|
||||
{
|
||||
name: "Unknown product",
|
||||
product: "UnknownProduct",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Empty product",
|
||||
product: "",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive - milan lowercase",
|
||||
product: "milan",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetProductName(tt.product)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Invalid product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: "InvalidProduct",
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "product name must be",
|
||||
},
|
||||
{
|
||||
name: "Valid Milan product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Valid Genoa product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductGenoa,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Config with existing product policy",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := verifyReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Basic validation test",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
{
|
||||
name: "Validation with report data",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
ReportData: []byte("test report datatest report datatest report datatest report data"),
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reportData []byte
|
||||
vmpl uint
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Report data too large",
|
||||
reportData: make([]byte, Nonce+1),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Valid report data size",
|
||||
reportData: make([]byte, 32),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Maximum valid report data size",
|
||||
reportData: make([]byte, Nonce),
|
||||
vmpl: 1,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Empty report data",
|
||||
reportData: []byte{},
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FetchAttestation(tt.reportData, tt.vmpl)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, result)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLeveledQuoteProvider(t *testing.T) {
|
||||
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
|
||||
provider, err := GetLeveledQuoteProvider()
|
||||
|
||||
if err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
//go:build !embed
|
||||
|
||||
package quoteprovider
|
||||
package vtpm
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -32,7 +32,7 @@ const (
|
||||
cocosDirectory = "/cocos"
|
||||
arkAskBundleName = "ask_ark.pem"
|
||||
vcekName = "vcek.pem"
|
||||
Nonce = 64
|
||||
SEVNonce = 64
|
||||
sevSnpProductMilan = "Milan"
|
||||
sevSnpProductGenoa = "Genoa"
|
||||
)
|
||||
@@ -43,9 +43,9 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
|
||||
ErrAttVerification = errors.New("attestation verification failed")
|
||||
errAttValidation = errors.New("attestation validation failed")
|
||||
ErrSEVProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
|
||||
ErrSEVAttVerification = errors.New("attestation verification failed")
|
||||
errSEVAttValidation = errors.New("attestation validation failed")
|
||||
)
|
||||
|
||||
func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error {
|
||||
@@ -77,16 +77,17 @@ func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyReport verifies the SEV-SNP attestation report.
|
||||
func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrAttVerification, err))
|
||||
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrSEVAttVerification, err))
|
||||
}
|
||||
|
||||
if cfg.Policy.Product == nil {
|
||||
productName := GetProductName(cfg.RootOfTrust.ProductLine)
|
||||
productName := GetSEVProductName(cfg.RootOfTrust.ProductLine)
|
||||
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
|
||||
return ErrProductLine
|
||||
return ErrSEVProductLine
|
||||
}
|
||||
|
||||
sopts.Product = &sevsnp.SevProduct{
|
||||
@@ -107,30 +108,33 @@ func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
}
|
||||
|
||||
if err := verify.SnpAttestation(attestationPB, sopts); err != nil {
|
||||
return errors.Wrap(ErrAttVerification, err)
|
||||
return errors.Wrap(ErrSEVAttVerification, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateReport validates the SEV-SNP attestation report against policy.
|
||||
func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
opts, err := validate.PolicyToOptions(cfg.Policy)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrAttVerification, err))
|
||||
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrSEVAttVerification, err))
|
||||
}
|
||||
|
||||
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
|
||||
return errors.Wrap(errAttValidation, err)
|
||||
return errors.Wrap(errSEVAttValidation, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
|
||||
// getLeveledQuoteProvider returns a leveled quote provider for SEV-SNP.
|
||||
func getLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
|
||||
return client.GetLeveledQuoteProvider()
|
||||
}
|
||||
|
||||
func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
|
||||
// VerifySEVAttestationReportTLS verifies a SEV-SNP attestation report for TLS (exported for azure package).
|
||||
func VerifySEVAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
|
||||
config := policy.Config
|
||||
|
||||
// Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report.
|
||||
@@ -141,10 +145,11 @@ func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []
|
||||
config.Policy.ReportData = reportData[:]
|
||||
}
|
||||
|
||||
return VerifyAndValidate(attestationPB, config)
|
||||
return verifySEVAndValidate(attestationPB, config)
|
||||
}
|
||||
|
||||
func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
// verifySEVAndValidate performs both verification and validation of a SEV-SNP attestation.
|
||||
func verifySEVAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
|
||||
logger.Init("", false, false, io.Discard)
|
||||
|
||||
if err := verifyReport(attestationPB, cfg); err != nil {
|
||||
@@ -158,15 +163,16 @@ func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
|
||||
var reportData [Nonce]byte
|
||||
// fetchSEVAttestation fetches a SEV-SNP attestation report.
|
||||
func fetchSEVAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
|
||||
var reportData [SEVNonce]byte
|
||||
|
||||
qp, err := GetLeveledQuoteProvider()
|
||||
qp, err := getLeveledQuoteProvider()
|
||||
if err != nil {
|
||||
return []byte{}, fmt.Errorf("could not get quote provider")
|
||||
}
|
||||
|
||||
if len(reportData) > Nonce {
|
||||
if len(reportData) > SEVNonce {
|
||||
return []byte{}, fmt.Errorf("attestation report size mismatch")
|
||||
}
|
||||
copy(reportData[:], reportDataSlice)
|
||||
@@ -206,7 +212,8 @@ func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func GetProductName(product string) sevsnp.SevProduct_SevProductName {
|
||||
// GetSEVProductName maps a product string to a SEV product name.
|
||||
func GetSEVProductName(product string) sevsnp.SevProduct_SevProductName {
|
||||
switch product {
|
||||
case sevSnpProductMilan:
|
||||
return sevsnp.SevProduct_SEV_PRODUCT_MILAN
|
||||
@@ -217,6 +224,7 @@ func GetProductName(product string) sevsnp.SevProduct_SevProductName {
|
||||
}
|
||||
}
|
||||
|
||||
// derToPem converts DER-encoded certificate to PEM format.
|
||||
func derToPem(der []byte) []byte {
|
||||
// Try to parse to make sure it's a certificate
|
||||
if _, err := x509.ParseCertificate(der); err != nil {
|
||||
@@ -226,15 +234,16 @@ func derToPem(der []byte) []byte {
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
}
|
||||
|
||||
func FetchCertificates(vmpl uint) error {
|
||||
var reportData [Nonce]byte
|
||||
// FetchSEVCertificates fetches SEV-SNP certificates from KDS.
|
||||
func FetchSEVCertificates(vmpl uint) error {
|
||||
var reportData [SEVNonce]byte
|
||||
|
||||
qp, err := GetLeveledQuoteProvider()
|
||||
qp, err := getLeveledQuoteProvider()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get quote provider")
|
||||
}
|
||||
|
||||
if len(reportData) > Nonce {
|
||||
if len(reportData) > SEVNonce {
|
||||
return fmt.Errorf("attestation report size mismatch")
|
||||
}
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/google/go-tpm/tpmutil"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
@@ -120,7 +119,7 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
return quoteprovider.FetchAttestation(teeNonce, v.vmpl)
|
||||
return fetchSEVAttestation(teeNonce, v.vmpl)
|
||||
}
|
||||
|
||||
func (v provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
|
||||
@@ -171,7 +170,7 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
}
|
||||
|
||||
attestationReport := sevsnp.Attestation{Report: attestReport, CertificateChain: nil}
|
||||
return quoteprovider.VerifyAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
|
||||
return VerifySEVAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
|
||||
}
|
||||
|
||||
func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
@@ -234,7 +233,7 @@ func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Write
|
||||
|
||||
attestData := sha3.Sum512(nonce)
|
||||
|
||||
if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
|
||||
if err := VerifySEVAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
|
||||
return fmt.Errorf("failed to verify TEE attestation report: %v", err)
|
||||
}
|
||||
|
||||
@@ -336,7 +335,7 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint)
|
||||
|
||||
attestData := sha3.Sum512(teeNonce)
|
||||
|
||||
rawTeeAttestation, err := quoteprovider.FetchAttestation(attestData[:], vmpl)
|
||||
rawTeeAttestation, err := fetchSEVAttestation(attestData[:], vmpl)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch TEE attestation report: %v", err)
|
||||
}
|
||||
|
||||
@@ -22,12 +22,9 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const sevSnpProductMilan = "Milan"
|
||||
|
||||
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
||||
|
||||
type mockTPM struct {
|
||||
@@ -679,13 +676,13 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
name: "Valid attestation, distorted signature",
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: quoteprovider.ErrAttVerification,
|
||||
err: ErrSEVAttVerification,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
@@ -713,13 +710,13 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
|
||||
name: "Valid attestation, unknown product",
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: quoteprovider.ErrProductLine,
|
||||
err: ErrSEVProductLine,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
@@ -752,7 +749,7 @@ func TestVerifyAttestationReportSuccess(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
@@ -780,13 +777,13 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
|
||||
name: "Valid attestation, malformed policy (measurement)",
|
||||
attestationReport: attestationPB,
|
||||
reportData: reportData,
|
||||
err: quoteprovider.ErrAttVerification,
|
||||
err: ErrSEVAttVerification,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
|
||||
+1
-1
@@ -2,5 +2,5 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package clients contains the domain concept definitions needed to support
|
||||
// HTTP/gRPC Client functionality.
|
||||
// client configuration for HTTP/gRPC connections.
|
||||
package clients
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() {
|
||||
if client.Secure() != tls.WithMATLS.String() && client.Secure() != tls.WithATLS.String() && client.Secure() != tls.WithTLS.String() {
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
@@ -119,7 +120,7 @@ func TestNewClient(t *testing.T) {
|
||||
ServerCAFile: "nonexistent.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrFailedToLoadRootCA,
|
||||
err: tls.ErrFailedToLoadRootCA,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ClientCert",
|
||||
@@ -130,7 +131,7 @@ func TestNewClient(t *testing.T) {
|
||||
ClientKey: clientKeyFile,
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrFailedToLoadClientCertKey,
|
||||
err: tls.ErrFailedToLoadClientCertKey,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ClientKey",
|
||||
@@ -141,7 +142,7 @@ func TestNewClient(t *testing.T) {
|
||||
ClientKey: "nonexistent.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
err: clients.ErrFailedToLoadClientCertKey,
|
||||
err: tls.ErrFailedToLoadClientCertKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -169,32 +170,32 @@ func TestNewClient(t *testing.T) {
|
||||
func TestClientSecure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
secure clients.Security
|
||||
secure tls.Security
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Without TLS",
|
||||
secure: clients.WithoutTLS,
|
||||
secure: tls.WithoutTLS,
|
||||
expected: "without TLS",
|
||||
},
|
||||
{
|
||||
name: "With TLS",
|
||||
secure: clients.WithTLS,
|
||||
secure: tls.WithTLS,
|
||||
expected: "with TLS",
|
||||
},
|
||||
{
|
||||
name: "With mTLS",
|
||||
secure: clients.WithMTLS,
|
||||
secure: tls.WithMTLS,
|
||||
expected: "with mTLS",
|
||||
},
|
||||
{
|
||||
name: "With aTLS",
|
||||
secure: clients.WithATLS,
|
||||
secure: tls.WithATLS,
|
||||
expected: "with aTLS",
|
||||
},
|
||||
{
|
||||
name: "With maTLS",
|
||||
secure: clients.WithMATLS,
|
||||
secure: tls.WithMATLS,
|
||||
expected: "with maTLS",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package grpc
|
||||
import (
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -26,7 +27,7 @@ type Client interface {
|
||||
type client struct {
|
||||
*grpc.ClientConn
|
||||
cfg clients.ClientConfiguration
|
||||
security clients.Security
|
||||
security tls.Security
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
@@ -59,14 +60,19 @@ func (c *client) Connection() *grpc.ClientConn {
|
||||
return c.ClientConn
|
||||
}
|
||||
|
||||
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) {
|
||||
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, error) {
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
|
||||
}
|
||||
security := clients.WithoutTLS
|
||||
security := tls.WithoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
|
||||
result, err := clients.LoadATLSConfig(agcfg)
|
||||
result, err := tls.LoadATLSConfig(
|
||||
agcfg.AttestationPolicy,
|
||||
agcfg.ServerCAFile,
|
||||
agcfg.ClientCert,
|
||||
agcfg.ClientKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
@@ -90,13 +96,13 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Securit
|
||||
return conn, security, nil
|
||||
}
|
||||
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) {
|
||||
result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey)
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) {
|
||||
result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, clients.WithoutTLS, err
|
||||
return nil, tls.WithoutTLS, err
|
||||
}
|
||||
|
||||
if result.Security == clients.WithoutTLS || result.Config == nil {
|
||||
if result.Security == tls.WithoutTLS || result.Config == nil {
|
||||
return insecure.NewCredentials(), result.Security, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
@@ -19,7 +20,7 @@ type Client interface {
|
||||
type client struct {
|
||||
transport *http.Transport
|
||||
cfg clients.ClientConfiguration
|
||||
security clients.Security
|
||||
security tls.Security
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
@@ -49,17 +50,22 @@ func (c *client) Timeout() time.Duration {
|
||||
return c.cfg.Config().Timeout
|
||||
}
|
||||
|
||||
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) {
|
||||
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Security, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
security := clients.WithoutTLS
|
||||
security := tls.WithoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
|
||||
result, err := clients.LoadATLSConfig(*agcfg)
|
||||
result, err := tls.LoadATLSConfig(
|
||||
agcfg.AttestationPolicy,
|
||||
agcfg.ServerCAFile,
|
||||
agcfg.ClientCert,
|
||||
agcfg.ClientKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
@@ -69,12 +75,12 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.
|
||||
} else {
|
||||
conf := cfg.Config()
|
||||
|
||||
result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
result, err := tls.LoadBasicConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
if result.Security != clients.WithoutTLS {
|
||||
if result.Security != tls.WithoutTLS {
|
||||
transport.TLSClientConfig = result.Config
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/tls"
|
||||
)
|
||||
|
||||
func TestConfig_Configuration(t *testing.T) {
|
||||
@@ -144,7 +145,7 @@ func TestClient_Secure(t *testing.T) {
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
expected: clients.WithoutTLS.String(),
|
||||
expected: tls.WithoutTLS.String(),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -183,7 +184,7 @@ func TestCreateTransport_DefaultSettings(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Equal(t, 100, transport.MaxIdleConns)
|
||||
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
|
||||
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
|
||||
@@ -205,7 +206,7 @@ func TestCreateTransport_ATLSError(t *testing.T) {
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "failed to stat attestation policy")
|
||||
}
|
||||
|
||||
@@ -220,7 +221,7 @@ func TestCreateTransport_BasicTLSError(t *testing.T) {
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, tls.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "failed to load root ca file")
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
"golang.org/x/crypto/sha3"
|
||||
@@ -388,7 +387,7 @@ func TestAttestation(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
userKey any
|
||||
reportData [quoteprovider.Nonce]byte
|
||||
reportData [vtpm.SEVNonce]byte
|
||||
nonce [vtpm.Nonce]byte
|
||||
response *agent.AttestationResponse
|
||||
svcRes []byte
|
||||
@@ -397,7 +396,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "fetch attestation report successfully",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
reportData: [vtpm.SEVNonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
@@ -408,7 +407,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "fetch attestation report with different key type",
|
||||
userKey: resultConsumer1Key,
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
reportData: [vtpm.SEVNonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
@@ -419,7 +418,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "failed to fetch attestation report",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [quoteprovider.Nonce]byte(reportData),
|
||||
reportData: [vtpm.SEVNonce]byte(reportData),
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: []byte{},
|
||||
@@ -429,7 +428,7 @@ func TestAttestation(t *testing.T) {
|
||||
{
|
||||
name: "invalid report data",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [quoteprovider.Nonce]byte{},
|
||||
reportData: [vtpm.SEVNonce]byte{},
|
||||
nonce: [vtpm.Nonce]byte(nonce),
|
||||
response: &agent.AttestationResponse{
|
||||
File: []byte{},
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package tls provides TLS configuration utilities for client connections.
|
||||
// It supports standard TLS, mutual TLS (mTLS), and attested TLS (aTLS).
|
||||
package tls
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -53,20 +53,20 @@ var (
|
||||
errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file")
|
||||
)
|
||||
|
||||
// TLSResult contains the result of TLS configuration.
|
||||
type TLSResult struct {
|
||||
// Result contains the result of TLS configuration.
|
||||
type Result struct {
|
||||
Config *tls.Config
|
||||
Security Security
|
||||
}
|
||||
|
||||
// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS).
|
||||
func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) {
|
||||
// LoadBasicConfig loads standard TLS configuration (TLS/mTLS).
|
||||
func LoadBasicConfig(serverCAFile, clientCert, clientKey string) (*Result, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
security := WithoutTLS
|
||||
|
||||
// If no TLS configuration is provided, return nil config (no TLS)
|
||||
if serverCAFile == "" && clientCert == "" && clientKey == "" {
|
||||
return &TLSResult{Config: nil, Security: security}, nil
|
||||
return &Result{Config: nil, Security: security}, nil
|
||||
}
|
||||
|
||||
if serverCAFile != "" {
|
||||
@@ -96,14 +96,15 @@ func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult,
|
||||
security = WithMTLS
|
||||
}
|
||||
|
||||
return &TLSResult{Config: tlsConfig, Security: security}, nil
|
||||
return &Result{Config: tlsConfig, Security: security}, nil
|
||||
}
|
||||
|
||||
// LoadATLSConfig configures Attested TLS.
|
||||
func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
// Parameters are passed individually to avoid circular dependencies with the clients package.
|
||||
func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey string) (*Result, error) {
|
||||
security := WithATLS
|
||||
|
||||
info, err := os.Stat(cfg.AttestationPolicy)
|
||||
info, err := os.Stat(attestationPolicy)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err)
|
||||
}
|
||||
@@ -112,12 +113,12 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
return nil, errAttestationPolicyIrregular
|
||||
}
|
||||
|
||||
attestation.AttestationPolicyPath = cfg.AttestationPolicy
|
||||
attestation.AttestationPolicyPath = attestationPolicy
|
||||
|
||||
var rootCAs *x509.CertPool
|
||||
|
||||
if cfg.ServerCAFile != "" {
|
||||
rootCAs, err = loadRootCAs(cfg.ServerCAFile)
|
||||
if serverCAFile != "" {
|
||||
rootCAs, err = loadRootCAs(serverCAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -141,8 +142,8 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.ClientCert != "" || cfg.ClientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if clientCert != "" || clientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
|
||||
}
|
||||
@@ -150,7 +151,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
return &TLSResult{Config: tlsConfig, Security: security}, nil
|
||||
return &Result{Config: tlsConfig, Security: security}, nil
|
||||
}
|
||||
|
||||
// loadRootCAs loads root CA certificates from a file.
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
@@ -64,7 +64,7 @@ func TestSecurity_String(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBasicTLSConfig(t *testing.T) {
|
||||
func TestLoadBasicConfig(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
@@ -164,7 +164,7 @@ func TestLoadBasicTLSConfig(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
result, err := LoadBasicConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
@@ -201,100 +201,86 @@ func TestLoadATLSConfig(t *testing.T) {
|
||||
require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AttestedClientConfig
|
||||
expectedSec Security
|
||||
expectError bool
|
||||
errorMsg string
|
||||
name string
|
||||
attestationPolicy string
|
||||
serverCAFile string
|
||||
clientCert string
|
||||
clientKey string
|
||||
expectedSec Security
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "ValidATLSConfig",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: "",
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
name: "ValidATLSConfig",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidMATLSConfig",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: caFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithMATLS,
|
||||
expectError: false,
|
||||
name: "ValidMATLSConfig",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: caFile,
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithMATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidATLSWithClientCert",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ClientCert: certFile,
|
||||
ClientKey: keyFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
name: "ValidATLSWithClientCert",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: "",
|
||||
clientCert: certFile,
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "NonexistentPolicyFile",
|
||||
config: AttestedClientConfig{
|
||||
AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to stat attestation policy file",
|
||||
name: "NonexistentPolicyFile",
|
||||
attestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to stat attestation policy file",
|
||||
},
|
||||
{
|
||||
name: "PolicyFileIsDirectory",
|
||||
config: AttestedClientConfig{
|
||||
AttestationPolicy: tmpDir, // Directory instead of file
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "attestation policy file is not a regular file",
|
||||
name: "PolicyFileIsDirectory",
|
||||
attestationPolicy: tmpDir, // Directory instead of file
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "attestation policy file is not a regular file",
|
||||
},
|
||||
{
|
||||
name: "InvalidCAFile",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to read certificate file",
|
||||
name: "InvalidCAFile",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to read certificate file",
|
||||
},
|
||||
{
|
||||
name: "InvalidClientCert",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ClientCert: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
ClientKey: keyFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
name: "InvalidClientCert",
|
||||
attestationPolicy: policyFile,
|
||||
serverCAFile: "",
|
||||
clientCert: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LoadATLSConfig(tt.config)
|
||||
result, err := LoadATLSConfig(tt.attestationPolicy, tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
Reference in New Issue
Block a user