NOISSUE - Refactor aTLS and gRPC server to use CertificateProvider interface (#522)

* Refactor ATLS and gRPC server to use CertificateProvider interface

- Removed unused test cases and mock dependencies in atls_test.go.
- Updated TestGetPlatformVerifier to use CertificateVerifier struct.
- Introduced CertificateProvider interface for better abstraction in TLS handling.
- Refactored gRPC server to accept CertificateProvider and configure TLS accordingly.
- Simplified TLS configuration logic in both gRPC and HTTP servers.
- Removed unnecessary parameters from server initialization in tests and main function.
- Enhanced logging for TLS configurations.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix comments for consistency and clarity in atls.go

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Update expected error messages in VM command tests for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance tests by integrating mock providers and improving error messages for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for certificate generation and attestation providers

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Implement certificate and attestation providers with unified generation logic

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor certificate and attestation provider structures for consistency; implement CertificateVerifier interface and related methods

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor attestation and certificate provider methods for consistency; rename methods and update related logic

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-09-23 15:49:23 +03:00
committed by GitHub
parent 906d7877b2
commit c758b3b216
20 changed files with 1379 additions and 880 deletions
+7
View File
@@ -139,3 +139,10 @@ packages:
dir: '{{.InterfaceDir}}/mocks'
structname: '{{.InterfaceName}}'
filename: "{{.InterfaceName | lower}}.go"
github.com/ultravioletrs/cocos/pkg/atls:
interfaces:
CertificateProvider:
config:
dir: '{{.InterfaceDir}}/mocks'
structname: '{{.InterfaceName}}'
filename: "{{.InterfaceName | lower}}.go"
+12 -13
View File
@@ -11,6 +11,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"google.golang.org/grpc"
@@ -28,21 +29,19 @@ type AgentServer interface {
}
type agentServer struct {
gs server.Server
logger *slog.Logger
svc agent.Service
host string
caUrl string
cvmId string
gs server.Server
logger *slog.Logger
svc agent.Service
host string
certProvider atls.CertificateProvider
}
func NewServer(logger *slog.Logger, svc agent.Service, host string, caUrl string, cvmId string) AgentServer {
func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider atls.CertificateProvider) AgentServer {
return &agentServer{
logger: logger,
svc: svc,
host: host,
caUrl: caUrl,
cvmId: cvmId,
logger: logger,
svc: svc,
host: host,
certProvider: certProvider,
}
}
@@ -78,7 +77,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
ctx, cancel := context.WithCancel(context.Background())
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.caUrl, as.cvmId)
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.certProvider)
go func() {
err := as.gs.Start()
+14 -28
View File
@@ -18,12 +18,10 @@ import (
"github.com/ultravioletrs/cocos/agent/mocks"
)
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, string, string, []byte) {
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, []byte) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
mockSvc := new(mocks.Service)
host := "localhost"
caUrl := "https://ca.example.com"
cvmId := "test-cvm-id"
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.NoError(t, err, "Failed to generate ECDSA key")
@@ -31,19 +29,17 @@ func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, string, stri
pubkey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
assert.NoError(t, err, "Failed to marshal public key")
return logger, mockSvc, host, caUrl, cvmId, pubkey
return logger, mockSvc, host, pubkey
}
func TestNewServer(t *testing.T) {
logger, svc, host, caUrl, cvmId, _ := setupTest(t)
logger, svc, host, _ := setupTest(t)
tests := []struct {
name string
logger *slog.Logger
svc agent.Service
host string
caUrl string
cvmId string
expected AgentServer
}{
{
@@ -51,38 +47,30 @@ func TestNewServer(t *testing.T) {
logger: logger,
svc: svc,
host: host,
caUrl: caUrl,
cvmId: cvmId,
},
{
name: "server with empty host",
logger: logger,
svc: svc,
host: "",
caUrl: caUrl,
cvmId: cvmId,
},
{
name: "server with empty caUrl",
logger: logger,
svc: svc,
host: host,
caUrl: "",
cvmId: cvmId,
},
{
name: "server with empty cvmId",
logger: logger,
svc: svc,
host: host,
caUrl: caUrl,
cvmId: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewServer(tt.logger, tt.svc, tt.host, tt.caUrl, tt.cvmId)
server := NewServer(tt.logger, tt.svc, tt.host, nil)
assert.NotNil(t, server)
@@ -91,14 +79,12 @@ func TestNewServer(t *testing.T) {
assert.Equal(t, tt.logger, agentSrv.logger)
assert.Equal(t, tt.svc, agentSrv.svc)
assert.Equal(t, tt.host, agentSrv.host)
assert.Equal(t, tt.caUrl, agentSrv.caUrl)
assert.Equal(t, tt.cvmId, agentSrv.cvmId)
})
}
}
func TestAgentServer_Start(t *testing.T) {
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
logger, svc, host, pubKey := setupTest(t)
tests := []struct {
name string
@@ -211,7 +197,7 @@ func TestAgentServer_Start(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tt.setupMocks(svc)
server := NewServer(logger, svc, host, caUrl, cvmId)
server := NewServer(logger, svc, host, nil)
err := server.Start(tt.cfg, tt.cmp)
@@ -238,7 +224,7 @@ func TestAgentServer_Start(t *testing.T) {
}
func TestAgentServer_Stop(t *testing.T) {
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
logger, svc, host, pubKey := setupTest(t)
tests := []struct {
name string
@@ -287,7 +273,7 @@ func TestAgentServer_Stop(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewServer(logger, svc, host, caUrl, cvmId)
server := NewServer(logger, svc, host, nil)
err := tt.setupServer(server)
if err != nil {
@@ -314,8 +300,8 @@ func TestAgentServer_Stop(t *testing.T) {
}
func TestAgentServer_StopMultipleTimes(t *testing.T) {
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
server := NewServer(logger, svc, host, caUrl, cvmId)
logger, svc, host, pubKey := setupTest(t)
server := NewServer(logger, svc, host, nil)
// Start the server
cfg := agent.AgentConfig{Port: "7005"}
@@ -358,8 +344,8 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
}
func TestAgentServer_StartAfterStop(t *testing.T) {
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
server := NewServer(logger, svc, host, caUrl, cvmId)
logger, svc, host, pubKey := setupTest(t)
server := NewServer(logger, svc, host, nil)
cfg := agent.AgentConfig{Port: "7006"}
cmp := agent.Computation{
@@ -425,7 +411,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
}
func TestAgentServer_ConfigValidation(t *testing.T) {
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
logger, svc, host, pubKey := setupTest(t)
tests := []struct {
name string
@@ -512,7 +498,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewServer(logger, svc, host, caUrl, cvmId)
server := NewServer(logger, svc, host, nil)
err := server.Start(tt.config, tt.cmp)
+2 -2
View File
@@ -113,7 +113,7 @@ func TestCLI_NewCreateVMCmd(t *testing.T) {
flags: map[string]string{
"server-url": "https://server.com",
},
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
expectedError: "failed to exit idle mode: dns resolver: missing address ❌",
expectError: true,
},
{
@@ -252,7 +252,7 @@ func TestCLI_NewRemoveVMCmd(t *testing.T) {
cli.connectErr = errors.New("connection failed")
},
args: []string{"vm-123"},
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
expectedError: "failed to exit idle mode: dns resolver: missing address ❌",
expectError: true,
},
{
+13 -1
View File
@@ -26,6 +26,7 @@ import (
"github.com/ultravioletrs/cocos/agent/cvms/server"
"github.com/ultravioletrs/cocos/agent/events"
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
@@ -163,7 +164,18 @@ func main() {
return
}
mc, err := cvmsapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, cfg.CAUrl, cfg.CVMId), storageDir, reconnectFn, cvmGRPCClient)
var certProvider atls.CertificateProvider
if ccPlatform != attestation.NoCC {
certProvider, err = atls.NewProvider(provider, ccPlatform, cfg.CVMId, cfg.CAUrl)
if err != nil {
logger.Error(fmt.Sprintf("failed to create certificate provider: %s", err))
exitCode = 1
return
}
}
mc, err := cvmsapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), storageDir, reconnectFn, cvmGRPCClient)
if err != nil {
logger.Error(err.Error())
exitCode = 1
+1 -1
View File
@@ -146,7 +146,7 @@ func main() {
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
}
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, "", "")
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, http.MakeHandler(chi.NewMux(), svcName, cfg.InstanceID), logger)
+90 -303
View File
@@ -5,18 +5,12 @@ package atls
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"strings"
@@ -26,356 +20,149 @@ import (
certscli "github.com/absmach/certs/cli"
"github.com/absmach/certs/errors"
certssdk "github.com/absmach/certs/sdk"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
)
const (
vmpl2 = 2
organization = "Ultraviolet"
country = "Serbia"
province = ""
locality = "Belgrade"
streetAddress = "Bulevar Arsenija Carnojevica 103"
postalCode = "11000"
notAfterYear = 1
notAfterMonth = 0
notAfterDay = 0
defaultNotAfterYears = 1
nonceLength = 64
nonceSuffix = ".nonce"
)
// Platform-specific OIDs for certificate extensions.
var (
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
errCertificateParse = errors.New("failed to parse x509 certificate")
errAttVerification = errors.New("certificate is not self signed")
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
)
type csrReq struct {
// CertificateSubject contains certificate subject information.
type CertificateSubject struct {
Organization string
Country string
Province string
Locality string
StreetAddress string
PostalCode string
}
// DefaultCertificateSubject returns the default certificate subject for Ultraviolet.
func DefaultCertificateSubject() CertificateSubject {
return CertificateSubject{
Organization: "Ultraviolet",
Country: "Serbia",
Province: "",
Locality: "Belgrade",
StreetAddress: "Bulevar Arsenija Carnojevica 103",
PostalCode: "11000",
}
}
// CAClient handles communication with Certificate Authority.
type CAClient struct {
baseURL string
client *http.Client
}
type CSRRequest struct {
CSR string `json:"csr,omitempty"`
}
func getPlatformProvider(platformType attestation.PlatformType) (attestation.Provider, error) {
switch platformType {
case attestation.SNPvTPM:
return vtpm.NewProvider(true, vmpl2), nil
case attestation.Azure:
return azure.NewProvider(), nil
case attestation.TDX:
return tdx.NewProvider(), nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
func NewCAClient(baseURL string) *CAClient {
return &CAClient{
baseURL: baseURL,
client: &http.Client{},
}
}
func getPlatformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
var verifier attestation.Verifier
switch platformType {
case attestation.SNPvTPM:
verifier = vtpm.NewVerifier(nil)
case attestation.Azure:
verifier = azure.NewVerifier(nil)
case attestation.TDX:
verifier = tdx.NewVerifier()
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
func (c *CAClient) RequestCertificate(csrMetadata certs.CSRMetadata, privateKey *ecdsa.PrivateKey, cvmID string, ttl time.Duration) ([]byte, error) {
csr, sdkerr := certscli.CreateCSR(csrMetadata, privateKey)
if sdkerr != nil {
return nil, fmt.Errorf("failed to create CSR: %w", sdkerr)
}
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
request := CSRRequest{CSR: string(csr.CSR)}
requestData, err := json.Marshal(request)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to marshal CSR request: %w", err)
}
return verifier, nil
}
func getOID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
switch platformType {
case attestation.SNPvTPM:
return SNPvTPMOID, nil
case attestation.Azure:
return AzureOID, nil
case attestation.TDX:
return TDXOID, nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
}
endpoint := fmt.Sprintf("certs/csrs/%s", cvmID)
query := url.Values{}
query.Add("ttl", ttl.String())
requestURL := fmt.Sprintf("%s/%s?%s", c.baseURL, endpoint, query.Encode())
func getPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
switch {
case oid.Equal(SNPvTPMOID):
return attestation.SNPvTPM, nil
case oid.Equal(AzureOID):
return attestation.Azure, nil
case oid.Equal(TDXOID):
return attestation.TDX, nil
default:
return 0, fmt.Errorf("unsupported OID: %v", oid)
}
}
func verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
verifier, err := getPlatformVerifier(pType)
_, responseBody, err := c.processRequest(http.MethodPost, requestURL, requestData, nil, http.StatusOK)
if err != nil {
return fmt.Errorf("failed to get platform verifier: %w", err)
return nil, fmt.Errorf("failed to process CA request: %w", err)
}
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:vtpm.Nonce]); err != nil {
fmt.Printf("failed to verify attestation: %v\n", err)
return err
var cert certssdk.Certificate
if err := json.Unmarshal(responseBody, &cert); err != nil {
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
}
return nil
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
block, rest := pem.Decode([]byte(cleanCertificateString))
if len(rest) != 0 {
return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data")
}
if block == nil {
return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found")
}
return block.Bytes, nil
}
func GetCertificate(caUrl string, cvmId string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
pType := attestation.CCPlatform()
provider, err := getPlatformProvider(pType)
if err != nil {
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, fmt.Errorf("failed to get platform provider: %w", err)
}
}
teeOid, err := getOID(pType)
if err != nil {
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, fmt.Errorf("failed to get OID for platform type: %w", err)
}
}
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
curve := elliptic.P256()
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private/public key: %w", err)
}
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key to DER format: %w", err)
}
sniLength := len(clientHello.ServerName)
if sniLength < 7 || clientHello.ServerName[sniLength-6:] != ".nonce" {
return nil, fmt.Errorf("invalid server name: %s", clientHello.ServerName)
}
nonceStr := clientHello.ServerName[:sniLength-6]
nonce, err := hex.DecodeString(nonceStr)
if err != nil {
return nil, fmt.Errorf("failed to decode nonce from server name: %w", err)
}
if len(nonce) != 64 {
return nil, fmt.Errorf("invalid nonce length: expected 64 bytes, got %d bytes", len(nonce))
}
attestExtension, err := getCertificateExtension(provider, pubKeyDER, nonce, teeOid)
if err != nil {
return nil, fmt.Errorf("failed to get certificate extension: %w", err)
}
var certDERBytes []byte
if caUrl == "" && cvmId == "" {
certTemplate := &x509.Certificate{
SerialNumber: big.NewInt(202403311),
Subject: pkix.Name{
Organization: []string{organization},
Country: []string{country},
Province: []string{province},
Locality: []string{locality},
StreetAddress: []string{streetAddress},
PostalCode: []string{postalCode},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: []pkix.Extension{attestExtension},
}
DERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %w", err)
}
certDERBytes = DERBytes
} else {
csrmd := certs.CSRMetadata{
Organization: []string{organization},
Country: []string{country},
Province: []string{province},
Locality: []string{locality},
StreetAddress: []string{streetAddress},
PostalCode: []string{postalCode},
ExtraExtensions: []pkix.Extension{attestExtension},
}
csr, err := certscli.CreateCSR(csrmd, privateKey)
if err != nil {
return nil, fmt.Errorf("failed to create CSR: %w", err)
}
csrData := string(csr.CSR)
r := csrReq{
CSR: csrData,
}
data, sdkErr := json.Marshal(r)
if sdkErr != nil {
return nil, fmt.Errorf("failed to marshal CSR request: %w", sdkErr)
}
notBefore := time.Now()
notAfter := time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay)
ttlString := notAfter.Sub(notBefore).String()
query := url.Values{}
query.Add("ttl", ttlString)
query_string := query.Encode()
certsEndpoint := "certs"
csrEndpoint := "csrs"
endpoint := fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, cvmId)
url := fmt.Sprintf("%s/%s?%s", caUrl, endpoint, query_string)
_, body, err := processRequest(http.MethodPost, url, data, nil, http.StatusOK)
if err != nil {
return nil, fmt.Errorf("failed to process request: %w", err)
}
var cert certssdk.Certificate
if err := json.Unmarshal(body, &cert); err != nil {
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
}
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
block, rest := pem.Decode([]byte(cleanCertificateString))
if len(rest) != 0 {
return nil, fmt.Errorf("failed to convert generated certificate to DER format: %s", cleanCertificateString)
}
certDERBytes = block.Bytes
}
return &tls.Certificate{
Certificate: [][]byte{certDERBytes},
PrivateKey: privateKey,
}, nil
}
}
func getCertificateExtension(provider attestation.Provider, pubKey []byte, nonce []byte, teeOid asn1.ObjectIdentifier) (pkix.Extension, error) {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
rawAttestation, err := provider.Attestation(hashNonce[:], hashNonce[:vtpm.Nonce])
if err != nil {
return pkix.Extension{}, fmt.Errorf("failed to get attestation: %w", err)
}
return pkix.Extension{
Id: teeOid,
Value: rawAttestation,
}, nil
}
func processRequest(method, reqUrl string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data))
func (c *CAClient) processRequest(method, reqURL string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
req, err := http.NewRequest(method, reqURL, bytes.NewReader(data))
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
// Sets a default value for the Content-Type.
// Overridden if Content-Type is passed in the headers arguments.
req.Header.Add("Content-Type", "application/json")
for key, value := range headers {
req.Header.Add(key, value)
}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
resp, err := client.Do(req)
resp, err := c.client.Do(req)
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
defer resp.Body.Close()
sdkerr := errors.CheckError(resp, expectedRespCodes...)
if sdkerr != nil {
return make(http.Header), []byte{}, sdkerr
sdkErr := errors.CheckError(resp, expectedRespCodes...)
if sdkErr != nil {
return make(http.Header), []byte{}, sdkErr
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return make(http.Header), []byte{}, errors.NewSDKError(err)
}
return resp.Header, body, nil
}
// VerifyPeerCertificateATLS verifies peer certificates for Attested TLS.
func VerifyPeerCertificateATLS(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
cert, err := x509.ParseCertificate(rawCerts[0])
func extractNonceFromSNI(serverName string) ([]byte, error) {
if len(serverName) < len(nonceSuffix) || !hasNonceSuffix(serverName) {
return nil, fmt.Errorf("invalid server name: %s", serverName)
}
nonceStr := serverName[:len(serverName)-len(nonceSuffix)]
nonce, err := hex.DecodeString(nonceStr)
if err != nil {
return errors.Wrap(errCertificateParse, err)
return nil, fmt.Errorf("failed to decode nonce: %w", err)
}
err = verifyCertificateSignature(cert, rootCAs)
if err != nil {
return errors.Wrap(errAttVerification, err)
if len(nonce) != nonceLength {
return nil, fmt.Errorf("invalid nonce length: expected %d bytes, got %d bytes", nonceLength, len(nonce))
}
for _, ext := range cert.Extensions {
pType, err := getPlatformTypeFromOID(ext.Id)
if err == nil {
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
}
return verifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
}
}
return errors.New("attestation extension not found in certificate")
return nonce, nil
}
// VerifyCertificateSignature verifies the certificate signature against root CAs.
func verifyCertificateSignature(cert *x509.Certificate, rootCAs *x509.CertPool) error {
if rootCAs == nil {
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
}
if _, err := cert.Verify(opts); err != nil {
return err
}
return nil
func hasNonceSuffix(serverName string) bool {
return len(serverName) >= len(nonceSuffix) &&
serverName[len(serverName)-len(nonceSuffix):] == nonceSuffix
}
+610 -375
View File
File diff suppressed because it is too large Load Diff
+66
View File
@@ -0,0 +1,66 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"encoding/asn1"
"fmt"
"github.com/ultravioletrs/cocos/pkg/attestation"
"golang.org/x/crypto/sha3"
)
// AttestationProvider defines the interface for platform attestation operations.
type AttestationProvider interface {
Attest(pubKey []byte, nonce []byte) ([]byte, error)
OID() asn1.ObjectIdentifier
PlatformType() attestation.PlatformType
}
// PlatformAttestationProvider handles platform attestation operations.
type platformAttestationProvider struct {
provider attestation.Provider
oid asn1.ObjectIdentifier
platformType attestation.PlatformType
}
// NewAttestationProvider creates a new attestation provider for the given platform type.
func NewAttestationProvider(provider attestation.Provider, platformType attestation.PlatformType) (AttestationProvider, error) {
oid, err := OID(platformType)
if err != nil {
return nil, fmt.Errorf("failed to get OID: %w", err)
}
return &platformAttestationProvider{
provider: provider,
oid: oid,
platformType: platformType,
}, nil
}
func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
return p.provider.Attestation(hashNonce[:], hashNonce[:32])
}
func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier {
return p.oid
}
func (p *platformAttestationProvider) PlatformType() attestation.PlatformType {
return p.platformType
}
func OID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
switch platformType {
case attestation.SNPvTPM:
return SNPvTPMOID, nil
case attestation.Azure:
return AzureOID, nil
case attestation.TDX:
return TDXOID, nil
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
}
+162
View File
@@ -0,0 +1,162 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"time"
"github.com/absmach/certs"
"github.com/ultravioletrs/cocos/pkg/attestation"
)
// CertificateProvider defines the interface for providing TLS certificates.
type CertificateProvider interface {
GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
// AttestedCertificateProvider provides attested TLS certificates.
type attestedCertificateProvider struct {
attestationProvider AttestationProvider
caClient *CAClient
subject CertificateSubject
useCA bool
cvmID string
ttl time.Duration
notAfterYears int
}
// NewAttestedProvider creates a new attested certificate provider for self-signed certificates.
func NewAttestedProvider(
attestationProvider AttestationProvider,
subject CertificateSubject,
) CertificateProvider {
return &attestedCertificateProvider{
attestationProvider: attestationProvider,
subject: subject,
useCA: false,
notAfterYears: defaultNotAfterYears,
}
}
// NewAttestedCAProvider creates a new attested certificate provider for CA-signed certificates.
func NewAttestedCAProvider(
attestationProvider AttestationProvider,
subject CertificateSubject,
caURL, cvmID string,
) CertificateProvider {
return &attestedCertificateProvider{
attestationProvider: attestationProvider,
subject: subject,
caClient: NewCAClient(caURL),
useCA: true,
cvmID: cvmID,
ttl: time.Hour * 24 * 365, // Default 1 year
}
}
// SetTTL sets the certificate TTL for CA-signed certificates.
func (p *attestedCertificateProvider) SetTTL(ttl time.Duration) {
p.ttl = ttl
}
func (p *attestedCertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("failed to generate private key: %w", err)
}
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
nonce, err := extractNonceFromSNI(clientHello.ServerName)
if err != nil {
return nil, fmt.Errorf("failed to extract nonce: %w", err)
}
attestationData, err := p.attestationProvider.Attest(pubKeyDER, nonce)
if err != nil {
return nil, fmt.Errorf("failed to get attestation: %w", err)
}
extension := pkix.Extension{
Id: p.attestationProvider.OID(),
Value: attestationData,
}
var certDERBytes []byte
if p.useCA {
certDERBytes, err = p.generateCASignedCertificate(privateKey, extension)
} else {
certDERBytes, err = p.generateSelfSignedCertificate(privateKey, extension)
}
if err != nil {
return nil, fmt.Errorf("failed to generate certificate: %w", err)
}
return &tls.Certificate{
Certificate: [][]byte{certDERBytes},
PrivateKey: privateKey,
}, nil
}
func (p *attestedCertificateProvider) generateSelfSignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
certTemplate := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().Unix()),
Subject: pkix.Name{
Organization: []string{p.subject.Organization},
Country: []string{p.subject.Country},
Province: []string{p.subject.Province},
Locality: []string{p.subject.Locality},
StreetAddress: []string{p.subject.StreetAddress},
PostalCode: []string{p.subject.PostalCode},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(p.notAfterYears, 0, 0),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
ExtraExtensions: []pkix.Extension{extension},
}
return x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
}
func (p *attestedCertificateProvider) generateCASignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
csrMetadata := certs.CSRMetadata{
Organization: []string{p.subject.Organization},
Country: []string{p.subject.Country},
Province: []string{p.subject.Province},
Locality: []string{p.subject.Locality},
StreetAddress: []string{p.subject.StreetAddress},
PostalCode: []string{p.subject.PostalCode},
ExtraExtensions: []pkix.Extension{extension},
}
return p.caClient.RequestCertificate(csrMetadata, privateKey, p.cvmID, p.ttl)
}
func NewProvider(provider attestation.Provider, platformType attestation.PlatformType, caURL, cvmID string) (CertificateProvider, error) {
attestationProvider, err := NewAttestationProvider(provider, platformType)
if err != nil {
return nil, fmt.Errorf("failed to create attestation provider: %w", err)
}
subject := DefaultCertificateSubject()
if caURL != "" && cvmID != "" {
return NewAttestedCAProvider(attestationProvider, subject, caURL, cvmID), nil
}
return NewAttestedProvider(attestationProvider, subject), nil
}
+125
View File
@@ -0,0 +1,125 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package atls
import (
"crypto/x509"
"encoding/asn1"
"fmt"
"time"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"golang.org/x/crypto/sha3"
)
type CertificateVerifier interface {
VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte) error
}
// CertificateVerifier handles certificate verification operations.
type certificateVerifier struct {
rootCAs *x509.CertPool
}
func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
return &certificateVerifier{rootCAs: rootCAs}
}
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
if len(rawCerts) == 0 {
return fmt.Errorf("no certificates provided")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("failed to parse x509 certificate: %w", err)
}
if err := v.verifyCertificateSignature(cert); err != nil {
return fmt.Errorf("certificate signature verification failed: %w", err)
}
return v.verifyAttestationExtension(cert, nonce)
}
func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error {
rootCAs := v.rootCAs
if rootCAs == nil {
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
}
_, err := cert.Verify(opts)
return err
}
func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error {
for _, ext := range cert.Extensions {
if platformType, err := platformTypeFromOID(ext.Id); err == nil {
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key: %w", err)
}
return v.verifyCertificateExtension(ext.Value, pubKeyDER, nonce, platformType)
}
}
return fmt.Errorf("attestation extension not found in certificate")
}
func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error {
verifier, err := platformVerifier(platformType)
if err != nil {
return fmt.Errorf("failed to get platform verifier: %w", err)
}
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:32]); err != nil {
return fmt.Errorf("failed to verify attestation: %w", err)
}
return nil
}
func platformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
switch {
case oid.Equal(SNPvTPMOID):
return attestation.SNPvTPM, nil
case oid.Equal(AzureOID):
return attestation.Azure, nil
case oid.Equal(TDXOID):
return attestation.TDX, nil
default:
return 0, fmt.Errorf("unsupported OID: %v", oid)
}
}
func platformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
var verifier attestation.Verifier
switch platformType {
case attestation.SNPvTPM:
verifier = vtpm.NewVerifier(nil)
case attestation.Azure:
verifier = azure.NewVerifier(nil)
case attestation.TDX:
verifier = tdx.NewVerifier()
default:
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
}
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
if err != nil {
return nil, err
}
return verifier, nil
}
+103
View File
@@ -0,0 +1,103 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"crypto/tls"
mock "github.com/stretchr/testify/mock"
)
// NewCertificateProvider creates a new instance of CertificateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewCertificateProvider(t interface {
mock.TestingT
Cleanup(func())
}) *CertificateProvider {
mock := &CertificateProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// CertificateProvider is an autogenerated mock type for the CertificateProvider type
type CertificateProvider struct {
mock.Mock
}
type CertificateProvider_Expecter struct {
mock *mock.Mock
}
func (_m *CertificateProvider) EXPECT() *CertificateProvider_Expecter {
return &CertificateProvider_Expecter{mock: &_m.Mock}
}
// GetCertificate provides a mock function for the type CertificateProvider
func (_mock *CertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
ret := _mock.Called(clientHello)
if len(ret) == 0 {
panic("no return value specified for GetCertificate")
}
var r0 *tls.Certificate
var r1 error
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) (*tls.Certificate, error)); ok {
return returnFunc(clientHello)
}
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) *tls.Certificate); ok {
r0 = returnFunc(clientHello)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*tls.Certificate)
}
}
if returnFunc, ok := ret.Get(1).(func(*tls.ClientHelloInfo) error); ok {
r1 = returnFunc(clientHello)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CertificateProvider_GetCertificate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCertificate'
type CertificateProvider_GetCertificate_Call struct {
*mock.Call
}
// GetCertificate is a helper method to define mock.On call
// - clientHello *tls.ClientHelloInfo
func (_e *CertificateProvider_Expecter) GetCertificate(clientHello interface{}) *CertificateProvider_GetCertificate_Call {
return &CertificateProvider_GetCertificate_Call{Call: _e.mock.On("GetCertificate", clientHello)}
}
func (_c *CertificateProvider_GetCertificate_Call) Run(run func(clientHello *tls.ClientHelloInfo)) *CertificateProvider_GetCertificate_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 *tls.ClientHelloInfo
if args[0] != nil {
arg0 = args[0].(*tls.ClientHelloInfo)
}
run(
arg0,
)
})
return _c
}
func (_c *CertificateProvider_GetCertificate_Call) Return(certificate *tls.Certificate, err error) *CertificateProvider_GetCertificate_Call {
_c.Call.Return(certificate, err)
return _c
}
func (_c *CertificateProvider_GetCertificate_Call) RunAndReturn(run func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)) *CertificateProvider_GetCertificate_Call {
_c.Call.Return(run)
return _c
}
+2 -2
View File
@@ -101,7 +101,7 @@ func TestAgentClientIntegration(t *testing.T) {
Timeout: 1,
},
},
err: errors.New("failed to connect to grpc server"),
err: errors.New("agent service is unavailable"),
},
{
name: "invalid config, missing AttestationPolicy with aTLS",
@@ -127,7 +127,7 @@ func TestAgentClientIntegration(t *testing.T) {
}
client, agentClient, err := NewAgentClient(ctx, tt.config)
assert.True(t, errors.Contains(err, tt.err))
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error to contain: %v, got: %v", tt.err, err))
if err != nil {
assert.Nil(t, client)
assert.Nil(t, agentClient)
+1 -10
View File
@@ -89,15 +89,6 @@ func TestAgentClientIntegration(t *testing.T) {
},
err: nil,
},
{
name: "server not healthy",
serverRunning: false,
config: clients.StandardClientConfig{
URL: "",
Timeout: 1,
},
err: errors.New("failed to connect to grpc server"),
},
}
for _, tt := range tests {
@@ -109,7 +100,7 @@ func TestAgentClientIntegration(t *testing.T) {
}
client, agentClient, err := NewCVMClient(tt.config)
assert.True(t, errors.Contains(err, tt.err))
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error to contain: %v, got: %v", tt.err, err))
if err != nil {
assert.Nil(t, client)
assert.Nil(t, agentClient)
+1 -1
View File
@@ -137,7 +137,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
RootCAs: rootCAs,
ServerName: sni,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return atls.VerifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
return atls.NewCertificateVerifier(rootCAs).VerifyPeerCertificate(rawCerts, verifiedChains, nonce)
},
}
+104 -71
View File
@@ -25,39 +25,43 @@ import (
)
const (
stopWaitTime = 5 * time.Second
organization = "Ultraviolet"
country = "Serbia"
province = ""
locality = "Belgrade"
streetAddress = "Bulevar Arsenija Carnojevica 103"
postalCode = "11000"
notAfterYear = 1
notAfterMonth = 0
notAfterDay = 0
nonceSize = 32
stopWaitTime = 5 * time.Second
)
type Server struct {
server.BaseServer
mu sync.RWMutex
server *grpc.Server
health *health.Server
registerService serviceRegister
authSvc auth.Authenticator
caUrl string
cvmId string
started bool
stopped bool
mu sync.RWMutex
server *grpc.Server
health *health.Server
registerService serviceRegister
authSvc auth.Authenticator
certProvider atls.CertificateProvider
attestedTLSEnabled bool
started bool
stopped bool
}
type serviceRegister func(srv *grpc.Server)
var _ server.Server = (*Server)(nil)
func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, caUrl string, cvmId string) server.Server {
func New(
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, certProvider atls.CertificateProvider,
) server.Server {
base := config.GetBaseConfig()
listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port)
var attestedTLS bool
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
if certProvider == nil {
logger.Error("Failed to create certificate provider")
} else {
attestedTLS = true
}
}
return &Server{
BaseServer: server.BaseServer{
Ctx: ctx,
@@ -67,10 +71,10 @@ func New(ctx context.Context, cancel context.CancelFunc, name string, config ser
Config: config,
Logger: logger,
},
registerService: registerService,
authSvc: authSvc,
caUrl: caUrl,
cvmId: cvmId,
registerService: registerService,
authSvc: authSvc,
certProvider: certProvider,
attestedTLSEnabled: attestedTLS,
}
}
@@ -92,65 +96,28 @@ func (s *Server) Start() error {
grpc.StatsHandler(otelgrpc.NewServerHandler()),
}
// Add authentication interceptors if auth service is available
if s.authSvc != nil {
unary, stream := agentgrpc.NewAuthInterceptor(s.authSvc)
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
}
creds := grpc.Creds(insecure.NewCredentials())
c := s.Config.GetBaseConfig()
if agCfg, ok := s.Config.(server.AgentConfig); ok && agCfg.AttestedTLS {
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
GetCertificate: atls.GetCertificate(s.caUrl, s.cvmId),
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, c.ServerCAFile, c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
if mtls {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
}
} else {
switch {
case c.CertFile != "" || c.KeyFile != "":
tlsSetup, err := server.SetupRegularTLS(c.CertFile, c.KeyFile, c.ServerCAFile, c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup TLS: %w", err)
}
creds = grpc.Creds(credentials.NewTLS(tlsSetup.Config))
if tlsSetup.MTLS {
mtlsCA := server.BuildMTLSDescription(c.ServerCAFile, c.ClientCAFile)
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, c.CertFile, c.KeyFile, mtlsCA))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile))
}
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
}
// Configure credentials
creds, err := s.configureCredentials()
if err != nil {
return fmt.Errorf("failed to configure credentials: %w", err)
}
grpcServerOptions = append(grpcServerOptions, creds)
// Create listener
listener, err := net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
}
grpcServerOptions = append(grpcServerOptions, creds)
// Create and configure server
s.mu.Lock()
s.server = grpc.NewServer(grpcServerOptions...)
s.health = health.NewServer()
@@ -159,6 +126,7 @@ func (s *Server) Start() error {
s.health.SetServingStatus(s.Name, grpchealth.HealthCheckResponse_SERVING)
s.mu.Unlock()
// Start server
go func() {
s.mu.RLock()
server := s.server
@@ -178,6 +146,71 @@ func (s *Server) Start() error {
}
}
func (s *Server) configureCredentials() (grpc.ServerOption, error) {
baseConfig := s.Config.GetBaseConfig()
// Check if attested TLS should be used
if s.shouldUseAttestedTLS() {
return s.configureAttestedTLS(baseConfig.Config)
}
// Check if regular TLS should be used
if s.shouldUseRegularTLS(baseConfig.Config) {
return s.configureRegularTLS(baseConfig.Config)
}
// Use insecure credentials
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
return grpc.Creds(insecure.NewCredentials()), nil
}
func (s *Server) shouldUseAttestedTLS() bool {
return s.attestedTLSEnabled && s.certProvider != nil
}
func (s *Server) shouldUseRegularTLS(config server.Config) bool {
return config.CertFile != "" || config.KeyFile != ""
}
func (s *Server) configureAttestedTLS(config server.Config) (grpc.ServerOption, error) {
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
GetCertificate: s.certProvider.GetCertificate,
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, config.ServerCAFile, config.ClientCAFile)
if err != nil {
return nil, fmt.Errorf("failed to configure certificate authorities: %w", err)
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
}
return grpc.Creds(credentials.NewTLS(tlsConfig)), nil
}
func (s *Server) configureRegularTLS(config server.Config) (grpc.ServerOption, error) {
tlsSetup, err := server.SetupRegularTLS(config.CertFile, config.KeyFile, config.ServerCAFile, config.ClientCAFile)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS: %w", err)
}
if tlsSetup.MTLS {
mtlsCA := server.BuildMTLSDescription(config.ServerCAFile, config.ClientCAFile)
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s",
s.Name, s.Address, config.CertFile, config.KeyFile, mtlsCA))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s",
s.Name, s.Address, config.CertFile, config.KeyFile))
}
return grpc.Creds(credentials.NewTLS(tlsSetup.Config)), nil
}
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
+8 -5
View File
@@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/server"
"google.golang.org/grpc"
@@ -49,7 +50,7 @@ func TestNew(t *testing.T) {
logger := slog.Default()
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
assert.NotNil(t, srv)
assert.IsType(t, &Server{}, srv)
@@ -98,7 +99,7 @@ func TestServerStartWithTLSFile(t *testing.T) {
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
var wg sync.WaitGroup
wg.Add(1)
@@ -144,7 +145,7 @@ func TestServerStartWithmTLSFile(t *testing.T) {
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
var wg sync.WaitGroup
wg.Add(1)
@@ -183,7 +184,7 @@ func TestServerStop(t *testing.T) {
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
go func() {
err := srv.Start()
@@ -367,7 +368,9 @@ func TestServerInitializationAndStartup(t *testing.T) {
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
mockCertProvider := new(mocks.CertificateProvider)
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, mockCertProvider)
var wg sync.WaitGroup
wg.Add(1)
+29 -25
View File
@@ -23,23 +23,35 @@ const (
type httpServer struct {
server.BaseServer
server *http.Server
caURL string
server *http.Server
certProvider atls.CertificateProvider
attestedTLSEnabled bool
}
var _ server.Server = (*httpServer)(nil)
func NewServer(
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
handler http.Handler, logger *slog.Logger, caURL string,
handler http.Handler, logger *slog.Logger, certProvider atls.CertificateProvider,
) server.Server {
baseServer := server.NewBaseServer(ctx, cancel, name, config, logger)
hserver := &http.Server{Addr: baseServer.Address, Handler: handler}
var attestedTLS bool
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
if certProvider == nil {
logger.Error("Failed to create certificate provider")
} else {
attestedTLS = true
}
}
return &httpServer{
BaseServer: baseServer,
server: hserver,
caURL: caURL,
BaseServer: baseServer,
server: hserver,
certProvider: certProvider,
attestedTLSEnabled: attestedTLS,
}
}
@@ -66,22 +78,15 @@ func (s *httpServer) Stop() error {
if err := s.server.Shutdown(ctx); err != nil {
s.Logger.Error(fmt.Sprintf(
"%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err))
return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err)
}
s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address))
return nil
}
func (s *httpServer) shouldUseAttestedTLS() bool {
cfg, ok := s.Config.(server.AgentConfig)
if !ok {
return false
}
return cfg.AttestedTLS && s.caURL != ""
return s.attestedTLSEnabled && s.certProvider != nil
}
func (s *httpServer) shouldUseRegularTLS() bool {
@@ -91,10 +96,11 @@ func (s *httpServer) shouldUseRegularTLS() bool {
func (s *httpServer) startWithAttestedTLS() error {
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
GetCertificate: atls.GetCertificate(s.caURL, ""),
GetCertificate: s.certProvider.GetCertificate,
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
baseConfig := s.Config.GetBaseConfig()
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
@@ -107,12 +113,12 @@ func (s *httpServer) startWithAttestedTLS() error {
s.Protocol = httpsProtocol
s.logAttestedTLSStart(mtls)
return s.listenAndServe(true)
}
func (s *httpServer) startWithRegularTLS() error {
tlsSetup, err := server.SetupRegularTLS(s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
baseConfig := s.Config.GetBaseConfig()
tlsSetup, err := server.SetupRegularTLS(baseConfig.CertFile, baseConfig.KeyFile, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup TLS: %w", err)
}
@@ -121,13 +127,11 @@ func (s *httpServer) startWithRegularTLS() error {
s.Protocol = httpsProtocol
s.logRegularTLSStart(tlsSetup.MTLS)
return s.listenAndServe(true)
}
func (s *httpServer) startWithoutTLS() error {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address))
return s.listenAndServe(false)
}
@@ -140,15 +144,15 @@ func (s *httpServer) logAttestedTLSStart(mtls bool) {
}
func (s *httpServer) logRegularTLSStart(mtls bool) {
baseConfig := s.Config.GetBaseConfig()
if mtls {
s.Logger.Info(fmt.Sprintf(
"%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s",
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile,
s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile))
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile,
baseConfig.ServerCAFile, baseConfig.ClientCAFile))
} else {
s.Logger.Info(
fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile))
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile))
}
}
+28 -28
View File
@@ -15,6 +15,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
"github.com/ultravioletrs/cocos/pkg/server"
)
@@ -64,57 +66,55 @@ func TestNewServer(t *testing.T) {
}
handler := &mockHandler{}
logger := slog.Default()
caURL := "https://ca.example.com"
server := NewServer(ctx, cancel, name, config, handler, logger, caURL)
server := NewServer(ctx, cancel, name, config, handler, logger, nil)
assert.NotNil(t, server)
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.Equal(t, caURL, httpSrv.caURL)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, handler, httpSrv.server.Handler)
}
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
mockCertProvider := new(mocks.CertificateProvider)
tests := []struct {
name string
config server.ServerConfiguration
caURL string
attestedTLS bool
expected bool
name string
config server.ServerConfiguration
expected bool
certProvider atls.CertificateProvider
}{
{
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and caURL is not empty",
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and certProvider is not empty",
config: server.AgentConfig{
AttestedTLS: true,
},
caURL: "https://ca.example.com",
expected: true,
certProvider: mockCertProvider,
expected: true,
},
{
name: "should not use attested TLS when caURL is empty",
name: "should not use attested TLS when certProvider is empty",
config: server.AgentConfig{
AttestedTLS: true,
},
caURL: "",
expected: false,
certProvider: nil,
expected: false,
},
{
name: "should not use attested TLS when AttestedTLS is false",
config: server.AgentConfig{
AttestedTLS: false,
},
caURL: "https://ca.example.com",
expected: false,
certProvider: mockCertProvider,
expected: false,
},
{
name: "should not use attested TLS when config is not AgentConfig",
config: &mockServerConfig{
baseConfig: &mockBaseConfig{},
},
caURL: "https://ca.example.com",
expected: false,
certProvider: mockCertProvider,
expected: false,
},
}
@@ -123,7 +123,7 @@ func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
ctx := context.Background()
cancel := func() {}
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.caURL)
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.certProvider)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseAttestedTLS()
@@ -176,7 +176,7 @@ func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
},
}
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), "")
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseRegularTLS()
@@ -192,7 +192,7 @@ func TestHttpServer_Stop(t *testing.T) {
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Start a test server that we can control
@@ -229,7 +229,7 @@ func TestHttpServer_logAttestedTLSStart(t *testing.T) {
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
@@ -269,7 +269,7 @@ func TestHttpServer_logRegularTLSStart(t *testing.T) {
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
@@ -289,7 +289,7 @@ func TestHttpServer_startWithoutTLS(t *testing.T) {
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Use a test server to avoid binding to actual ports
@@ -334,7 +334,7 @@ func TestHttpServer_Protocol(t *testing.T) {
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
tt.setupTLS(httpSrv)
@@ -351,7 +351,7 @@ func TestHttpServer_ContextCancellation(t *testing.T) {
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Cancel the context immediately
@@ -372,7 +372,7 @@ func TestHttpServer_TLSConfiguration(t *testing.T) {
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Test TLS configuration setup
@@ -397,7 +397,7 @@ func TestHttpServer_Lifecycle(t *testing.T) {
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
// Test that server can be created and has expected initial state
httpSrv, ok := server.(*httpServer)
+1 -15
View File
@@ -40,8 +40,6 @@ var (
attestedTLSString string
attestedTLS bool
pubKeyFile string
caUrl string
cvmId string
clientCAFile string
)
@@ -108,8 +106,6 @@ func main() {
flagSet.StringVar(&pubKeyFile, "public-key-path", "", "Path to the public key file")
flagSet.StringVar(&attestedTLSString, "attested-tls-bool", "", "Should aTLS be used, must be 'true' or 'false'")
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas")
flagSet.StringVar(&caUrl, "ca-url", "", "URL for certificate authority, must be specified if aTLS is used")
flagSet.StringVar(&cvmId, "cvm-id", "", "UUID for a CVM, must be specified if aTLS is used")
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
flagSetParseError := flagSet.Parse(os.Args[1:])
@@ -145,16 +141,6 @@ func main() {
dataPaths = strings.Split(dataPathString, ",")
}
if err == nil && caUrl != "" && !attestedTLS {
parsingErrorString.WriteString("CA URL is only available with attested TLS\n")
parsingError = true
}
if err == nil && cvmId != "" && !attestedTLS {
parsingErrorString.WriteString("CVM UUID is only available with attested TLS\n")
parsingError = true
}
if parsingError {
parsingErrorString.WriteString("Usage :\n")
flagSet.SetOutput(&parsingErrorString)
@@ -191,7 +177,7 @@ func main() {
return
}
gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, caUrl, cvmId)
gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, nil)
g.Go(func() error {
return gs.Start()