COCOS-525-487 - Refactor attestation and atls (#562)

* Refactor attestation handling to remove quoteprovider dependency

- Removed references to quoteprovider in various files, replacing them with vtpm where necessary.
- Updated function signatures and implementations to use SEVNonce instead of quoteprovider.Nonce.
- Introduced new vtpm package to handle SEV-related attestation logic, including fetching and verifying attestation reports.
- Adjusted tests to reflect changes in the attestation logic and ensure compatibility with the new structure.
- Deleted the now redundant quoteprovider/sev_test.go file.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Add veraison/go-cose dependency to go.mod

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Introduce TLS package for enhanced security configuration and refactor client code to utilize new TLS utilities

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-02-18 13:53:04 +03:00
committed by GitHub
parent de50b6d2d4
commit 207bfd99af
29 changed files with 222 additions and 591 deletions
+3 -4
View File
@@ -21,7 +21,6 @@ import (
"github.com/google/go-tpm-tools/proto/attest"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"google.golang.org/protobuf/proto"
)
@@ -130,7 +129,7 @@ func (a verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
return errors.Wrap(fmt.Errorf("failed to convert TEE report to proto"), err)
}
return quoteprovider.VerifyAttestationReportTLS(attestationReport, teeNonce, a.Policy)
return vtpm.VerifySEVAttestationReportTLS(attestationReport, teeNonce, a.Policy)
}
func (a verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
@@ -148,7 +147,7 @@ func (a verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []
}
snpReport := quote.GetSevSnpAttestation()
if err = quoteprovider.VerifyAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
if err = vtpm.VerifySEVAttestationReportTLS(snpReport, nil, a.Policy); err != nil {
return fmt.Errorf("failed to verify vTPM attestation report: %w", err)
}
@@ -266,7 +265,7 @@ func GenerateAttestationPolicy(token, product string, policy uint64) (*attestati
return nil, fmt.Errorf("failed to decode reportID: %w", err)
}
sevSnpProduct := quoteprovider.GetProductName(product)
sevSnpProduct := vtpm.GetSEVProductName(product)
return &attestation.Config{
Config: &check.Config{
-377
View File
@@ -1,377 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
//go:build !embed
package quoteprovider
import (
"os"
"path"
"testing"
"github.com/google/go-sev-guest/proto/check"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFillInAttestationLocal(t *testing.T) {
originalHome := os.Getenv("HOME")
defer func() {
os.Setenv("HOME", originalHome)
}()
tempDir, err := os.MkdirTemp("", "test_home")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
os.Setenv("HOME", tempDir)
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
err = os.MkdirAll(cocosDir, 0o755)
require.NoError(t, err)
bundleContent := []byte("mock ASK ARK bundle")
bundlePath := path.Join(cocosDir, arkAskBundleName)
err = os.WriteFile(bundlePath, bundleContent, 0o644)
require.NoError(t, err)
config := &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{},
}
tests := []struct {
name string
attestation *sevsnp.Attestation
setupFunc func()
expectedError bool
errorContains string
}{
{
name: "Empty attestation - creates new chain",
attestation: &sevsnp.Attestation{
CertificateChain: nil,
},
setupFunc: func() {},
expectedError: true,
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
},
{
name: "Attestation with existing chain - no changes needed",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("existing ASK cert"),
ArkCert: []byte("existing ARK cert"),
},
},
setupFunc: func() {},
expectedError: false,
},
{
name: "Attestation with empty chain - tries to load from file",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
setupFunc: func() {},
expectedError: true,
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
},
{
name: "No bundle file exists - no error",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
setupFunc: func() {
os.Remove(bundlePath)
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Setenv("HOME", tempDir)
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
t.Fatalf("Failed to write bundle file: %v", err)
}
}
tt.setupFunc()
err := fillInAttestationLocal(tt.attestation, config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestGetProductName(t *testing.T) {
tests := []struct {
name string
product string
expected sevsnp.SevProduct_SevProductName
}{
{
name: "Milan product",
product: sevSnpProductMilan,
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
{
name: "Genoa product",
product: sevSnpProductGenoa,
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
},
{
name: "Unknown product",
product: "UnknownProduct",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
{
name: "Empty product",
product: "",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
{
name: "Case sensitive - milan lowercase",
product: "milan",
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetProductName(tt.product)
assert.Equal(t, tt.expected, result)
})
}
}
func TestVerifyReport(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
config *check.Config
expectedError bool
errorContains string
}{
{
name: "Invalid product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: "InvalidProduct",
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "product name must be",
},
{
name: "Valid Milan product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "attestation verification failed",
},
{
name: "Valid Genoa product line",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductGenoa,
},
Policy: &check.Policy{},
},
expectedError: true,
errorContains: "attestation verification failed",
},
{
name: "Config with existing product policy",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{
AskCert: []byte("mock ask cert"),
ArkCert: []byte("mock ark cert"),
},
},
config: &check.Config{
RootOfTrust: &check.RootOfTrust{
ProductLine: sevSnpProductMilan,
},
Policy: &check.Policy{
Product: &sevsnp.SevProduct{
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
},
},
},
expectedError: true,
errorContains: "attestation verification failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := verifyReport(tt.attestation, tt.config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateReport(t *testing.T) {
tests := []struct {
name string
attestation *sevsnp.Attestation
config *check.Config
expectedError bool
errorContains string
}{
{
name: "Basic validation test",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
Policy: &check.Policy{
Policy: 196608,
},
},
expectedError: true,
errorContains: "attestation validation failed",
},
{
name: "Validation with report data",
attestation: &sevsnp.Attestation{
CertificateChain: &sevsnp.CertificateChain{},
},
config: &check.Config{
Policy: &check.Policy{
Policy: 196608,
ReportData: []byte("test report datatest report datatest report datatest report data"),
},
},
expectedError: true,
errorContains: "attestation validation failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateReport(tt.attestation, tt.config)
if tt.expectedError {
assert.Error(t, err)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestFetchAttestation(t *testing.T) {
tests := []struct {
name string
reportData []byte
vmpl uint
expectedError bool
errorContains string
}{
{
name: "Report data too large",
reportData: make([]byte, Nonce+1),
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Valid report data size",
reportData: make([]byte, 32),
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Maximum valid report data size",
reportData: make([]byte, Nonce),
vmpl: 1,
expectedError: true,
errorContains: "could not get quote provider",
},
{
name: "Empty report data",
reportData: []byte{},
vmpl: 0,
expectedError: true,
errorContains: "could not get quote provider",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := FetchAttestation(tt.reportData, tt.vmpl)
if tt.expectedError {
assert.Error(t, err)
assert.Empty(t, result)
if tt.errorContains != "" {
assert.Contains(t, err.Error(), tt.errorContains)
}
} else {
assert.NoError(t, err)
assert.NotEmpty(t, result)
}
})
}
}
func TestGetLeveledQuoteProvider(t *testing.T) {
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
provider, err := GetLeveledQuoteProvider()
if err != nil {
assert.Error(t, err)
assert.Nil(t, provider)
} else {
assert.NoError(t, err)
assert.NotNil(t, provider)
}
})
}
@@ -3,7 +3,7 @@
//go:build !embed
package quoteprovider
package vtpm
import (
"crypto/rand"
@@ -32,7 +32,7 @@ const (
cocosDirectory = "/cocos"
arkAskBundleName = "ask_ark.pem"
vcekName = "vcek.pem"
Nonce = 64
SEVNonce = 64
sevSnpProductMilan = "Milan"
sevSnpProductGenoa = "Genoa"
)
@@ -43,9 +43,9 @@ var (
)
var (
ErrProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
ErrAttVerification = errors.New("attestation verification failed")
errAttValidation = errors.New("attestation validation failed")
ErrSEVProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevSnpProductMilan, sevSnpProductGenoa))
ErrSEVAttVerification = errors.New("attestation verification failed")
errSEVAttValidation = errors.New("attestation validation failed")
)
func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config) error {
@@ -77,16 +77,17 @@ func fillInAttestationLocal(attestation *sevsnp.Attestation, cfg *check.Config)
return nil
}
// verifyReport verifies the SEV-SNP attestation report.
func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
sopts, err := verify.RootOfTrustToOptions(cfg.RootOfTrust)
if err != nil {
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrAttVerification, err))
return fmt.Errorf("failed to get root of trust options: %v", errors.Wrap(ErrSEVAttVerification, err))
}
if cfg.Policy.Product == nil {
productName := GetProductName(cfg.RootOfTrust.ProductLine)
productName := GetSEVProductName(cfg.RootOfTrust.ProductLine)
if productName == sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN {
return ErrProductLine
return ErrSEVProductLine
}
sopts.Product = &sevsnp.SevProduct{
@@ -107,30 +108,33 @@ func verifyReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
}
if err := verify.SnpAttestation(attestationPB, sopts); err != nil {
return errors.Wrap(ErrAttVerification, err)
return errors.Wrap(ErrSEVAttVerification, err)
}
return nil
}
// validateReport validates the SEV-SNP attestation report against policy.
func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
opts, err := validate.PolicyToOptions(cfg.Policy)
if err != nil {
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrAttVerification, err))
return fmt.Errorf("failed to get policy for validation: %v", errors.Wrap(ErrSEVAttVerification, err))
}
if err = validate.SnpAttestation(attestationPB, opts); err != nil {
return errors.Wrap(errAttValidation, err)
return errors.Wrap(errSEVAttValidation, err)
}
return nil
}
func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
// getLeveledQuoteProvider returns a leveled quote provider for SEV-SNP.
func getLeveledQuoteProvider() (client.LeveledQuoteProvider, error) {
return client.GetLeveledQuoteProvider()
}
func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
// VerifySEVAttestationReportTLS verifies a SEV-SNP attestation report for TLS (exported for azure package).
func VerifySEVAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte, policy *attestation.Config) error {
config := policy.Config
// Certificate chain is populated based on the extra data that is appended to the SEV-SNP attestation report.
@@ -141,10 +145,11 @@ func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []
config.Policy.ReportData = reportData[:]
}
return VerifyAndValidate(attestationPB, config)
return verifySEVAndValidate(attestationPB, config)
}
func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
// verifySEVAndValidate performs both verification and validation of a SEV-SNP attestation.
func verifySEVAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error {
logger.Init("", false, false, io.Discard)
if err := verifyReport(attestationPB, cfg); err != nil {
@@ -158,15 +163,16 @@ func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) err
return nil
}
func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
var reportData [Nonce]byte
// fetchSEVAttestation fetches a SEV-SNP attestation report.
func fetchSEVAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
var reportData [SEVNonce]byte
qp, err := GetLeveledQuoteProvider()
qp, err := getLeveledQuoteProvider()
if err != nil {
return []byte{}, fmt.Errorf("could not get quote provider")
}
if len(reportData) > Nonce {
if len(reportData) > SEVNonce {
return []byte{}, fmt.Errorf("attestation report size mismatch")
}
copy(reportData[:], reportDataSlice)
@@ -206,7 +212,8 @@ func FetchAttestation(reportDataSlice []byte, vmpl uint) ([]byte, error) {
return result, nil
}
func GetProductName(product string) sevsnp.SevProduct_SevProductName {
// GetSEVProductName maps a product string to a SEV product name.
func GetSEVProductName(product string) sevsnp.SevProduct_SevProductName {
switch product {
case sevSnpProductMilan:
return sevsnp.SevProduct_SEV_PRODUCT_MILAN
@@ -217,6 +224,7 @@ func GetProductName(product string) sevsnp.SevProduct_SevProductName {
}
}
// derToPem converts DER-encoded certificate to PEM format.
func derToPem(der []byte) []byte {
// Try to parse to make sure it's a certificate
if _, err := x509.ParseCertificate(der); err != nil {
@@ -226,15 +234,16 @@ func derToPem(der []byte) []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
}
func FetchCertificates(vmpl uint) error {
var reportData [Nonce]byte
// FetchSEVCertificates fetches SEV-SNP certificates from KDS.
func FetchSEVCertificates(vmpl uint) error {
var reportData [SEVNonce]byte
qp, err := GetLeveledQuoteProvider()
qp, err := getLeveledQuoteProvider()
if err != nil {
return fmt.Errorf("could not get quote provider")
}
if len(reportData) > Nonce {
if len(reportData) > SEVNonce {
return fmt.Errorf("attestation report size mismatch")
}
+4 -5
View File
@@ -25,7 +25,6 @@ import (
"github.com/google/go-tpm/tpmutil"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/eat"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/encoding/prototext"
@@ -120,7 +119,7 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
}
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
return quoteprovider.FetchAttestation(teeNonce, v.vmpl)
return fetchSEVAttestation(teeNonce, v.vmpl)
}
func (v provider) VTpmAttestation(vTpmNonce []byte) ([]byte, error) {
@@ -171,7 +170,7 @@ func (v verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
}
attestationReport := sevsnp.Attestation{Report: attestReport, CertificateChain: nil}
return quoteprovider.VerifyAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
return VerifySEVAttestationReportTLS(&attestationReport, teeNonce, v.Policy)
}
func (v verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
@@ -234,7 +233,7 @@ func VTPMVerify(quote []byte, teeNonce []byte, vtpmNonce []byte, writer io.Write
attestData := sha3.Sum512(nonce)
if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
if err := VerifySEVAttestationReportTLS(attestation.GetSevSnpAttestation(), attestData[:], policy); err != nil {
return fmt.Errorf("failed to verify TEE attestation report: %v", err)
}
@@ -336,7 +335,7 @@ func addTEEAttestation(attestation *attest.Attestation, nonce []byte, vmpl uint)
attestData := sha3.Sum512(teeNonce)
rawTeeAttestation, err := quoteprovider.FetchAttestation(attestData[:], vmpl)
rawTeeAttestation, err := fetchSEVAttestation(attestData[:], vmpl)
if err != nil {
return fmt.Errorf("failed to fetch TEE attestation report: %v", err)
}
+7 -10
View File
@@ -22,12 +22,9 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"google.golang.org/protobuf/encoding/protojson"
)
const sevSnpProductMilan = "Milan"
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
type mockTPM struct {
@@ -679,13 +676,13 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
name: "Valid attestation, distorted signature",
attestationReport: attestationPB,
reportData: reportData,
err: quoteprovider.ErrAttVerification,
err: ErrSEVAttVerification,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
@@ -713,13 +710,13 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) {
name: "Valid attestation, unknown product",
attestationReport: attestationPB,
reportData: reportData,
err: quoteprovider.ErrProductLine,
err: ErrSEVProductLine,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
@@ -752,7 +749,7 @@ func TestVerifyAttestationReportSuccess(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
@@ -780,13 +777,13 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) {
name: "Valid attestation, malformed policy (measurement)",
attestationReport: attestationPB,
reportData: reportData,
err: quoteprovider.ErrAttVerification,
err: ErrSEVAttVerification,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := quoteprovider.VerifyAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
err := VerifySEVAttestationReportTLS(tt.attestationReport, tt.reportData, &policy)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
+1 -1
View File
@@ -2,5 +2,5 @@
// SPDX-License-Identifier: Apache-2.0
// Package clients contains the domain concept definitions needed to support
// HTTP/gRPC Client functionality.
// client configuration for HTTP/gRPC connections.
package clients
+2 -1
View File
@@ -9,6 +9,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/tls"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
@@ -21,7 +22,7 @@ func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc
return nil, nil, err
}
if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() {
if client.Secure() != tls.WithMATLS.String() && client.Secure() != tls.WithATLS.String() && client.Secure() != tls.WithTLS.String() {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
+10 -9
View File
@@ -22,6 +22,7 @@ import (
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
func TestNewClient(t *testing.T) {
@@ -119,7 +120,7 @@ func TestNewClient(t *testing.T) {
ServerCAFile: "nonexistent.pem",
},
wantErr: true,
err: clients.ErrFailedToLoadRootCA,
err: tls.ErrFailedToLoadRootCA,
},
{
name: "Fail with invalid ClientCert",
@@ -130,7 +131,7 @@ func TestNewClient(t *testing.T) {
ClientKey: clientKeyFile,
},
wantErr: true,
err: clients.ErrFailedToLoadClientCertKey,
err: tls.ErrFailedToLoadClientCertKey,
},
{
name: "Fail with invalid ClientKey",
@@ -141,7 +142,7 @@ func TestNewClient(t *testing.T) {
ClientKey: "nonexistent.pem",
},
wantErr: true,
err: clients.ErrFailedToLoadClientCertKey,
err: tls.ErrFailedToLoadClientCertKey,
},
}
@@ -169,32 +170,32 @@ func TestNewClient(t *testing.T) {
func TestClientSecure(t *testing.T) {
tests := []struct {
name string
secure clients.Security
secure tls.Security
expected string
}{
{
name: "Without TLS",
secure: clients.WithoutTLS,
secure: tls.WithoutTLS,
expected: "without TLS",
},
{
name: "With TLS",
secure: clients.WithTLS,
secure: tls.WithTLS,
expected: "with TLS",
},
{
name: "With mTLS",
secure: clients.WithMTLS,
secure: tls.WithMTLS,
expected: "with mTLS",
},
{
name: "With aTLS",
secure: clients.WithATLS,
secure: tls.WithATLS,
expected: "with aTLS",
},
{
name: "With maTLS",
secure: clients.WithMATLS,
secure: tls.WithMATLS,
expected: "with maTLS",
},
}
+14 -8
View File
@@ -6,6 +6,7 @@ package grpc
import (
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -26,7 +27,7 @@ type Client interface {
type client struct {
*grpc.ClientConn
cfg clients.ClientConfiguration
security clients.Security
security tls.Security
}
var _ Client = (*client)(nil)
@@ -59,14 +60,19 @@ func (c *client) Connection() *grpc.ClientConn {
return c.ClientConn
}
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) {
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, tls.Security, error) {
opts := []grpc.DialOption{
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
}
security := clients.WithoutTLS
security := tls.WithoutTLS
if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
result, err := clients.LoadATLSConfig(agcfg)
result, err := tls.LoadATLSConfig(
agcfg.AttestationPolicy,
agcfg.ServerCAFile,
agcfg.ClientCert,
agcfg.ClientKey,
)
if err != nil {
return nil, security, err
}
@@ -90,13 +96,13 @@ func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Securit
return conn, security, nil
}
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) {
result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey)
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, tls.Security, error) {
result, err := tls.LoadBasicConfig(serverCAFile, clientCert, clientKey)
if err != nil {
return nil, clients.WithoutTLS, err
return nil, tls.WithoutTLS, err
}
if result.Security == clients.WithoutTLS || result.Config == nil {
if result.Security == tls.WithoutTLS || result.Config == nil {
return insecure.NewCredentials(), result.Security, nil
}
+12 -6
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
type Client interface {
@@ -19,7 +20,7 @@ type Client interface {
type client struct {
transport *http.Transport
cfg clients.ClientConfiguration
security clients.Security
security tls.Security
}
var _ Client = (*client)(nil)
@@ -49,17 +50,22 @@ func (c *client) Timeout() time.Duration {
return c.cfg.Config().Timeout
}
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) {
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, tls.Security, error) {
transport := &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
}
security := clients.WithoutTLS
security := tls.WithoutTLS
if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
result, err := clients.LoadATLSConfig(*agcfg)
result, err := tls.LoadATLSConfig(
agcfg.AttestationPolicy,
agcfg.ServerCAFile,
agcfg.ClientCert,
agcfg.ClientKey,
)
if err != nil {
return nil, security, err
}
@@ -69,12 +75,12 @@ func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.
} else {
conf := cfg.Config()
result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
result, err := tls.LoadBasicConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
if err != nil {
return nil, security, err
}
if result.Security != clients.WithoutTLS {
if result.Security != tls.WithoutTLS {
transport.TLSClientConfig = result.Config
}
+5 -4
View File
@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/tls"
)
func TestConfig_Configuration(t *testing.T) {
@@ -144,7 +145,7 @@ func TestClient_Secure(t *testing.T) {
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
},
expected: clients.WithoutTLS.String(),
expected: tls.WithoutTLS.String(),
},
}
@@ -183,7 +184,7 @@ func TestCreateTransport_DefaultSettings(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, tls.WithoutTLS, security)
assert.Equal(t, 100, transport.MaxIdleConns)
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
@@ -205,7 +206,7 @@ func TestCreateTransport_ATLSError(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, tls.WithoutTLS, security)
assert.Contains(t, err.Error(), "failed to stat attestation policy")
}
@@ -220,7 +221,7 @@ func TestCreateTransport_BasicTLSError(t *testing.T) {
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, tls.WithoutTLS, security)
assert.Contains(t, err.Error(), "failed to load root ca file")
}
+5 -6
View File
@@ -19,7 +19,6 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/sdk"
"golang.org/x/crypto/sha3"
@@ -388,7 +387,7 @@ func TestAttestation(t *testing.T) {
cases := []struct {
name string
userKey any
reportData [quoteprovider.Nonce]byte
reportData [vtpm.SEVNonce]byte
nonce [vtpm.Nonce]byte
response *agent.AttestationResponse
svcRes []byte
@@ -397,7 +396,7 @@ func TestAttestation(t *testing.T) {
{
name: "fetch attestation report successfully",
userKey: resultConsumerKey,
reportData: [quoteprovider.Nonce]byte(reportData),
reportData: [vtpm.SEVNonce]byte(reportData),
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: report,
@@ -408,7 +407,7 @@ func TestAttestation(t *testing.T) {
{
name: "fetch attestation report with different key type",
userKey: resultConsumer1Key,
reportData: [quoteprovider.Nonce]byte(reportData),
reportData: [vtpm.SEVNonce]byte(reportData),
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: report,
@@ -419,7 +418,7 @@ func TestAttestation(t *testing.T) {
{
name: "failed to fetch attestation report",
userKey: resultConsumerKey,
reportData: [quoteprovider.Nonce]byte(reportData),
reportData: [vtpm.SEVNonce]byte(reportData),
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: []byte{},
@@ -429,7 +428,7 @@ func TestAttestation(t *testing.T) {
{
name: "invalid report data",
userKey: resultConsumerKey,
reportData: [quoteprovider.Nonce]byte{},
reportData: [vtpm.SEVNonce]byte{},
nonce: [vtpm.Nonce]byte(nonce),
response: &agent.AttestationResponse{
File: []byte{},
+6
View File
@@ -0,0 +1,6 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package tls provides TLS configuration utilities for client connections.
// It supports standard TLS, mutual TLS (mTLS), and attested TLS (aTLS).
package tls
+16 -15
View File
@@ -1,7 +1,7 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
package tls
import (
"crypto/rand"
@@ -53,20 +53,20 @@ var (
errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file")
)
// TLSResult contains the result of TLS configuration.
type TLSResult struct {
// Result contains the result of TLS configuration.
type Result struct {
Config *tls.Config
Security Security
}
// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS).
func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) {
// LoadBasicConfig loads standard TLS configuration (TLS/mTLS).
func LoadBasicConfig(serverCAFile, clientCert, clientKey string) (*Result, error) {
tlsConfig := &tls.Config{}
security := WithoutTLS
// If no TLS configuration is provided, return nil config (no TLS)
if serverCAFile == "" && clientCert == "" && clientKey == "" {
return &TLSResult{Config: nil, Security: security}, nil
return &Result{Config: nil, Security: security}, nil
}
if serverCAFile != "" {
@@ -96,14 +96,15 @@ func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult,
security = WithMTLS
}
return &TLSResult{Config: tlsConfig, Security: security}, nil
return &Result{Config: tlsConfig, Security: security}, nil
}
// LoadATLSConfig configures Attested TLS.
func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
// Parameters are passed individually to avoid circular dependencies with the clients package.
func LoadATLSConfig(attestationPolicy, serverCAFile, clientCert, clientKey string) (*Result, error) {
security := WithATLS
info, err := os.Stat(cfg.AttestationPolicy)
info, err := os.Stat(attestationPolicy)
if err != nil {
return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err)
}
@@ -112,12 +113,12 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
return nil, errAttestationPolicyIrregular
}
attestation.AttestationPolicyPath = cfg.AttestationPolicy
attestation.AttestationPolicyPath = attestationPolicy
var rootCAs *x509.CertPool
if cfg.ServerCAFile != "" {
rootCAs, err = loadRootCAs(cfg.ServerCAFile)
if serverCAFile != "" {
rootCAs, err = loadRootCAs(serverCAFile)
if err != nil {
return nil, err
}
@@ -141,8 +142,8 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
},
}
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if clientCert != "" || clientKey != "" {
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
if err != nil {
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
}
@@ -150,7 +151,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
tlsConfig.Certificates = []tls.Certificate{certificate}
}
return &TLSResult{Config: tlsConfig, Security: security}, nil
return &Result{Config: tlsConfig, Security: security}, nil
}
// loadRootCAs loads root CA certificates from a file.
+64 -78
View File
@@ -1,7 +1,7 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
package tls
import (
"crypto/rand"
@@ -64,7 +64,7 @@ func TestSecurity_String(t *testing.T) {
}
}
func TestLoadBasicTLSConfig(t *testing.T) {
func TestLoadBasicConfig(t *testing.T) {
// Create temporary directory for test files
tmpDir := t.TempDir()
@@ -164,7 +164,7 @@ func TestLoadBasicTLSConfig(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
result, err := LoadBasicConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
if tt.expectError {
assert.Error(t, err)
@@ -201,100 +201,86 @@ func TestLoadATLSConfig(t *testing.T) {
require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644))
tests := []struct {
name string
config AttestedClientConfig
expectedSec Security
expectError bool
errorMsg string
name string
attestationPolicy string
serverCAFile string
clientCert string
clientKey string
expectedSec Security
expectError bool
errorMsg string
}{
{
name: "ValidATLSConfig",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: "",
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithATLS,
expectError: false,
name: "ValidATLSConfig",
attestationPolicy: policyFile,
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithATLS,
expectError: false,
},
{
name: "ValidMATLSConfig",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: caFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithMATLS,
expectError: false,
name: "ValidMATLSConfig",
attestationPolicy: policyFile,
serverCAFile: caFile,
clientCert: "",
clientKey: "",
expectedSec: WithMATLS,
expectError: false,
},
{
name: "ValidATLSWithClientCert",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ClientCert: certFile,
ClientKey: keyFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithATLS,
expectError: false,
name: "ValidATLSWithClientCert",
attestationPolicy: policyFile,
serverCAFile: "",
clientCert: certFile,
clientKey: keyFile,
expectedSec: WithATLS,
expectError: false,
},
{
name: "NonexistentPolicyFile",
config: AttestedClientConfig{
AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to stat attestation policy file",
name: "NonexistentPolicyFile",
attestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to stat attestation policy file",
},
{
name: "PolicyFileIsDirectory",
config: AttestedClientConfig{
AttestationPolicy: tmpDir, // Directory instead of file
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "attestation policy file is not a regular file",
name: "PolicyFileIsDirectory",
attestationPolicy: tmpDir, // Directory instead of file
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "attestation policy file is not a regular file",
},
{
name: "InvalidCAFile",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to read certificate file",
name: "InvalidCAFile",
attestationPolicy: policyFile,
serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to read certificate file",
},
{
name: "InvalidClientCert",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ClientCert: filepath.Join(tmpDir, "nonexistent.crt"),
ClientKey: keyFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
name: "InvalidClientCert",
attestationPolicy: policyFile,
serverCAFile: "",
clientCert: filepath.Join(tmpDir, "nonexistent.crt"),
clientKey: keyFile,
expectedSec: WithoutTLS,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LoadATLSConfig(tt.config)
result, err := LoadATLSConfig(tt.attestationPolicy, tt.serverCAFile, tt.clientCert, tt.clientKey)
if tt.expectError {
assert.Error(t, err)