mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Refactor aTLS and gRPC server to use CertificateProvider interface (#522)
* Refactor ATLS and gRPC server to use CertificateProvider interface - Removed unused test cases and mock dependencies in atls_test.go. - Updated TestGetPlatformVerifier to use CertificateVerifier struct. - Introduced CertificateProvider interface for better abstraction in TLS handling. - Refactored gRPC server to accept CertificateProvider and configure TLS accordingly. - Simplified TLS configuration logic in both gRPC and HTTP servers. - Removed unnecessary parameters from server initialization in tests and main function. - Enhanced logging for TLS configurations. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix comments for consistency and clarity in atls.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update expected error messages in VM command tests for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance tests by integrating mock providers and improving error messages for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for certificate generation and attestation providers Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Implement certificate and attestation providers with unified generation logic Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor certificate and attestation provider structures for consistency; implement CertificateVerifier interface and related methods Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor attestation and certificate provider methods for consistency; rename methods and update related logic Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
906d7877b2
commit
c758b3b216
@@ -139,3 +139,10 @@ packages:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/atls:
|
||||
interfaces:
|
||||
CertificateProvider:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
|
||||
+12
-13
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"google.golang.org/grpc"
|
||||
@@ -28,21 +29,19 @@ type AgentServer interface {
|
||||
}
|
||||
|
||||
type agentServer struct {
|
||||
gs server.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
caUrl string
|
||||
cvmId string
|
||||
gs server.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
certProvider atls.CertificateProvider
|
||||
}
|
||||
|
||||
func NewServer(logger *slog.Logger, svc agent.Service, host string, caUrl string, cvmId string) AgentServer {
|
||||
func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider atls.CertificateProvider) AgentServer {
|
||||
return &agentServer{
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
certProvider: certProvider,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,7 +77,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.caUrl, as.cvmId)
|
||||
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.certProvider)
|
||||
|
||||
go func() {
|
||||
err := as.gs.Start()
|
||||
|
||||
@@ -18,12 +18,10 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, string, string, []byte) {
|
||||
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, []byte) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
mockSvc := new(mocks.Service)
|
||||
host := "localhost"
|
||||
caUrl := "https://ca.example.com"
|
||||
cvmId := "test-cvm-id"
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.NoError(t, err, "Failed to generate ECDSA key")
|
||||
@@ -31,19 +29,17 @@ func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, string, stri
|
||||
pubkey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
|
||||
assert.NoError(t, err, "Failed to marshal public key")
|
||||
|
||||
return logger, mockSvc, host, caUrl, cvmId, pubkey
|
||||
return logger, mockSvc, host, pubkey
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, _ := setupTest(t)
|
||||
logger, svc, host, _ := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
caUrl string
|
||||
cvmId string
|
||||
expected AgentServer
|
||||
}{
|
||||
{
|
||||
@@ -51,38 +47,30 @@ func TestNewServer(t *testing.T) {
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
},
|
||||
{
|
||||
name: "server with empty host",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: "",
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
},
|
||||
{
|
||||
name: "server with empty caUrl",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: "",
|
||||
cvmId: cvmId,
|
||||
},
|
||||
{
|
||||
name: "server with empty cvmId",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: caUrl,
|
||||
cvmId: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(tt.logger, tt.svc, tt.host, tt.caUrl, tt.cvmId)
|
||||
server := NewServer(tt.logger, tt.svc, tt.host, nil)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
|
||||
@@ -91,14 +79,12 @@ func TestNewServer(t *testing.T) {
|
||||
assert.Equal(t, tt.logger, agentSrv.logger)
|
||||
assert.Equal(t, tt.svc, agentSrv.svc)
|
||||
assert.Equal(t, tt.host, agentSrv.host)
|
||||
assert.Equal(t, tt.caUrl, agentSrv.caUrl)
|
||||
assert.Equal(t, tt.cvmId, agentSrv.cvmId)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_Start(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -211,7 +197,7 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMocks(svc)
|
||||
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
err := server.Start(tt.cfg, tt.cmp)
|
||||
|
||||
@@ -238,7 +224,7 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentServer_Stop(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -287,7 +273,7 @@ func TestAgentServer_Stop(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
err := tt.setupServer(server)
|
||||
if err != nil {
|
||||
@@ -314,8 +300,8 @@ func TestAgentServer_Stop(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
// Start the server
|
||||
cfg := agent.AgentConfig{Port: "7005"}
|
||||
@@ -358,8 +344,8 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
cfg := agent.AgentConfig{Port: "7006"}
|
||||
cmp := agent.Computation{
|
||||
@@ -425,7 +411,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -512,7 +498,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
err := server.Start(tt.config, tt.cmp)
|
||||
|
||||
|
||||
+2
-2
@@ -113,7 +113,7 @@ func TestCLI_NewCreateVMCmd(t *testing.T) {
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
|
||||
expectedError: "failed to exit idle mode: dns resolver: missing address ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
@@ -252,7 +252,7 @@ func TestCLI_NewRemoveVMCmd(t *testing.T) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
|
||||
expectedError: "failed to exit idle mode: dns resolver: missing address ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
|
||||
+13
-1
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/server"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
@@ -163,7 +164,18 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
mc, err := cvmsapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, cfg.CAUrl, cfg.CVMId), storageDir, reconnectFn, cvmGRPCClient)
|
||||
var certProvider atls.CertificateProvider
|
||||
|
||||
if ccPlatform != attestation.NoCC {
|
||||
certProvider, err = atls.NewProvider(provider, ccPlatform, cfg.CVMId, cfg.CAUrl)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create certificate provider: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mc, err := cvmsapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), storageDir, reconnectFn, cvmGRPCClient)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
|
||||
+1
-1
@@ -146,7 +146,7 @@ func main() {
|
||||
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
|
||||
}
|
||||
|
||||
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, "", "")
|
||||
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil)
|
||||
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, http.MakeHandler(chi.NewMux(), svcName, cfg.InstanceID), logger)
|
||||
|
||||
|
||||
+90
-303
@@ -5,18 +5,12 @@ package atls
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -26,356 +20,149 @@ import (
|
||||
certscli "github.com/absmach/certs/cli"
|
||||
"github.com/absmach/certs/errors"
|
||||
certssdk "github.com/absmach/certs/sdk"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
const (
|
||||
vmpl2 = 2
|
||||
organization = "Ultraviolet"
|
||||
country = "Serbia"
|
||||
province = ""
|
||||
locality = "Belgrade"
|
||||
streetAddress = "Bulevar Arsenija Carnojevica 103"
|
||||
postalCode = "11000"
|
||||
notAfterYear = 1
|
||||
notAfterMonth = 0
|
||||
notAfterDay = 0
|
||||
defaultNotAfterYears = 1
|
||||
nonceLength = 64
|
||||
nonceSuffix = ".nonce"
|
||||
)
|
||||
|
||||
// Platform-specific OIDs for certificate extensions.
|
||||
var (
|
||||
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
|
||||
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
|
||||
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
|
||||
errCertificateParse = errors.New("failed to parse x509 certificate")
|
||||
errAttVerification = errors.New("certificate is not self signed")
|
||||
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
|
||||
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
|
||||
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
|
||||
)
|
||||
|
||||
type csrReq struct {
|
||||
// CertificateSubject contains certificate subject information.
|
||||
type CertificateSubject struct {
|
||||
Organization string
|
||||
Country string
|
||||
Province string
|
||||
Locality string
|
||||
StreetAddress string
|
||||
PostalCode string
|
||||
}
|
||||
|
||||
// DefaultCertificateSubject returns the default certificate subject for Ultraviolet.
|
||||
func DefaultCertificateSubject() CertificateSubject {
|
||||
return CertificateSubject{
|
||||
Organization: "Ultraviolet",
|
||||
Country: "Serbia",
|
||||
Province: "",
|
||||
Locality: "Belgrade",
|
||||
StreetAddress: "Bulevar Arsenija Carnojevica 103",
|
||||
PostalCode: "11000",
|
||||
}
|
||||
}
|
||||
|
||||
// CAClient handles communication with Certificate Authority.
|
||||
type CAClient struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type CSRRequest struct {
|
||||
CSR string `json:"csr,omitempty"`
|
||||
}
|
||||
|
||||
func getPlatformProvider(platformType attestation.PlatformType) (attestation.Provider, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return vtpm.NewProvider(true, vmpl2), nil
|
||||
case attestation.Azure:
|
||||
return azure.NewProvider(), nil
|
||||
case attestation.TDX:
|
||||
return tdx.NewProvider(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
func NewCAClient(baseURL string) *CAClient {
|
||||
return &CAClient{
|
||||
baseURL: baseURL,
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func getPlatformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
|
||||
var verifier attestation.Verifier
|
||||
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
verifier = vtpm.NewVerifier(nil)
|
||||
case attestation.Azure:
|
||||
verifier = azure.NewVerifier(nil)
|
||||
case attestation.TDX:
|
||||
verifier = tdx.NewVerifier()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
func (c *CAClient) RequestCertificate(csrMetadata certs.CSRMetadata, privateKey *ecdsa.PrivateKey, cvmID string, ttl time.Duration) ([]byte, error) {
|
||||
csr, sdkerr := certscli.CreateCSR(csrMetadata, privateKey)
|
||||
if sdkerr != nil {
|
||||
return nil, fmt.Errorf("failed to create CSR: %w", sdkerr)
|
||||
}
|
||||
|
||||
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
|
||||
request := CSRRequest{CSR: string(csr.CSR)}
|
||||
requestData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to marshal CSR request: %w", err)
|
||||
}
|
||||
return verifier, nil
|
||||
}
|
||||
|
||||
func getOID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return SNPvTPMOID, nil
|
||||
case attestation.Azure:
|
||||
return AzureOID, nil
|
||||
case attestation.TDX:
|
||||
return TDXOID, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
endpoint := fmt.Sprintf("certs/csrs/%s", cvmID)
|
||||
query := url.Values{}
|
||||
query.Add("ttl", ttl.String())
|
||||
requestURL := fmt.Sprintf("%s/%s?%s", c.baseURL, endpoint, query.Encode())
|
||||
|
||||
func getPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
|
||||
switch {
|
||||
case oid.Equal(SNPvTPMOID):
|
||||
return attestation.SNPvTPM, nil
|
||||
case oid.Equal(AzureOID):
|
||||
return attestation.Azure, nil
|
||||
case oid.Equal(TDXOID):
|
||||
return attestation.TDX, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported OID: %v", oid)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
verifier, err := getPlatformVerifier(pType)
|
||||
_, responseBody, err := c.processRequest(http.MethodPost, requestURL, requestData, nil, http.StatusOK)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get platform verifier: %w", err)
|
||||
return nil, fmt.Errorf("failed to process CA request: %w", err)
|
||||
}
|
||||
|
||||
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:vtpm.Nonce]); err != nil {
|
||||
fmt.Printf("failed to verify attestation: %v\n", err)
|
||||
return err
|
||||
var cert certssdk.Certificate
|
||||
if err := json.Unmarshal(responseBody, &cert); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
|
||||
block, rest := pem.Decode([]byte(cleanCertificateString))
|
||||
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("failed to decode certificate PEM: unexpected remaining data")
|
||||
}
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate PEM: no PEM block found")
|
||||
}
|
||||
|
||||
return block.Bytes, nil
|
||||
}
|
||||
|
||||
func GetCertificate(caUrl string, cvmId string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
pType := attestation.CCPlatform()
|
||||
|
||||
provider, err := getPlatformProvider(pType)
|
||||
if err != nil {
|
||||
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, fmt.Errorf("failed to get platform provider: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
teeOid, err := getOID(pType)
|
||||
if err != nil {
|
||||
return func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, fmt.Errorf("failed to get OID for platform type: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
curve := elliptic.P256()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private/public key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
sniLength := len(clientHello.ServerName)
|
||||
if sniLength < 7 || clientHello.ServerName[sniLength-6:] != ".nonce" {
|
||||
return nil, fmt.Errorf("invalid server name: %s", clientHello.ServerName)
|
||||
}
|
||||
|
||||
nonceStr := clientHello.ServerName[:sniLength-6]
|
||||
nonce, err := hex.DecodeString(nonceStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode nonce from server name: %w", err)
|
||||
}
|
||||
|
||||
if len(nonce) != 64 {
|
||||
return nil, fmt.Errorf("invalid nonce length: expected 64 bytes, got %d bytes", len(nonce))
|
||||
}
|
||||
|
||||
attestExtension, err := getCertificateExtension(provider, pubKeyDER, nonce, teeOid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get certificate extension: %w", err)
|
||||
}
|
||||
|
||||
var certDERBytes []byte
|
||||
|
||||
if caUrl == "" && cvmId == "" {
|
||||
certTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(202403311),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{organization},
|
||||
Country: []string{country},
|
||||
Province: []string{province},
|
||||
Locality: []string{locality},
|
||||
StreetAddress: []string{streetAddress},
|
||||
PostalCode: []string{postalCode},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
ExtraExtensions: []pkix.Extension{attestExtension},
|
||||
}
|
||||
|
||||
DERBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
certDERBytes = DERBytes
|
||||
} else {
|
||||
csrmd := certs.CSRMetadata{
|
||||
Organization: []string{organization},
|
||||
Country: []string{country},
|
||||
Province: []string{province},
|
||||
Locality: []string{locality},
|
||||
StreetAddress: []string{streetAddress},
|
||||
PostalCode: []string{postalCode},
|
||||
ExtraExtensions: []pkix.Extension{attestExtension},
|
||||
}
|
||||
|
||||
csr, err := certscli.CreateCSR(csrmd, privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CSR: %w", err)
|
||||
}
|
||||
|
||||
csrData := string(csr.CSR)
|
||||
|
||||
r := csrReq{
|
||||
CSR: csrData,
|
||||
}
|
||||
|
||||
data, sdkErr := json.Marshal(r)
|
||||
if sdkErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal CSR request: %w", sdkErr)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := time.Now().AddDate(notAfterYear, notAfterMonth, notAfterDay)
|
||||
ttlString := notAfter.Sub(notBefore).String()
|
||||
|
||||
query := url.Values{}
|
||||
query.Add("ttl", ttlString)
|
||||
query_string := query.Encode()
|
||||
|
||||
certsEndpoint := "certs"
|
||||
csrEndpoint := "csrs"
|
||||
endpoint := fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, cvmId)
|
||||
|
||||
url := fmt.Sprintf("%s/%s?%s", caUrl, endpoint, query_string)
|
||||
|
||||
_, body, err := processRequest(http.MethodPost, url, data, nil, http.StatusOK)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process request: %w", err)
|
||||
}
|
||||
|
||||
var cert certssdk.Certificate
|
||||
if err := json.Unmarshal(body, &cert); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal certificate response: %w", err)
|
||||
}
|
||||
|
||||
cleanCertificateString := strings.ReplaceAll(cert.Certificate, "\\n", "\n")
|
||||
|
||||
block, rest := pem.Decode([]byte(cleanCertificateString))
|
||||
|
||||
if len(rest) != 0 {
|
||||
return nil, fmt.Errorf("failed to convert generated certificate to DER format: %s", cleanCertificateString)
|
||||
}
|
||||
|
||||
certDERBytes = block.Bytes
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certDERBytes},
|
||||
PrivateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getCertificateExtension(provider attestation.Provider, pubKey []byte, nonce []byte, teeOid asn1.ObjectIdentifier) (pkix.Extension, error) {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
rawAttestation, err := provider.Attestation(hashNonce[:], hashNonce[:vtpm.Nonce])
|
||||
if err != nil {
|
||||
return pkix.Extension{}, fmt.Errorf("failed to get attestation: %w", err)
|
||||
}
|
||||
|
||||
return pkix.Extension{
|
||||
Id: teeOid,
|
||||
Value: rawAttestation,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func processRequest(method, reqUrl string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
|
||||
req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data))
|
||||
func (c *CAClient) processRequest(method, reqURL string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) {
|
||||
req, err := http.NewRequest(method, reqURL, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
|
||||
// Sets a default value for the Content-Type.
|
||||
// Overridden if Content-Type is passed in the headers arguments.
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
sdkerr := errors.CheckError(resp, expectedRespCodes...)
|
||||
if sdkerr != nil {
|
||||
return make(http.Header), []byte{}, sdkerr
|
||||
|
||||
sdkErr := errors.CheckError(resp, expectedRespCodes...)
|
||||
if sdkErr != nil {
|
||||
return make(http.Header), []byte{}, sdkErr
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return make(http.Header), []byte{}, errors.NewSDKError(err)
|
||||
}
|
||||
|
||||
return resp.Header, body, nil
|
||||
}
|
||||
|
||||
// VerifyPeerCertificateATLS verifies peer certificates for Attested TLS.
|
||||
func VerifyPeerCertificateATLS(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
func extractNonceFromSNI(serverName string) ([]byte, error) {
|
||||
if len(serverName) < len(nonceSuffix) || !hasNonceSuffix(serverName) {
|
||||
return nil, fmt.Errorf("invalid server name: %s", serverName)
|
||||
}
|
||||
|
||||
nonceStr := serverName[:len(serverName)-len(nonceSuffix)]
|
||||
nonce, err := hex.DecodeString(nonceStr)
|
||||
if err != nil {
|
||||
return errors.Wrap(errCertificateParse, err)
|
||||
return nil, fmt.Errorf("failed to decode nonce: %w", err)
|
||||
}
|
||||
|
||||
err = verifyCertificateSignature(cert, rootCAs)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
if len(nonce) != nonceLength {
|
||||
return nil, fmt.Errorf("invalid nonce length: expected %d bytes, got %d bytes", nonceLength, len(nonce))
|
||||
}
|
||||
|
||||
for _, ext := range cert.Extensions {
|
||||
pType, err := getPlatformTypeFromOID(ext.Id)
|
||||
if err == nil {
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
return verifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("attestation extension not found in certificate")
|
||||
return nonce, nil
|
||||
}
|
||||
|
||||
// VerifyCertificateSignature verifies the certificate signature against root CAs.
|
||||
func verifyCertificateSignature(cert *x509.Certificate, rootCAs *x509.CertPool) error {
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
|
||||
if _, err := cert.Verify(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
func hasNonceSuffix(serverName string) bool {
|
||||
return len(serverName) >= len(nonceSuffix) &&
|
||||
serverName[len(serverName)-len(nonceSuffix):] == nonceSuffix
|
||||
}
|
||||
|
||||
+610
-375
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
// AttestationProvider defines the interface for platform attestation operations.
|
||||
type AttestationProvider interface {
|
||||
Attest(pubKey []byte, nonce []byte) ([]byte, error)
|
||||
OID() asn1.ObjectIdentifier
|
||||
PlatformType() attestation.PlatformType
|
||||
}
|
||||
|
||||
// PlatformAttestationProvider handles platform attestation operations.
|
||||
type platformAttestationProvider struct {
|
||||
provider attestation.Provider
|
||||
oid asn1.ObjectIdentifier
|
||||
platformType attestation.PlatformType
|
||||
}
|
||||
|
||||
// NewAttestationProvider creates a new attestation provider for the given platform type.
|
||||
func NewAttestationProvider(provider attestation.Provider, platformType attestation.PlatformType) (AttestationProvider, error) {
|
||||
oid, err := OID(platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get OID: %w", err)
|
||||
}
|
||||
|
||||
return &platformAttestationProvider{
|
||||
provider: provider,
|
||||
oid: oid,
|
||||
platformType: platformType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) Attest(pubKey []byte, nonce []byte) ([]byte, error) {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
return p.provider.Attestation(hashNonce[:], hashNonce[:32])
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) OID() asn1.ObjectIdentifier {
|
||||
return p.oid
|
||||
}
|
||||
|
||||
func (p *platformAttestationProvider) PlatformType() attestation.PlatformType {
|
||||
return p.platformType
|
||||
}
|
||||
|
||||
func OID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error) {
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
return SNPvTPMOID, nil
|
||||
case attestation.Azure:
|
||||
return AzureOID, nil
|
||||
case attestation.TDX:
|
||||
return TDXOID, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/certs"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
// CertificateProvider defines the interface for providing TLS certificates.
|
||||
type CertificateProvider interface {
|
||||
GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
// AttestedCertificateProvider provides attested TLS certificates.
|
||||
type attestedCertificateProvider struct {
|
||||
attestationProvider AttestationProvider
|
||||
caClient *CAClient
|
||||
subject CertificateSubject
|
||||
useCA bool
|
||||
cvmID string
|
||||
ttl time.Duration
|
||||
notAfterYears int
|
||||
}
|
||||
|
||||
// NewAttestedProvider creates a new attested certificate provider for self-signed certificates.
|
||||
func NewAttestedProvider(
|
||||
attestationProvider AttestationProvider,
|
||||
subject CertificateSubject,
|
||||
) CertificateProvider {
|
||||
return &attestedCertificateProvider{
|
||||
attestationProvider: attestationProvider,
|
||||
subject: subject,
|
||||
useCA: false,
|
||||
notAfterYears: defaultNotAfterYears,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAttestedCAProvider creates a new attested certificate provider for CA-signed certificates.
|
||||
func NewAttestedCAProvider(
|
||||
attestationProvider AttestationProvider,
|
||||
subject CertificateSubject,
|
||||
caURL, cvmID string,
|
||||
) CertificateProvider {
|
||||
return &attestedCertificateProvider{
|
||||
attestationProvider: attestationProvider,
|
||||
subject: subject,
|
||||
caClient: NewCAClient(caURL),
|
||||
useCA: true,
|
||||
cvmID: cvmID,
|
||||
ttl: time.Hour * 24 * 365, // Default 1 year
|
||||
}
|
||||
}
|
||||
|
||||
// SetTTL sets the certificate TTL for CA-signed certificates.
|
||||
func (p *attestedCertificateProvider) SetTTL(ttl time.Duration) {
|
||||
p.ttl = ttl
|
||||
}
|
||||
|
||||
func (p *attestedCertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate private key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := extractNonceFromSNI(clientHello.ServerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to extract nonce: %w", err)
|
||||
}
|
||||
|
||||
attestationData, err := p.attestationProvider.Attest(pubKeyDER, nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get attestation: %w", err)
|
||||
}
|
||||
|
||||
extension := pkix.Extension{
|
||||
Id: p.attestationProvider.OID(),
|
||||
Value: attestationData,
|
||||
}
|
||||
|
||||
var certDERBytes []byte
|
||||
if p.useCA {
|
||||
certDERBytes, err = p.generateCASignedCertificate(privateKey, extension)
|
||||
} else {
|
||||
certDERBytes, err = p.generateSelfSignedCertificate(privateKey, extension)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate certificate: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certDERBytes},
|
||||
PrivateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *attestedCertificateProvider) generateSelfSignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
|
||||
certTemplate := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().Unix()),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{p.subject.Organization},
|
||||
Country: []string{p.subject.Country},
|
||||
Province: []string{p.subject.Province},
|
||||
Locality: []string{p.subject.Locality},
|
||||
StreetAddress: []string{p.subject.StreetAddress},
|
||||
PostalCode: []string{p.subject.PostalCode},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(p.notAfterYears, 0, 0),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
ExtraExtensions: []pkix.Extension{extension},
|
||||
}
|
||||
|
||||
return x509.CreateCertificate(rand.Reader, certTemplate, certTemplate, &privateKey.PublicKey, privateKey)
|
||||
}
|
||||
|
||||
func (p *attestedCertificateProvider) generateCASignedCertificate(privateKey *ecdsa.PrivateKey, extension pkix.Extension) ([]byte, error) {
|
||||
csrMetadata := certs.CSRMetadata{
|
||||
Organization: []string{p.subject.Organization},
|
||||
Country: []string{p.subject.Country},
|
||||
Province: []string{p.subject.Province},
|
||||
Locality: []string{p.subject.Locality},
|
||||
StreetAddress: []string{p.subject.StreetAddress},
|
||||
PostalCode: []string{p.subject.PostalCode},
|
||||
ExtraExtensions: []pkix.Extension{extension},
|
||||
}
|
||||
|
||||
return p.caClient.RequestCertificate(csrMetadata, privateKey, p.cvmID, p.ttl)
|
||||
}
|
||||
|
||||
func NewProvider(provider attestation.Provider, platformType attestation.PlatformType, caURL, cvmID string) (CertificateProvider, error) {
|
||||
attestationProvider, err := NewAttestationProvider(provider, platformType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create attestation provider: %w", err)
|
||||
}
|
||||
|
||||
subject := DefaultCertificateSubject()
|
||||
|
||||
if caURL != "" && cvmID != "" {
|
||||
return NewAttestedCAProvider(attestationProvider, subject, caURL, cvmID), nil
|
||||
}
|
||||
|
||||
return NewAttestedProvider(attestationProvider, subject), nil
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
type CertificateVerifier interface {
|
||||
VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte) error
|
||||
}
|
||||
|
||||
// CertificateVerifier handles certificate verification operations.
|
||||
type certificateVerifier struct {
|
||||
rootCAs *x509.CertPool
|
||||
}
|
||||
|
||||
func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
|
||||
return &certificateVerifier{rootCAs: rootCAs}
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
|
||||
if len(rawCerts) == 0 {
|
||||
return fmt.Errorf("no certificates provided")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse x509 certificate: %w", err)
|
||||
}
|
||||
|
||||
if err := v.verifyCertificateSignature(cert); err != nil {
|
||||
return fmt.Errorf("certificate signature verification failed: %w", err)
|
||||
}
|
||||
|
||||
return v.verifyAttestationExtension(cert, nonce)
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error {
|
||||
rootCAs := v.rootCAs
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
|
||||
_, err := cert.Verify(opts)
|
||||
return err
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error {
|
||||
for _, ext := range cert.Extensions {
|
||||
if platformType, err := platformTypeFromOID(ext.Id); err == nil {
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
return v.verifyCertificateExtension(ext.Value, pubKeyDER, nonce, platformType)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("attestation extension not found in certificate")
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, platformType attestation.PlatformType) error {
|
||||
verifier, err := platformVerifier(platformType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get platform verifier: %w", err)
|
||||
}
|
||||
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
if err = verifier.VerifyAttestation(extension, hashNonce[:], hashNonce[:32]); err != nil {
|
||||
return fmt.Errorf("failed to verify attestation: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func platformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
|
||||
switch {
|
||||
case oid.Equal(SNPvTPMOID):
|
||||
return attestation.SNPvTPM, nil
|
||||
case oid.Equal(AzureOID):
|
||||
return attestation.Azure, nil
|
||||
case oid.Equal(TDXOID):
|
||||
return attestation.TDX, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported OID: %v", oid)
|
||||
}
|
||||
}
|
||||
|
||||
func platformVerifier(platformType attestation.PlatformType) (attestation.Verifier, error) {
|
||||
var verifier attestation.Verifier
|
||||
|
||||
switch platformType {
|
||||
case attestation.SNPvTPM:
|
||||
verifier = vtpm.NewVerifier(nil)
|
||||
case attestation.Azure:
|
||||
verifier = azure.NewVerifier(nil)
|
||||
case attestation.TDX:
|
||||
verifier = tdx.NewVerifier()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported platform type: %d", platformType)
|
||||
}
|
||||
|
||||
err := verifier.JSONToPolicy(attestation.AttestationPolicyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return verifier, nil
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewCertificateProvider creates a new instance of CertificateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewCertificateProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *CertificateProvider {
|
||||
mock := &CertificateProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// CertificateProvider is an autogenerated mock type for the CertificateProvider type
|
||||
type CertificateProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type CertificateProvider_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *CertificateProvider) EXPECT() *CertificateProvider_Expecter {
|
||||
return &CertificateProvider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// GetCertificate provides a mock function for the type CertificateProvider
|
||||
func (_mock *CertificateProvider) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
ret := _mock.Called(clientHello)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetCertificate")
|
||||
}
|
||||
|
||||
var r0 *tls.Certificate
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) (*tls.Certificate, error)); ok {
|
||||
return returnFunc(clientHello)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(*tls.ClientHelloInfo) *tls.Certificate); ok {
|
||||
r0 = returnFunc(clientHello)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*tls.Certificate)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(*tls.ClientHelloInfo) error); ok {
|
||||
r1 = returnFunc(clientHello)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// CertificateProvider_GetCertificate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCertificate'
|
||||
type CertificateProvider_GetCertificate_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetCertificate is a helper method to define mock.On call
|
||||
// - clientHello *tls.ClientHelloInfo
|
||||
func (_e *CertificateProvider_Expecter) GetCertificate(clientHello interface{}) *CertificateProvider_GetCertificate_Call {
|
||||
return &CertificateProvider_GetCertificate_Call{Call: _e.mock.On("GetCertificate", clientHello)}
|
||||
}
|
||||
|
||||
func (_c *CertificateProvider_GetCertificate_Call) Run(run func(clientHello *tls.ClientHelloInfo)) *CertificateProvider_GetCertificate_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 *tls.ClientHelloInfo
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(*tls.ClientHelloInfo)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *CertificateProvider_GetCertificate_Call) Return(certificate *tls.Certificate, err error) *CertificateProvider_GetCertificate_Call {
|
||||
_c.Call.Return(certificate, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *CertificateProvider_GetCertificate_Call) RunAndReturn(run func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)) *CertificateProvider_GetCertificate_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
Timeout: 1,
|
||||
},
|
||||
},
|
||||
err: errors.New("failed to connect to grpc server"),
|
||||
err: errors.New("agent service is unavailable"),
|
||||
},
|
||||
{
|
||||
name: "invalid config, missing AttestationPolicy with aTLS",
|
||||
@@ -127,7 +127,7 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
client, agentClient, err := NewAgentClient(ctx, tt.config)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error to contain: %v, got: %v", tt.err, err))
|
||||
if err != nil {
|
||||
assert.Nil(t, client)
|
||||
assert.Nil(t, agentClient)
|
||||
|
||||
@@ -89,15 +89,6 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "server not healthy",
|
||||
serverRunning: false,
|
||||
config: clients.StandardClientConfig{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
err: errors.New("failed to connect to grpc server"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -109,7 +100,7 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
}
|
||||
|
||||
client, agentClient, err := NewCVMClient(tt.config)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error to contain: %v, got: %v", tt.err, err))
|
||||
if err != nil {
|
||||
assert.Nil(t, client)
|
||||
assert.Nil(t, agentClient)
|
||||
|
||||
+1
-1
@@ -137,7 +137,7 @@ func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
RootCAs: rootCAs,
|
||||
ServerName: sni,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return atls.VerifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
|
||||
return atls.NewCertificateVerifier(rootCAs).VerifyPeerCertificate(rawCerts, verifiedChains, nonce)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+104
-71
@@ -25,39 +25,43 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
stopWaitTime = 5 * time.Second
|
||||
organization = "Ultraviolet"
|
||||
country = "Serbia"
|
||||
province = ""
|
||||
locality = "Belgrade"
|
||||
streetAddress = "Bulevar Arsenija Carnojevica 103"
|
||||
postalCode = "11000"
|
||||
notAfterYear = 1
|
||||
notAfterMonth = 0
|
||||
notAfterDay = 0
|
||||
nonceSize = 32
|
||||
stopWaitTime = 5 * time.Second
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
server.BaseServer
|
||||
mu sync.RWMutex
|
||||
server *grpc.Server
|
||||
health *health.Server
|
||||
registerService serviceRegister
|
||||
authSvc auth.Authenticator
|
||||
caUrl string
|
||||
cvmId string
|
||||
started bool
|
||||
stopped bool
|
||||
mu sync.RWMutex
|
||||
server *grpc.Server
|
||||
health *health.Server
|
||||
registerService serviceRegister
|
||||
authSvc auth.Authenticator
|
||||
certProvider atls.CertificateProvider
|
||||
attestedTLSEnabled bool
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
type serviceRegister func(srv *grpc.Server)
|
||||
|
||||
var _ server.Server = (*Server)(nil)
|
||||
|
||||
func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, caUrl string, cvmId string) server.Server {
|
||||
func New(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
|
||||
registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, certProvider atls.CertificateProvider,
|
||||
) server.Server {
|
||||
base := config.GetBaseConfig()
|
||||
listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port)
|
||||
|
||||
var attestedTLS bool
|
||||
|
||||
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
|
||||
if certProvider == nil {
|
||||
logger.Error("Failed to create certificate provider")
|
||||
} else {
|
||||
attestedTLS = true
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
BaseServer: server.BaseServer{
|
||||
Ctx: ctx,
|
||||
@@ -67,10 +71,10 @@ func New(ctx context.Context, cancel context.CancelFunc, name string, config ser
|
||||
Config: config,
|
||||
Logger: logger,
|
||||
},
|
||||
registerService: registerService,
|
||||
authSvc: authSvc,
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
registerService: registerService,
|
||||
authSvc: authSvc,
|
||||
certProvider: certProvider,
|
||||
attestedTLSEnabled: attestedTLS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,65 +96,28 @@ func (s *Server) Start() error {
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
}
|
||||
|
||||
// Add authentication interceptors if auth service is available
|
||||
if s.authSvc != nil {
|
||||
unary, stream := agentgrpc.NewAuthInterceptor(s.authSvc)
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
|
||||
}
|
||||
|
||||
creds := grpc.Creds(insecure.NewCredentials())
|
||||
|
||||
c := s.Config.GetBaseConfig()
|
||||
if agCfg, ok := s.Config.(server.AgentConfig); ok && agCfg.AttestedTLS {
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: atls.GetCertificate(s.caUrl, s.cvmId),
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, c.ServerCAFile, c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
}
|
||||
} else {
|
||||
switch {
|
||||
case c.CertFile != "" || c.KeyFile != "":
|
||||
tlsSetup, err := server.SetupRegularTLS(c.CertFile, c.KeyFile, c.ServerCAFile, c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsSetup.Config))
|
||||
|
||||
if tlsSetup.MTLS {
|
||||
mtlsCA := server.BuildMTLSDescription(c.ServerCAFile, c.ClientCAFile)
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, c.CertFile, c.KeyFile, mtlsCA))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile))
|
||||
}
|
||||
default:
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
|
||||
}
|
||||
// Configure credentials
|
||||
creds, err := s.configureCredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure credentials: %w", err)
|
||||
}
|
||||
|
||||
grpcServerOptions = append(grpcServerOptions, creds)
|
||||
|
||||
// Create listener
|
||||
listener, err := net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
|
||||
grpcServerOptions = append(grpcServerOptions, creds)
|
||||
|
||||
// Create and configure server
|
||||
s.mu.Lock()
|
||||
s.server = grpc.NewServer(grpcServerOptions...)
|
||||
s.health = health.NewServer()
|
||||
@@ -159,6 +126,7 @@ func (s *Server) Start() error {
|
||||
s.health.SetServingStatus(s.Name, grpchealth.HealthCheckResponse_SERVING)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
s.mu.RLock()
|
||||
server := s.server
|
||||
@@ -178,6 +146,71 @@ func (s *Server) Start() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) configureCredentials() (grpc.ServerOption, error) {
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
|
||||
// Check if attested TLS should be used
|
||||
if s.shouldUseAttestedTLS() {
|
||||
return s.configureAttestedTLS(baseConfig.Config)
|
||||
}
|
||||
|
||||
// Check if regular TLS should be used
|
||||
if s.shouldUseRegularTLS(baseConfig.Config) {
|
||||
return s.configureRegularTLS(baseConfig.Config)
|
||||
}
|
||||
|
||||
// Use insecure credentials
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
|
||||
return grpc.Creds(insecure.NewCredentials()), nil
|
||||
}
|
||||
|
||||
func (s *Server) shouldUseAttestedTLS() bool {
|
||||
return s.attestedTLSEnabled && s.certProvider != nil
|
||||
}
|
||||
|
||||
func (s *Server) shouldUseRegularTLS(config server.Config) bool {
|
||||
return config.CertFile != "" || config.KeyFile != ""
|
||||
}
|
||||
|
||||
func (s *Server) configureAttestedTLS(config server.Config) (grpc.ServerOption, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: s.certProvider.GetCertificate,
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, config.ServerCAFile, config.ClientCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
}
|
||||
|
||||
return grpc.Creds(credentials.NewTLS(tlsConfig)), nil
|
||||
}
|
||||
|
||||
func (s *Server) configureRegularTLS(config server.Config) (grpc.ServerOption, error) {
|
||||
tlsSetup, err := server.SetupRegularTLS(config.CertFile, config.KeyFile, config.ServerCAFile, config.ClientCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
if tlsSetup.MTLS {
|
||||
mtlsCA := server.BuildMTLSDescription(config.ServerCAFile, config.ClientCAFile)
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s",
|
||||
s.Name, s.Address, config.CertFile, config.KeyFile, mtlsCA))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s",
|
||||
s.Name, s.Address, config.CertFile, config.KeyFile))
|
||||
}
|
||||
|
||||
return grpc.Creds(credentials.NewTLS(tlsSetup.Config)), nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"google.golang.org/grpc"
|
||||
@@ -49,7 +50,7 @@ func TestNew(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
assert.NotNil(t, srv)
|
||||
assert.IsType(t, &Server{}, srv)
|
||||
@@ -98,7 +99,7 @@ func TestServerStartWithTLSFile(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -144,7 +145,7 @@ func TestServerStartWithmTLSFile(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
@@ -183,7 +184,7 @@ func TestServerStop(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
go func() {
|
||||
err := srv.Start()
|
||||
@@ -367,7 +368,9 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, "", "")
|
||||
mockCertProvider := new(mocks.CertificateProvider)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, mockCertProvider)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
|
||||
+29
-25
@@ -23,23 +23,35 @@ const (
|
||||
type httpServer struct {
|
||||
server.BaseServer
|
||||
|
||||
server *http.Server
|
||||
caURL string
|
||||
server *http.Server
|
||||
certProvider atls.CertificateProvider
|
||||
attestedTLSEnabled bool
|
||||
}
|
||||
|
||||
var _ server.Server = (*httpServer)(nil)
|
||||
|
||||
func NewServer(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
|
||||
handler http.Handler, logger *slog.Logger, caURL string,
|
||||
handler http.Handler, logger *slog.Logger, certProvider atls.CertificateProvider,
|
||||
) server.Server {
|
||||
baseServer := server.NewBaseServer(ctx, cancel, name, config, logger)
|
||||
hserver := &http.Server{Addr: baseServer.Address, Handler: handler}
|
||||
|
||||
var attestedTLS bool
|
||||
|
||||
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
|
||||
if certProvider == nil {
|
||||
logger.Error("Failed to create certificate provider")
|
||||
} else {
|
||||
attestedTLS = true
|
||||
}
|
||||
}
|
||||
|
||||
return &httpServer{
|
||||
BaseServer: baseServer,
|
||||
server: hserver,
|
||||
caURL: caURL,
|
||||
BaseServer: baseServer,
|
||||
server: hserver,
|
||||
certProvider: certProvider,
|
||||
attestedTLSEnabled: attestedTLS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,22 +78,15 @@ func (s *httpServer) Stop() error {
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
s.Logger.Error(fmt.Sprintf(
|
||||
"%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err))
|
||||
|
||||
return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err)
|
||||
}
|
||||
|
||||
s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpServer) shouldUseAttestedTLS() bool {
|
||||
cfg, ok := s.Config.(server.AgentConfig)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return cfg.AttestedTLS && s.caURL != ""
|
||||
return s.attestedTLSEnabled && s.certProvider != nil
|
||||
}
|
||||
|
||||
func (s *httpServer) shouldUseRegularTLS() bool {
|
||||
@@ -91,10 +96,11 @@ func (s *httpServer) shouldUseRegularTLS() bool {
|
||||
func (s *httpServer) startWithAttestedTLS() error {
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: atls.GetCertificate(s.caURL, ""),
|
||||
GetCertificate: s.certProvider.GetCertificate,
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
@@ -107,12 +113,12 @@ func (s *httpServer) startWithAttestedTLS() error {
|
||||
s.Protocol = httpsProtocol
|
||||
|
||||
s.logAttestedTLSStart(mtls)
|
||||
|
||||
return s.listenAndServe(true)
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithRegularTLS() error {
|
||||
tlsSetup, err := server.SetupRegularTLS(s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
tlsSetup, err := server.SetupRegularTLS(baseConfig.CertFile, baseConfig.KeyFile, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
@@ -121,13 +127,11 @@ func (s *httpServer) startWithRegularTLS() error {
|
||||
s.Protocol = httpsProtocol
|
||||
|
||||
s.logRegularTLSStart(tlsSetup.MTLS)
|
||||
|
||||
return s.listenAndServe(true)
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithoutTLS() error {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address))
|
||||
|
||||
return s.listenAndServe(false)
|
||||
}
|
||||
|
||||
@@ -140,15 +144,15 @@ func (s *httpServer) logAttestedTLSStart(mtls bool) {
|
||||
}
|
||||
|
||||
func (s *httpServer) logRegularTLSStart(mtls bool) {
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf(
|
||||
"%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s",
|
||||
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile,
|
||||
s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile))
|
||||
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile,
|
||||
baseConfig.ServerCAFile, baseConfig.ClientCAFile))
|
||||
} else {
|
||||
s.Logger.Info(
|
||||
fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
|
||||
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile))
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
|
||||
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
)
|
||||
|
||||
@@ -64,57 +66,55 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
logger := slog.Default()
|
||||
caURL := "https://ca.example.com"
|
||||
|
||||
server := NewServer(ctx, cancel, name, config, handler, logger, caURL)
|
||||
server := NewServer(ctx, cancel, name, config, handler, logger, nil)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
httpSrv, ok := server.(*httpServer)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, caURL, httpSrv.caURL)
|
||||
assert.NotNil(t, httpSrv.server)
|
||||
assert.Equal(t, handler, httpSrv.server.Handler)
|
||||
}
|
||||
|
||||
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
|
||||
mockCertProvider := new(mocks.CertificateProvider)
|
||||
tests := []struct {
|
||||
name string
|
||||
config server.ServerConfiguration
|
||||
caURL string
|
||||
attestedTLS bool
|
||||
expected bool
|
||||
name string
|
||||
config server.ServerConfiguration
|
||||
expected bool
|
||||
certProvider atls.CertificateProvider
|
||||
}{
|
||||
{
|
||||
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and caURL is not empty",
|
||||
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and certProvider is not empty",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: true,
|
||||
},
|
||||
caURL: "https://ca.example.com",
|
||||
expected: true,
|
||||
certProvider: mockCertProvider,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when caURL is empty",
|
||||
name: "should not use attested TLS when certProvider is empty",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: true,
|
||||
},
|
||||
caURL: "",
|
||||
expected: false,
|
||||
certProvider: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when AttestedTLS is false",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: false,
|
||||
},
|
||||
caURL: "https://ca.example.com",
|
||||
expected: false,
|
||||
certProvider: mockCertProvider,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when config is not AgentConfig",
|
||||
config: &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
},
|
||||
caURL: "https://ca.example.com",
|
||||
expected: false,
|
||||
certProvider: mockCertProvider,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
|
||||
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.caURL)
|
||||
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.certProvider)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
result := httpSrv.shouldUseAttestedTLS()
|
||||
@@ -176,7 +176,7 @@ func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
result := httpSrv.shouldUseRegularTLS()
|
||||
@@ -192,7 +192,7 @@ func TestHttpServer_Stop(t *testing.T) {
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Start a test server that we can control
|
||||
@@ -229,7 +229,7 @@ func TestHttpServer_logAttestedTLSStart(t *testing.T) {
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// This test mainly ensures the method doesn't panic
|
||||
@@ -269,7 +269,7 @@ func TestHttpServer_logRegularTLSStart(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// This test mainly ensures the method doesn't panic
|
||||
@@ -289,7 +289,7 @@ func TestHttpServer_startWithoutTLS(t *testing.T) {
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Use a test server to avoid binding to actual ports
|
||||
@@ -334,7 +334,7 @@ func TestHttpServer_Protocol(t *testing.T) {
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
tt.setupTLS(httpSrv)
|
||||
@@ -351,7 +351,7 @@ func TestHttpServer_ContextCancellation(t *testing.T) {
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Cancel the context immediately
|
||||
@@ -372,7 +372,7 @@ func TestHttpServer_TLSConfiguration(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Test TLS configuration setup
|
||||
@@ -397,7 +397,7 @@ func TestHttpServer_Lifecycle(t *testing.T) {
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
|
||||
// Test that server can be created and has expected initial state
|
||||
httpSrv, ok := server.(*httpServer)
|
||||
|
||||
+1
-15
@@ -40,8 +40,6 @@ var (
|
||||
attestedTLSString string
|
||||
attestedTLS bool
|
||||
pubKeyFile string
|
||||
caUrl string
|
||||
cvmId string
|
||||
clientCAFile string
|
||||
)
|
||||
|
||||
@@ -108,8 +106,6 @@ func main() {
|
||||
flagSet.StringVar(&pubKeyFile, "public-key-path", "", "Path to the public key file")
|
||||
flagSet.StringVar(&attestedTLSString, "attested-tls-bool", "", "Should aTLS be used, must be 'true' or 'false'")
|
||||
flagSet.StringVar(&dataPathString, "data-paths", "", "Paths to data sources, list of string separated with commas")
|
||||
flagSet.StringVar(&caUrl, "ca-url", "", "URL for certificate authority, must be specified if aTLS is used")
|
||||
flagSet.StringVar(&cvmId, "cvm-id", "", "UUID for a CVM, must be specified if aTLS is used")
|
||||
flagSet.StringVar(&clientCAFile, "client-ca-file", "", "Client CA root certificate file path")
|
||||
|
||||
flagSetParseError := flagSet.Parse(os.Args[1:])
|
||||
@@ -145,16 +141,6 @@ func main() {
|
||||
dataPaths = strings.Split(dataPathString, ",")
|
||||
}
|
||||
|
||||
if err == nil && caUrl != "" && !attestedTLS {
|
||||
parsingErrorString.WriteString("CA URL is only available with attested TLS\n")
|
||||
parsingError = true
|
||||
}
|
||||
|
||||
if err == nil && cvmId != "" && !attestedTLS {
|
||||
parsingErrorString.WriteString("CVM UUID is only available with attested TLS\n")
|
||||
parsingError = true
|
||||
}
|
||||
|
||||
if parsingError {
|
||||
parsingErrorString.WriteString("Usage :\n")
|
||||
flagSet.SetOutput(&parsingErrorString)
|
||||
@@ -191,7 +177,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, caUrl, cvmId)
|
||||
gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, nil)
|
||||
|
||||
g.Go(func() error {
|
||||
return gs.Start()
|
||||
|
||||
Reference in New Issue
Block a user