NOISSUE - Update cocos to match certs changes (#520)
CI / checkproto (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled

* pass domain id to agent environment

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* update generated files

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* use certs sdk directly

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* remove redundant variables

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* use agent certs token for csr

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* update certs and add token to create req

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* fix atls

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* add agent token to certificate provider

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* pass certs token to agent

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* use sdk for csr

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* update atls

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* fix tests

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* address comments

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* remove unused structs

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* update tests

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* lint

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* fix tests

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* lint

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* remove unused domain id

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* refactor tests and remove unused struct fields

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* refactor(atls): remove CAClient and inline CA certificate issuance

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* lint'

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* increase coverage

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* fix bug in certs sdk and certificate provider

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* update certs

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

* fix pkg stress

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>

---------

Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>
This commit is contained in:
Washington Kigani Kamadi
2025-10-06 22:12:18 +03:00
committed by GitHub
parent 0be724386b
commit 0ffc2d17cf
12 changed files with 441 additions and 261 deletions
+2 -101
View File
@@ -3,23 +3,9 @@
package atls
import (
"bytes"
"crypto/ecdsa"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/absmach/certs"
certscli "github.com/absmach/certs/cli"
"github.com/absmach/certs/errors"
certssdk "github.com/absmach/certs/sdk"
)
const (
@@ -38,6 +24,7 @@ var (
// CertificateSubject contains certificate subject information.
type CertificateSubject struct {
Organization string
CommonName string
Country string
Province string
Locality string
@@ -49,6 +36,7 @@ type CertificateSubject struct {
func DefaultCertificateSubject() CertificateSubject {
return CertificateSubject{
Organization: "Ultraviolet",
CommonName: "Ultraviolet",
Country: "Serbia",
Province: "",
Locality: "Belgrade",
@@ -57,93 +45,6 @@ func DefaultCertificateSubject() CertificateSubject {
}
}
// CAClient handles communication with Certificate Authority.
type CAClient struct {
baseURL string
client *http.Client
}
type CSRRequest struct {
CSR string `json:"csr,omitempty"`
}
func NewCAClient(baseURL string) *CAClient {
return &CAClient{
baseURL: baseURL,
client: &http.Client{},
}
}
func (c *CAClient) RequestCertificate(csrMetadata certs.CSRMetadata, privateKey *ecdsa.PrivateKey, cvmID string, ttl time.Duration) ([]byte, error) {
csr, sdkerr := certscli.CreateCSR(csrMetadata, privateKey)
if sdkerr != nil {
return nil, fmt.Errorf("failed to create CSR: %w", sdkerr)
}
request := CSRRequest{CSR: string(csr.CSR)}
requestData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal CSR request: %w", err)
}
endpoint := fmt.Sprintf("certs/csrs/%s", cvmID)
query := url.Values{}
query.Add("ttl", ttl.String())
requestURL := fmt.Sprintf("%s/%s?%s", c.baseURL, endpoint, query.Encode())
_, responseBody, err := c.processRequest(http.MethodPost, requestURL, requestData, nil, http.StatusOK)
if err != nil {
return nil, fmt.Errorf("failed to process CA request: %w", err)
}
var cert certssdk.Certificate
if err := json.Unmarshal(responseBody, &cert); err != nil {
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
}
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
block, rest := pem.Decode([]byte(cleanCertificateString))
if len(rest) != 0 {
return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data")
}
if block == nil {
return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found")
}
return block.Bytes, nil
}
func (c *CAClient) processRequest(method, reqURL string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
req, err := http.NewRequest(method, reqURL, bytes.NewReader(data))
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
req.Header.Add("Content-Type", "application/json")
for key, value := range headers {
req.Header.Add(key, value)
}
resp, err := c.client.Do(req)
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
defer resp.Body.Close()
sdkErr := errors.CheckError(resp, expectedRespCodes...)
if sdkErr != nil {
return make(http.Header), []byte{}, sdkErr
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
return resp.Header, body, nil
}
func extractNonceFromSNI(serverName string) ([]byte, error) {
if len(serverName) < len(nonceSuffix) || !hasNonceSuffix(serverName) {
return nil, fmt.Errorf("invalid server name: %s", serverName)
+281 -87
View File
@@ -12,19 +12,18 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/absmach/certs"
certssdk "github.com/absmach/certs/sdk"
sdkmocks "github.com/absmach/certs/sdk/mocks"
"github.com/absmach/supermq/pkg/errors"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/proto/check"
@@ -39,10 +38,70 @@ import (
"google.golang.org/protobuf/encoding/protojson"
)
const sevProductNameMilan = "Milan"
const (
sevProductNameMilan = "Milan"
)
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
func generateTestCertPEM(t *testing.T) string {
return generateTestCertPEMWithSubject(t, "test")
}
func generateTestCertPEMWithSubject(t *testing.T, commonName string) string {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
return strings.ReplaceAll(string(certPEM), "\n", "\\n")
}
func generateTestCertificateWithExtensions(t *testing.T, extensions []pkix.Extension) *x509.Certificate {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "test",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: extensions,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
// TestCertificateSubject tests the CertificateSubject functionality.
func TestDefaultCertificateSubject(t *testing.T) {
subject := DefaultCertificateSubject()
@@ -58,15 +117,15 @@ func TestDefaultCertificateSubject(t *testing.T) {
// TestUnifiedCertificateGenerator tests the unified certificate generator.
func TestUnifiedCertificateGenerator(t *testing.T) {
t.Run("SelfSignedGenerator", func(t *testing.T) {
generator, err := NewProvider(nil, attestation.SNPvTPM, "", "")
generator, err := NewProvider(nil, attestation.SNPvTPM, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, generator)
})
t.Run("CASignedGenerator", func(t *testing.T) {
caURL := "https://example.com/ca"
cvmID := "test-cvm-id"
generator, err := NewProvider(nil, attestation.SNPvTPM, caURL, cvmID)
mockSDK := sdkmocks.NewSDK(t)
generator, err := NewProvider(nil, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
assert.NoError(t, err)
assert.NotNil(t, generator)
})
@@ -204,19 +263,27 @@ func TestNewProvider(t *testing.T) {
mockProvider := new(mocks.Provider)
t.Run("SelfSignedProvider", func(t *testing.T) {
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "")
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("CASignedProvider", func(t *testing.T) {
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "https://example.com", "test-cvm")
t.Run("CASignedProviderWithSDK", func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("SelfSignedProviderNilSDK", func(t *testing.T) {
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", nil)
assert.NoError(t, err)
assert.NotNil(t, provider)
})
t.Run("InvalidPlatformType", func(t *testing.T) {
_, err := NewProvider(mockProvider, attestation.PlatformType(999), "", "")
_, err := NewProvider(mockProvider, attestation.PlatformType(999), "", "", nil)
assert.Error(t, err)
})
}
@@ -514,47 +581,6 @@ func TestVerifyCertificateExtension(t *testing.T) {
// Helper functions
func createMockCAServer(t *testing.T) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// Create a valid test certificate
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
require.NoError(t, err)
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
mockCert := certssdk.Certificate{
Certificate: string(certPEM),
}
response, _ := json.Marshal(mockCert)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(response)
}))
}
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
file, err := os.ReadFile("../../attestation.bin")
require.NoError(t, err)
@@ -686,40 +712,199 @@ func TestCertificateVerification(t *testing.T) {
})
}
// TestProcessRequestEdgeCases tests CAClient.processRequest edge cases.
func TestProcessRequestEdgeCases(t *testing.T) {
client := NewCAClient("http://example.com")
// TestAttestedCAProvider tests the CA-signed certificate provider.
func TestAttestedCAProvider(t *testing.T) {
mockProvider := new(mocks.Provider)
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
require.NoError(t, err)
t.Run("InvalidURL", func(t *testing.T) {
_, _, err := client.processRequest("GET", "://invalid-url", nil, nil, http.StatusOK)
assert.Error(t, err)
subject := DefaultCertificateSubject()
cvmID := "test-cvm-id"
agentToken := "test-token"
t.Run("NewAttestedCAProvider", func(t *testing.T) {
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
assert.NotNil(t, provider)
})
t.Run("NetworkError", func(t *testing.T) {
_, _, err := client.processRequest("GET", "http://nonexistent-domain-12345.com", nil, nil, http.StatusOK)
t.Run("SetTTL", func(t *testing.T) {
provider := NewAttestedCAProvider(attestationProvider, subject, nil, cvmID, agentToken)
newTTL := time.Hour * 48
provider.(*attestedCertificateProvider).SetTTL(newTTL)
attestedProvider := provider.(*attestedCertificateProvider)
assert.Equal(t, newTTL, attestedProvider.ttl)
})
}
// TestCASignedCertificateErrors tests error cases in CA-signed certificate generation.
func TestCASignedCertificateErrors(t *testing.T) {
mockProvider := new(mocks.Provider)
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
require.NoError(t, err)
subject := DefaultCertificateSubject()
cvmID := "test-cvm-id"
agentToken := "test-token"
cases := []struct {
name string
certificate string
sdkError error
expectedError string
}{
{"SDKIssueError", "", errors.NewSDKError(errors.New("SDK error")), "SDK error"},
{"InvalidPEMWithRemainingData", "-----BEGIN CERTIFICATE-----\\nVGVzdA==\\n-----END CERTIFICATE-----\\nExtra data here", nil, "unexpected remaining data"},
{"NoPEMBlockFound", "", nil, "no PEM block found"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
mockSDK.On("CreateCSR", mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{Certificate: c.certificate}, c.sdkError)
provider := NewAttestedCAProvider(attestationProvider, subject, mockSDK, cvmID, agentToken)
attestedProvider := provider.(*attestedCertificateProvider)
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
extension := pkix.Extension{
Id: SNPvTPMOID,
Value: []byte("test-data"),
}
_, err = attestedProvider.generateCASignedCertificate(privateKey, extension)
assert.Error(t, err)
assert.Contains(t, err.Error(), c.expectedError)
})
}
}
// TestGetCertificateErrors tests error paths in certificate generation.
func TestGetCertificateErrors(t *testing.T) {
t.Run("InvalidServerNameFormat", func(t *testing.T) {
mockProvider := new(mocks.Provider)
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
clientHello := &tls.ClientHelloInfo{
ServerName: "invalid-format",
}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to extract nonce")
})
t.Run("CustomHeaders", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test-value", r.Header.Get("Custom-Header"))
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
t.Run("AttestationProviderError", func(t *testing.T) {
mockProvider := new(mocks.Provider)
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("attestation failed"))
headers := map[string]string{"Custom-Header": "test-value"}
_, _, err := client.processRequest("GET", server.URL, nil, headers, http.StatusOK)
assert.NoError(t, err)
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
require.NoError(t, err)
provider := NewAttestedProvider(attestationProvider, DefaultCertificateSubject())
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get attestation")
})
t.Run("UnexpectedStatusCode", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer server.Close()
t.Run("CASignedCertificateError", func(t *testing.T) {
mockProvider := new(mocks.Provider)
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("test-attestation"), nil)
_, _, err := client.processRequest("GET", server.URL, nil, nil, http.StatusOK)
attestationProvider, err := NewAttestationProvider(mockProvider, attestation.SNPvTPM)
require.NoError(t, err)
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
sdkErr := errors.NewSDKError(errors.New("CA error"))
mockSDK.On("CreateCSR", mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certssdk.Certificate{}, sdkErr)
provider := NewAttestedCAProvider(attestationProvider, DefaultCertificateSubject(), mockSDK, "test-cvm", "test-token")
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
serverName := hex.EncodeToString(nonce) + ".nonce"
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
_, err = provider.GetCertificate(clientHello)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to generate certificate")
})
}
// TestCertificateVerificationEdgeCases tests edge cases in certificate verification.
func TestCertificateVerificationEdgeCases(t *testing.T) {
tempDir, err := os.MkdirTemp("", "policy")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
attestationPB := prepVerifyAttReport(t)
err = setAttestationPolicy(attestationPB, tempDir)
require.NoError(t, err)
t.Run("VerifyPeerCertificateWithMultipleCerts", func(t *testing.T) {
verifier := NewCertificateVerifier(nil)
cert1 := createSelfSignedCert(t)
cert2 := createSelfSignedCert(t)
nonce := generateNonce(t)
err := verifier.VerifyPeerCertificate([][]byte{cert1.Raw, cert2.Raw}, nil, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyAttestationExtensionWithNoExtensions", func(t *testing.T) {
cert := createSelfSignedCert(t)
verifier := certificateVerifier{}
nonce := generateNonce(t)
err := verifier.verifyAttestationExtension(cert, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyAttestationExtensionWithWrongOID", func(t *testing.T) {
wrongOID := asn1.ObjectIdentifier{1, 2, 3, 4, 5}
extension := pkix.Extension{
Id: wrongOID,
Value: []byte("test-data"),
}
cert := generateTestCertificateWithExtensions(t, []pkix.Extension{extension})
verifier := certificateVerifier{}
nonce := generateNonce(t)
err := verifier.verifyAttestationExtension(cert, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "attestation extension not found")
})
t.Run("VerifyCertificateExtensionPlatformVerifierError", func(t *testing.T) {
verifier := certificateVerifier{}
invalidPlatformType := attestation.PlatformType(999)
err := verifier.verifyCertificateExtension([]byte("test-extension"), []byte("test-pubkey"), []byte("test-nonce"), invalidPlatformType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported platform type")
})
}
@@ -793,7 +978,7 @@ func TestIntegrationScenarios(t *testing.T) {
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
// Create provider
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "")
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil)
require.NoError(t, err)
// Generate certificate
@@ -826,19 +1011,18 @@ func TestIntegrationScenarios(t *testing.T) {
})
t.Run("FullCASignedFlow", func(t *testing.T) {
// Setup mock CA server
mockServer := createMockCAServer(t)
defer mockServer.Close()
mockSDK := sdkmocks.NewSDK(t)
expectedCSR := certs.CSR{CSR: []byte("test-csr")}
expectedCert := certssdk.Certificate{Certificate: generateTestCertPEM(t)}
mockSDK.On("CreateCSR", mock.Anything, mock.Anything).Return(expectedCSR, errors.SDKError(nil))
mockSDK.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedCert, errors.SDKError(nil))
// Setup mock provider
mockProvider := new(mocks.Provider)
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
// Create CA-signed provider
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, mockServer.URL, "test-cvm")
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "test-token", "test-cvm-id", mockSDK)
require.NoError(t, err)
// Generate certificate
nonce := make([]byte, 64)
_, err = rand.Read(nonce)
require.NoError(t, err)
@@ -847,8 +1031,18 @@ func TestIntegrationScenarios(t *testing.T) {
clientHello := &tls.ClientHelloInfo{ServerName: serverName}
cert, err := provider.GetCertificate(clientHello)
assert.NoError(t, err)
assert.NotNil(t, cert)
require.NoError(t, err)
require.NotNil(t, cert)
require.NotEmpty(t, cert.Certificate)
require.NotNil(t, cert.PrivateKey)
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
assert.NotNil(t, parsedCert.Subject)
mockProvider.AssertExpectations(t)
mockSDK.AssertExpectations(t)
})
}
@@ -857,7 +1051,7 @@ func TestConcurrentAccess(t *testing.T) {
mockProvider := new(mocks.Provider)
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation"), nil)
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "")
provider, err := NewProvider(mockProvider, attestation.SNPvTPM, "", "", nil)
require.NoError(t, err)
const numGoroutines = 10
+33 -7
View File
@@ -9,11 +9,14 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"strings"
"time"
"github.com/absmach/certs"
sdk "github.com/absmach/certs/sdk"
"github.com/ultravioletrs/cocos/pkg/attestation"
)
@@ -25,7 +28,8 @@ type CertificateProvider interface {
// AttestedCertificateProvider provides attested TLS certificates.
type attestedCertificateProvider struct {
attestationProvider AttestationProvider
caClient *CAClient
certsSDK sdk.SDK
agentToken string
subject CertificateSubject
useCA bool
cvmID string
@@ -50,12 +54,13 @@ func NewAttestedProvider(
func NewAttestedCAProvider(
attestationProvider AttestationProvider,
subject CertificateSubject,
caURL, cvmID string,
certsSDK sdk.SDK, cvmID, agentToken string,
) CertificateProvider {
return &attestedCertificateProvider{
attestationProvider: attestationProvider,
subject: subject,
caClient: NewCAClient(caURL),
certsSDK: certsSDK,
agentToken: agentToken,
useCA: true,
cvmID: cvmID,
ttl: time.Hour * 24 * 365, // Default 1 year
@@ -136,6 +141,7 @@ func (p *attestedCertificateProvider) generateCASignedCertificate(privateKey *ec
csrMetadata := certs.CSRMetadata{
Organization: []string{p.subject.Organization},
Country: []string{p.subject.Country},
CommonName: p.subject.CommonName,
Province: []string{p.subject.Province},
Locality: []string{p.subject.Locality},
StreetAddress: []string{p.subject.StreetAddress},
@@ -143,10 +149,30 @@ func (p *attestedCertificateProvider) generateCASignedCertificate(privateKey *ec
ExtraExtensions: []pkix.Extension{extension},
}
return p.caClient.RequestCertificate(csrMetadata, privateKey, p.cvmID, p.ttl)
csr, sdkerr := p.certsSDK.CreateCSR(csrMetadata, privateKey)
if sdkerr != nil {
return nil, fmt.Errorf("failed to create CSR: %w", sdkerr)
}
cert, err := p.certsSDK.IssueFromCSRInternal(p.cvmID, p.ttl.String(), string(csr.CSR), p.agentToken)
if err != nil {
return nil, err
}
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
block, rest := pem.Decode([]byte(cleanCertificateString))
if len(rest) != 0 {
return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data")
}
if block == nil {
return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found")
}
return block.Bytes, nil
}
func NewProvider(provider attestation.Provider, platformType attestation.PlatformType, caURL, cvmID string) (CertificateProvider, error) {
func NewProvider(provider attestation.Provider, platformType attestation.PlatformType, agentToken, cvmID string, certsSDK sdk.SDK) (CertificateProvider, error) {
attestationProvider, err := NewAttestationProvider(provider, platformType)
if err != nil {
return nil, fmt.Errorf("failed to create attestation provider: %w", err)
@@ -154,8 +180,8 @@ func NewProvider(provider attestation.Provider, platformType attestation.Platfor
subject := DefaultCertificateSubject()
if caURL != "" && cvmID != "" {
return NewAttestedCAProvider(attestationProvider, subject, caURL, cvmID), nil
if certsSDK != nil {
return NewAttestedCAProvider(attestationProvider, subject, certsSDK, cvmID, agentToken), nil
}
return NewAttestedProvider(attestationProvider, subject), nil