mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Refactor http and grpc clients for reusability with Cube (#521)
* Implement gRPC server with TLS and mTLS support - Added gRPC server implementation in pkg/server/grpc. - Introduced server configuration options for TLS and mTLS. - Implemented health check service for gRPC. - Created tests for server initialization, startup, and shutdown scenarios. - Added mock server for testing purposes. - Implemented graceful shutdown handling for the server. - Included documentation for the server package. Signed-off-by: SammyOina <sammyoina@gmail.com> * Add TLS and ATLS support to gRPC and HTTP clients; refactor security handling Signed-off-by: SammyOina <sammyoina@gmail.com> * Refactor server configuration structure to use Config instead of BaseConfig Signed-off-by: SammyOina <sammyoina@gmail.com> * Fix comments for consistency and clarity in TLS-related code Signed-off-by: SammyOina <sammyoina@gmail.com> * Add comprehensive tests for TLS and ATLS configurations in clients package Signed-off-by: SammyOina <sammyoina@gmail.com> * Refactor file permission constants in client tests to use octal notation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for HTTP server's TLS configuration and lifecycle management Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for TLS certificate handling and configuration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for HTTP client configuration and transport Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor AttestationReportSize constant declaration for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor client configuration structure and update gRPC client implementations - Consolidated client configuration types into a unified structure with BaseConfig. - Introduced AttestedClientConfig and StandardClientConfig for specific use cases. - Updated gRPC client creation functions to utilize new configuration types. - Refactored tests to align with the new configuration structure. - Removed redundant ClientConfiguration interface and related methods. - Simplified TLS configuration loading logic for both standard and attested clients. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor client configuration structure and TLS handling - Introduced StandardClientConfig to replace BaseConfig, simplifying client configuration. - Updated AttestedClientConfig to embed StandardClientConfig instead of BaseConfig. - Modified ClientConfiguration interface to use Config() method instead of GetBaseConfig(). - Refactored various client tests to accommodate changes in configuration structure. - Added new TLS handling functions to support basic and attested TLS configurations. - Implemented comprehensive tests for TLS loading and configuration validation. - Removed deprecated methods and unnecessary code related to BaseConfig. Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: SammyOina <sammyoina@gmail.com> Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5377dd4d7f
commit
906d7877b2
+1
-1
@@ -68,7 +68,7 @@ packages:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/internal/server:
|
||||
github.com/ultravioletrs/cocos/pkg/server:
|
||||
interfaces:
|
||||
Server:
|
||||
config:
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
@@ -53,7 +53,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
|
||||
agentGrpcServerConfig := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: as.host,
|
||||
Port: cfg.Port,
|
||||
CertFile: cfg.CertFile,
|
||||
|
||||
+4
-3
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc/agent"
|
||||
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
|
||||
@@ -18,15 +19,15 @@ var Verbose bool
|
||||
|
||||
type CLI struct {
|
||||
agentSDK sdk.SDK
|
||||
agentConfig grpc.AgentClientConfig
|
||||
managerConfig grpc.ManagerClientConfig
|
||||
agentConfig clients.AttestedClientConfig
|
||||
managerConfig clients.StandardClientConfig
|
||||
client grpc.Client
|
||||
managerClient manager.ManagerServiceClient
|
||||
connectErr error
|
||||
measurement cmdconfig.MeasurementProvider
|
||||
}
|
||||
|
||||
func New(agentConfig grpc.AgentClientConfig, managerConfig grpc.ManagerClientConfig, measurement cmdconfig.MeasurementProvider) *CLI {
|
||||
func New(agentConfig clients.AttestedClientConfig, managerConfig clients.StandardClientConfig, measurement cmdconfig.MeasurementProvider) *CLI {
|
||||
return &CLI{
|
||||
agentConfig: agentConfig,
|
||||
managerConfig: managerConfig,
|
||||
|
||||
+2
-1
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -111,7 +112,7 @@ func main() {
|
||||
provider = &attestation.EmptyProvider{}
|
||||
}
|
||||
|
||||
cvmGrpcConfig := pkggrpc.CVMClientConfig{}
|
||||
cvmGrpcConfig := clients.StandardClientConfig{}
|
||||
if err := env.ParseWithOptions(&cvmGrpcConfig, env.Options{Prefix: envPrefixCVMGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
exitCode = 1
|
||||
|
||||
+3
-3
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/cli"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
cmd "github.com/virtee/sev-snp-measure-go/sevsnpmeasure/cmd"
|
||||
)
|
||||
|
||||
@@ -94,14 +94,14 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
agentGRPCConfig := grpc.AgentClientConfig{}
|
||||
agentGRPCConfig := clients.AttestedClientConfig{}
|
||||
if err := env.ParseWithOptions(&agentGRPCConfig, env.Options{Prefix: envPrefixAgentGRPC}); err != nil {
|
||||
message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)
|
||||
rootCmd.Println(message)
|
||||
return
|
||||
}
|
||||
|
||||
managerGRPCConfig := grpc.ManagerClientConfig{}
|
||||
managerGRPCConfig := clients.StandardClientConfig{}
|
||||
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixManagerGRPC}); err != nil {
|
||||
message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)
|
||||
rootCmd.Println(message)
|
||||
|
||||
+2
-2
@@ -20,14 +20,14 @@ import (
|
||||
"github.com/absmach/supermq/pkg/uuid"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/api"
|
||||
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager/api/http"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/tracing"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
+53
-5
@@ -47,9 +47,11 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
|
||||
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
|
||||
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
|
||||
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
|
||||
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
|
||||
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
|
||||
errCertificateParse = errors.New("failed to parse x509 certificate")
|
||||
errAttVerification = errors.New("certificate is not self signed")
|
||||
)
|
||||
|
||||
type csrReq struct {
|
||||
@@ -103,7 +105,7 @@ func getOID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error
|
||||
}
|
||||
}
|
||||
|
||||
func GetPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
|
||||
func getPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
|
||||
switch {
|
||||
case oid.Equal(SNPvTPMOID):
|
||||
return attestation.SNPvTPM, nil
|
||||
@@ -116,7 +118,7 @@ func GetPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
|
||||
func verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
@@ -331,3 +333,49 @@ func processRequest(method, reqUrl string, data []byte, headers map[string]strin
|
||||
}
|
||||
return resp.Header, body, nil
|
||||
}
|
||||
|
||||
// VerifyPeerCertificateATLS verifies peer certificates for Attested TLS.
|
||||
func VerifyPeerCertificateATLS(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return errors.Wrap(errCertificateParse, err)
|
||||
}
|
||||
|
||||
err = verifyCertificateSignature(cert, rootCAs)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
for _, ext := range cert.Extensions {
|
||||
pType, err := getPlatformTypeFromOID(ext.Id)
|
||||
if err == nil {
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
return verifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("attestation extension not found in certificate")
|
||||
}
|
||||
|
||||
// VerifyCertificateSignature verifies the certificate signature against root CAs.
|
||||
func verifyCertificateSignature(cert *x509.Certificate, rootCAs *x509.CertPool) error {
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
|
||||
if _, err := cert.Verify(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+233
-2
@@ -6,17 +6,21 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
certssdk "github.com/absmach/certs/sdk"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -217,7 +221,7 @@ func TestGetPlatformTypeFromOID(t *testing.T) {
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
pType, err := GetPlatformTypeFromOID(c.oid)
|
||||
pType, err := getPlatformTypeFromOID(c.oid)
|
||||
|
||||
if c.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
@@ -305,7 +309,7 @@ func TestVerifyCertificateExtension(t *testing.T) {
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
err := VerifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
|
||||
err := verifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
|
||||
if c.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
@@ -622,3 +626,230 @@ func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestCertificateVerification unified test suite for certificate verification.
|
||||
func TestCertificateVerification(t *testing.T) {
|
||||
// Setup common test data
|
||||
selfSignedCert := createSelfSignedCert(t)
|
||||
leafCert, rootCert := generateCertificateChain(t)
|
||||
rootCAs := createCertPool(rootCert)
|
||||
emptyPool := x509.NewCertPool()
|
||||
|
||||
t.Run("SelfSignedCertificates", func(t *testing.T) {
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "ValidSelfSignedCertificate",
|
||||
cert: selfSignedCert,
|
||||
rootCAs: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "EmptyCertificate",
|
||||
cert: &x509.Certificate{},
|
||||
rootCAs: nil,
|
||||
expectError: true,
|
||||
errorMsg: "x509: missing ASN.1 contents; use ParseCertificate",
|
||||
},
|
||||
}
|
||||
|
||||
runCertificateVerificationTests(t, testCases)
|
||||
})
|
||||
|
||||
t.Run("CertificateChainVerification", func(t *testing.T) {
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "ValidCertificateWithRootCA",
|
||||
cert: leafCert,
|
||||
rootCAs: rootCAs,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "SelfSignedCertificate",
|
||||
cert: rootCert,
|
||||
rootCAs: nil, // Self-signed verification
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidCertificateWithEmptyPool",
|
||||
cert: rootCert,
|
||||
rootCAs: emptyPool,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
runCertificateVerificationTests(t, testCases)
|
||||
})
|
||||
|
||||
t.Run("ATLSPeerCertificateVerification", func(t *testing.T) {
|
||||
nonce := generateNonce(t)
|
||||
|
||||
testCases := []atlsTestCase{
|
||||
{
|
||||
name: "InvalidCertificateData",
|
||||
rawCerts: [][]byte{[]byte("invalid cert data")},
|
||||
nonce: nonce,
|
||||
rootCAs: rootCAs,
|
||||
expectError: true,
|
||||
errorMsg: "failed to parse x509 certificate",
|
||||
},
|
||||
{
|
||||
name: "ValidCertificateNoAttestationExtension",
|
||||
rawCerts: [][]byte{leafCert.Raw},
|
||||
nonce: nonce,
|
||||
rootCAs: rootCAs,
|
||||
expectError: true,
|
||||
errorMsg: "attestation extension not found in certificate",
|
||||
},
|
||||
}
|
||||
|
||||
runATLSVerificationTests(t, testCases)
|
||||
})
|
||||
}
|
||||
|
||||
// Unified test case structures.
|
||||
type testCase struct {
|
||||
name string
|
||||
cert *x509.Certificate
|
||||
rootCAs *x509.CertPool
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}
|
||||
|
||||
type atlsTestCase struct {
|
||||
name string
|
||||
rawCerts [][]byte
|
||||
nonce []byte
|
||||
rootCAs *x509.CertPool
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}
|
||||
|
||||
// Unified test runners.
|
||||
func runCertificateVerificationTests(t *testing.T, testCases []testCase) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := verifyCertificateSignature(tc.cert, tc.rootCAs)
|
||||
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
if tc.errorMsg != "" {
|
||||
if tc.errorMsg == "x509: missing ASN.1 contents; use ParseCertificate" {
|
||||
// For specific error matching
|
||||
assert.True(t, errors.Contains(err, errors.New(tc.errorMsg)),
|
||||
fmt.Sprintf("expected error %q, got %v", tc.errorMsg, err))
|
||||
} else {
|
||||
assert.Contains(t, err.Error(), tc.errorMsg)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runATLSVerificationTests(t *testing.T, testCases []atlsTestCase) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := VerifyPeerCertificateATLS(tc.rawCerts, nil, tc.nonce, tc.rootCAs)
|
||||
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
if tc.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Unified certificate creation utilities.
|
||||
func createSelfSignedCert(t *testing.T) *x509.Certificate {
|
||||
privateKey := generateRSAKey(t)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour), // Consistent duration
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
return createCertificateFromTemplate(t, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
}
|
||||
|
||||
func generateCertificateChain(t *testing.T) (leafCert, rootCert *x509.Certificate) {
|
||||
// Generate root certificate
|
||||
rootKey := generateRSAKey(t)
|
||||
rootTemplate := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Root CA"},
|
||||
Country: []string{"US"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
rootCert = createCertificateFromTemplate(t, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey)
|
||||
|
||||
// Generate leaf certificate signed by root
|
||||
leafKey := generateRSAKey(t)
|
||||
leafTemplate := x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Leaf"},
|
||||
Country: []string{"US"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
|
||||
leafCert = createCertificateFromTemplate(t, &leafTemplate, &rootTemplate, &leafKey.PublicKey, rootKey)
|
||||
|
||||
return leafCert, rootCert
|
||||
}
|
||||
|
||||
// Helper functions for consistency.
|
||||
func generateRSAKey(t *testing.T) *rsa.PrivateKey {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
return privateKey
|
||||
}
|
||||
|
||||
func createCertificateFromTemplate(t *testing.T, template, parent *x509.Certificate, pub interface{}, priv interface{}) *x509.Certificate {
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
return cert
|
||||
}
|
||||
|
||||
func createCertPool(certs ...*x509.Certificate) *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
for _, cert := range certs {
|
||||
pool.AddCert(cert)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
func generateNonce(t *testing.T) []byte {
|
||||
nonce := make([]byte, 64)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
return nonce
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
|
||||
import "time"
|
||||
|
||||
var (
|
||||
_ ClientConfiguration = (*AttestedClientConfig)(nil)
|
||||
_ ClientConfiguration = (*StandardClientConfig)(nil)
|
||||
)
|
||||
|
||||
type ClientConfiguration interface {
|
||||
Config() StandardClientConfig
|
||||
}
|
||||
|
||||
// StandardClientConfig represents a basic client configuration without attested TLS.
|
||||
type StandardClientConfig struct {
|
||||
URL string `env:"URL" envDefault:"localhost:7001"`
|
||||
Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"`
|
||||
ClientCert string `env:"CLIENT_CERT" envDefault:""`
|
||||
ClientKey string `env:"CLIENT_KEY" envDefault:""`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
}
|
||||
|
||||
// AttestedClientConfig represents a client configuration with attested TLS capabilities.
|
||||
type AttestedClientConfig struct {
|
||||
StandardClientConfig
|
||||
AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""`
|
||||
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
|
||||
ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"`
|
||||
}
|
||||
|
||||
func (c AttestedClientConfig) Config() StandardClientConfig {
|
||||
return c.StandardClientConfig
|
||||
}
|
||||
|
||||
func (c StandardClientConfig) Config() StandardClientConfig {
|
||||
return c
|
||||
}
|
||||
+1
-1
@@ -2,5 +2,5 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package clients contains the domain concept definitions needed to support
|
||||
// Agent Client functionality.
|
||||
// HTTP/gRPC Client functionality.
|
||||
package clients
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
@@ -14,13 +15,13 @@ import (
|
||||
var ErrAgentServiceUnavailable = errors.New("agent service is unavailable")
|
||||
|
||||
// NewAgentClient creates new agent gRPC client instance.
|
||||
func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Client, agent.AgentServiceClient, error) {
|
||||
func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc.Client, agent.AgentServiceClient, error) {
|
||||
client, err := grpc.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if client.Secure() != grpc.WithMATLS && client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
|
||||
if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() {
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
@@ -78,14 +78,14 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverRunning bool
|
||||
config pkggrpc.AgentClientConfig
|
||||
config clients.AttestedClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "successful connection",
|
||||
serverRunning: true,
|
||||
config: pkggrpc.AgentClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
config: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
@@ -95,8 +95,8 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
{
|
||||
name: "server not healthy",
|
||||
serverRunning: false,
|
||||
config: pkggrpc.AgentClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
config: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
@@ -105,8 +105,8 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid config, missing AttestationPolicy with aTLS",
|
||||
config: pkggrpc.AgentClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
config: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, security, error) {
|
||||
security := withaTLS
|
||||
|
||||
info, err := os.Stat(cfg.AttestationPolicy)
|
||||
if err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to stat attestation policy file"), err)
|
||||
}
|
||||
|
||||
if !info.Mode().IsRegular() {
|
||||
return nil, withoutTLS, fmt.Errorf("attestation policy file is not a regular file: %s", cfg.AttestationPolicy)
|
||||
}
|
||||
|
||||
attestation.AttestationPolicyPath = cfg.AttestationPolicy
|
||||
|
||||
var rootCAs *x509.CertPool = nil
|
||||
|
||||
if len(cfg.ServerCAFile) > 0 {
|
||||
// Read the certificate file
|
||||
certPEM, err := os.ReadFile(cfg.ServerCAFile)
|
||||
if err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
|
||||
}
|
||||
|
||||
// Decode the PEM block
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
return nil, withoutTLS, fmt.Errorf("failed to decode PEM block")
|
||||
}
|
||||
|
||||
// Parse the certificate
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
|
||||
}
|
||||
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
|
||||
security = withmaTLS
|
||||
}
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to generate nonce"), err)
|
||||
}
|
||||
|
||||
encoded := hex.EncodeToString(nonce)
|
||||
sni := fmt.Sprintf("%s.nonce", encoded)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
RootCAs: rootCAs,
|
||||
ServerName: sni,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return verifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.ClientCert != "" || cfg.ClientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if err != nil {
|
||||
return nil, withoutTLS, errors.Wrap(errFailedToLoadClientCertKey, err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
return credentials.NewTLS(tlsConfig), security, nil
|
||||
}
|
||||
|
||||
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return errors.Wrap(errCertificateParse, err)
|
||||
}
|
||||
|
||||
err = checkIfCertificateSigned(cert, rootCAs)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAttVerification, err)
|
||||
}
|
||||
|
||||
for _, ext := range cert.Extensions {
|
||||
pType, err := atls.GetPlatformTypeFromOID(ext.Id)
|
||||
if err == nil {
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
|
||||
}
|
||||
|
||||
return atls.VerifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("attestation extension not found in certificate")
|
||||
}
|
||||
|
||||
func checkIfCertificateSigned(cert *x509.Certificate, rootCAs *x509.CertPool) error {
|
||||
if rootCAs == nil {
|
||||
rootCAs = x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
}
|
||||
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: rootCAs,
|
||||
CurrentTime: time.Now(),
|
||||
}
|
||||
|
||||
if _, err := cert.Verify(opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
@@ -35,14 +36,14 @@ func TestNewClient(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg BaseConfig
|
||||
agentCfg AgentClientConfig
|
||||
cfg clients.StandardClientConfig
|
||||
agentCfg clients.AttestedClientConfig
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Success without TLS",
|
||||
cfg: BaseConfig{
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
},
|
||||
wantErr: false,
|
||||
@@ -50,7 +51,7 @@ func TestNewClient(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Success with TLS",
|
||||
cfg: BaseConfig{
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
},
|
||||
@@ -59,7 +60,7 @@ func TestNewClient(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Success with mTLS",
|
||||
cfg: BaseConfig{
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
@@ -70,8 +71,8 @@ func TestNewClient(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Success agent client with mTLS",
|
||||
agentCfg: AgentClientConfig{
|
||||
BaseConfig: BaseConfig{
|
||||
agentCfg: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
@@ -83,8 +84,8 @@ func TestNewClient(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Success agent client with aTLS",
|
||||
agentCfg: AgentClientConfig{
|
||||
BaseConfig: BaseConfig{
|
||||
agentCfg: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
@@ -98,8 +99,8 @@ func TestNewClient(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Failed agent client with aTLS",
|
||||
agentCfg: AgentClientConfig{
|
||||
BaseConfig: BaseConfig{
|
||||
agentCfg: clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
@@ -113,34 +114,34 @@ func TestNewClient(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ServerCAFile",
|
||||
cfg: BaseConfig{
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: "nonexistent.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
err: errFailedToLoadRootCA,
|
||||
err: clients.ErrFailedToLoadRootCA,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ClientCert",
|
||||
cfg: BaseConfig{
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: "nonexistent.pem",
|
||||
ClientKey: clientKeyFile,
|
||||
},
|
||||
wantErr: true,
|
||||
err: errFailedToLoadClientCertKey,
|
||||
err: clients.ErrFailedToLoadClientCertKey,
|
||||
},
|
||||
{
|
||||
name: "Fail with invalid ClientKey",
|
||||
cfg: BaseConfig{
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
ServerCAFile: caCertFile,
|
||||
ClientCert: clientCertFile,
|
||||
ClientKey: "nonexistent.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
err: errFailedToLoadClientCertKey,
|
||||
err: clients.ErrFailedToLoadClientCertKey,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -168,39 +169,39 @@ func TestNewClient(t *testing.T) {
|
||||
func TestClientSecure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
secure security
|
||||
secure clients.Security
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Without TLS",
|
||||
secure: withoutTLS,
|
||||
secure: clients.WithoutTLS,
|
||||
expected: "without TLS",
|
||||
},
|
||||
{
|
||||
name: "With TLS",
|
||||
secure: withTLS,
|
||||
secure: clients.WithTLS,
|
||||
expected: "with TLS",
|
||||
},
|
||||
{
|
||||
name: "With mTLS",
|
||||
secure: withmTLS,
|
||||
secure: clients.WithMTLS,
|
||||
expected: "with mTLS",
|
||||
},
|
||||
{
|
||||
name: "With aTLS",
|
||||
secure: withaTLS,
|
||||
secure: clients.WithATLS,
|
||||
expected: "with aTLS",
|
||||
},
|
||||
{
|
||||
name: "With maTLS",
|
||||
secure: withmaTLS,
|
||||
expected: WithMATLS,
|
||||
secure: clients.WithMATLS,
|
||||
expected: "with maTLS",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &client{secure: tt.secure}
|
||||
c := &client{security: tt.secure}
|
||||
assert.Equal(t, tt.expected, c.Secure())
|
||||
})
|
||||
}
|
||||
@@ -354,56 +355,3 @@ func createTempFile(data []byte) (string, error) {
|
||||
func createTempFileHandle() (*os.File, error) {
|
||||
return os.CreateTemp("", "test")
|
||||
}
|
||||
|
||||
func TestCheckIfCertificateSelfSigned(t *testing.T) {
|
||||
selfSignedCert := createSelfSignedCert(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cert *x509.Certificate
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Self-signed certificate",
|
||||
cert: selfSignedCert,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "missing certificate contents",
|
||||
cert: &x509.Certificate{},
|
||||
err: errors.New("x509: missing ASN.1 contents; use ParseCertificate"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkIfCertificateSigned(tt.cert, nil)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createSelfSignedCert(t *testing.T) *x509.Certificate {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
return cert
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@ package cvm
|
||||
|
||||
import (
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
)
|
||||
|
||||
// NewManagerClient creates new manager gRPC client instance.
|
||||
func NewCVMClient(cfg grpc.CVMClientConfig) (grpc.Client, cvms.ServiceClient, error) {
|
||||
func NewCVMClient(cfg clients.StandardClientConfig) (grpc.Client, cvms.ServiceClient, error) {
|
||||
client, err := grpc.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
@@ -77,28 +77,24 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverRunning bool
|
||||
config pkggrpc.CVMClientConfig
|
||||
config clients.StandardClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "successful connection",
|
||||
serverRunning: true,
|
||||
config: pkggrpc.CVMClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
config: clients.StandardClientConfig{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "server not healthy",
|
||||
serverRunning: false,
|
||||
config: pkggrpc.CVMClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
config: clients.StandardClientConfig{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
err: errors.New("failed to connect to grpc server"),
|
||||
},
|
||||
|
||||
+29
-127
@@ -4,84 +4,19 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
type security int
|
||||
|
||||
const (
|
||||
withoutTLS security = iota
|
||||
withTLS
|
||||
withmTLS
|
||||
withaTLS
|
||||
withmaTLS
|
||||
)
|
||||
|
||||
const (
|
||||
AttestationReportSize = 0x4A0
|
||||
WithMATLS = "with maTLS"
|
||||
WithATLS = "with aTLS"
|
||||
WithTLS = "with TLS"
|
||||
)
|
||||
|
||||
var (
|
||||
errGrpcConnect = errors.New("failed to connect to grpc server")
|
||||
errGrpcClose = errors.New("failed to close grpc connection")
|
||||
errCertificateParse = errors.New("failed to parse x509 certificate")
|
||||
errAttVerification = errors.New("certificat is not self signed")
|
||||
errFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
|
||||
errFailedToLoadRootCA = errors.New("failed to load root ca file")
|
||||
errGrpcConnect = errors.New("failed to connect to grpc server")
|
||||
errGrpcClose = errors.New("failed to close grpc connection")
|
||||
)
|
||||
|
||||
type ClientConfiguration interface {
|
||||
GetBaseConfig() BaseConfig
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
URL string `env:"URL" envDefault:"localhost:7001"`
|
||||
Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"`
|
||||
ClientCert string `env:"CLIENT_CERT" envDefault:""`
|
||||
ClientKey string `env:"CLIENT_KEY" envDefault:""`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
}
|
||||
|
||||
type AgentClientConfig struct {
|
||||
BaseConfig
|
||||
AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""`
|
||||
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
|
||||
ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"`
|
||||
}
|
||||
|
||||
type ManagerClientConfig struct {
|
||||
BaseConfig
|
||||
}
|
||||
|
||||
type CVMClientConfig struct {
|
||||
BaseConfig
|
||||
}
|
||||
|
||||
func (a BaseConfig) GetBaseConfig() BaseConfig {
|
||||
return a
|
||||
}
|
||||
|
||||
func (a AgentClientConfig) GetBaseConfig() BaseConfig {
|
||||
return a.BaseConfig
|
||||
}
|
||||
|
||||
func (a CVMClientConfig) GetBaseConfig() BaseConfig {
|
||||
return a.BaseConfig
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
Close() error
|
||||
Secure() string
|
||||
@@ -90,14 +25,14 @@ type Client interface {
|
||||
|
||||
type client struct {
|
||||
*grpc.ClientConn
|
||||
cfg ClientConfiguration
|
||||
secure security
|
||||
cfg clients.ClientConfiguration
|
||||
security clients.Security
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
|
||||
func NewClient(cfg ClientConfiguration) (Client, error) {
|
||||
conn, secure, err := connect(cfg)
|
||||
func NewClient(cfg clients.ClientConfiguration) (Client, error) {
|
||||
conn, security, err := connect(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -105,7 +40,7 @@ func NewClient(cfg ClientConfiguration) (Client, error) {
|
||||
return &client{
|
||||
ClientConn: conn,
|
||||
cfg: cfg,
|
||||
secure: secure,
|
||||
security: security,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -117,86 +52,53 @@ func (c *client) Close() error {
|
||||
}
|
||||
|
||||
func (c *client) Secure() string {
|
||||
switch c.secure {
|
||||
case withTLS:
|
||||
return WithTLS
|
||||
case withmTLS:
|
||||
return "with mTLS"
|
||||
case withaTLS:
|
||||
return "with aTLS"
|
||||
case withmaTLS:
|
||||
return WithMATLS
|
||||
default:
|
||||
return "without TLS"
|
||||
}
|
||||
return c.security.String()
|
||||
}
|
||||
|
||||
func (c *client) Connection() *grpc.ClientConn {
|
||||
return c.ClientConn
|
||||
}
|
||||
|
||||
func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
|
||||
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) {
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
|
||||
}
|
||||
secure := withoutTLS
|
||||
security := clients.WithoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS {
|
||||
tc, sec, err := setupATLS(agcfg)
|
||||
if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
|
||||
result, err := clients.LoadATLSConfig(agcfg)
|
||||
if err != nil {
|
||||
return nil, secure, err
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(tc))
|
||||
|
||||
secure = sec
|
||||
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(result.Config)))
|
||||
security = result.Security
|
||||
} else {
|
||||
conf := cfg.GetBaseConfig()
|
||||
conf := cfg.Config()
|
||||
transportCreds, sec, err := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
if err != nil {
|
||||
return nil, secure, err
|
||||
return nil, security, err
|
||||
}
|
||||
opts = append(opts, grpc.WithTransportCredentials(transportCreds))
|
||||
secure = sec
|
||||
security = sec
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(cfg.GetBaseConfig().URL, opts...)
|
||||
conn, err := grpc.NewClient(cfg.Config().URL, opts...)
|
||||
if err != nil {
|
||||
return nil, secure, errors.Wrap(errGrpcConnect, err)
|
||||
return nil, security, errors.Wrap(errGrpcConnect, err)
|
||||
}
|
||||
return conn, secure, nil
|
||||
return conn, security, nil
|
||||
}
|
||||
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, security, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
secure := withoutTLS
|
||||
tc := insecure.NewCredentials()
|
||||
|
||||
if serverCAFile != "" {
|
||||
rootCA, err := os.ReadFile(serverCAFile)
|
||||
if err != nil {
|
||||
return nil, secure, errors.Wrap(errFailedToLoadRootCA, err)
|
||||
}
|
||||
if len(rootCA) > 0 {
|
||||
capool := x509.NewCertPool()
|
||||
if !capool.AppendCertsFromPEM(rootCA) {
|
||||
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
|
||||
}
|
||||
tlsConfig.RootCAs = capool
|
||||
secure = withTLS
|
||||
tc = credentials.NewTLS(tlsConfig)
|
||||
}
|
||||
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) {
|
||||
result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, clients.WithoutTLS, err
|
||||
}
|
||||
|
||||
if clientCert != "" || clientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, secure, errors.Wrap(errFailedToLoadClientCertKey, err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
secure = withmTLS
|
||||
tc = credentials.NewTLS(tlsConfig)
|
||||
if result.Security == clients.WithoutTLS || result.Config == nil {
|
||||
return insecure.NewCredentials(), result.Security, nil
|
||||
}
|
||||
|
||||
return tc, secure, nil
|
||||
return credentials.NewTLS(result.Config), result.Security, nil
|
||||
}
|
||||
|
||||
@@ -4,11 +4,12 @@ package manager
|
||||
|
||||
import (
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
)
|
||||
|
||||
// NewManagerClient creates new manager gRPC client instance.
|
||||
func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
|
||||
func NewManagerClient(cfg clients.StandardClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
|
||||
client, err := grpc.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -7,21 +7,19 @@ import (
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
)
|
||||
|
||||
func TestNewManagerClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg grpc.ManagerClientConfig
|
||||
cfg clients.StandardClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid config",
|
||||
cfg: grpc.ManagerClientConfig{
|
||||
BaseConfig: grpc.BaseConfig{
|
||||
URL: "localhost:7001",
|
||||
},
|
||||
cfg: clients.StandardClientConfig{
|
||||
URL: "localhost:7001",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Transport() *http.Transport
|
||||
Secure() string
|
||||
Timeout() time.Duration
|
||||
}
|
||||
|
||||
type client struct {
|
||||
transport *http.Transport
|
||||
cfg clients.ClientConfiguration
|
||||
security clients.Security
|
||||
}
|
||||
|
||||
var _ Client = (*client)(nil)
|
||||
|
||||
func NewClient(cfg clients.ClientConfiguration) (Client, error) {
|
||||
transport, security, err := createTransport(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &client{
|
||||
transport: transport,
|
||||
cfg: cfg,
|
||||
security: security,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) Transport() *http.Transport {
|
||||
return c.transport
|
||||
}
|
||||
|
||||
func (c *client) Secure() string {
|
||||
return c.security.String()
|
||||
}
|
||||
|
||||
func (c *client) Timeout() time.Duration {
|
||||
return c.cfg.Config().Timeout
|
||||
}
|
||||
|
||||
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
security := clients.WithoutTLS
|
||||
|
||||
if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
|
||||
result, err := clients.LoadATLSConfig(*agcfg)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = result.Config
|
||||
security = result.Security
|
||||
} else {
|
||||
conf := cfg.Config()
|
||||
|
||||
result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
|
||||
if err != nil {
|
||||
return nil, security, err
|
||||
}
|
||||
|
||||
if result.Security != clients.WithoutTLS {
|
||||
transport.TLSClientConfig = result.Config
|
||||
}
|
||||
|
||||
security = result.Security
|
||||
}
|
||||
|
||||
return transport, security, nil
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
)
|
||||
|
||||
func TestConfig_Configuration(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
ClientCert: "cert.pem",
|
||||
ClientKey: "key.pem",
|
||||
ServerCAFile: "ca.pem",
|
||||
}
|
||||
|
||||
result := config.Config()
|
||||
|
||||
assert.Equal(t, config, result)
|
||||
assert.Equal(t, "http://localhost:8080", result.URL)
|
||||
assert.Equal(t, 30*time.Second, result.Timeout)
|
||||
assert.Equal(t, "cert.pem", result.ClientCert)
|
||||
assert.Equal(t, "key.pem", result.ClientKey)
|
||||
assert.Equal(t, "ca.pem", result.ServerCAFile)
|
||||
}
|
||||
|
||||
func TestAgentClientConfig_Configuration(t *testing.T) {
|
||||
agentConfig := &clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "https://agent.example.com",
|
||||
Timeout: 60 * time.Second,
|
||||
ClientCert: "agent-cert.pem",
|
||||
ClientKey: "agent-key.pem",
|
||||
ServerCAFile: "agent-ca.pem",
|
||||
},
|
||||
AttestationPolicy: "policy.json",
|
||||
AttestedTLS: true,
|
||||
ProductName: "Milan",
|
||||
}
|
||||
|
||||
result := agentConfig.Config()
|
||||
|
||||
assert.Equal(t, agentConfig.StandardClientConfig, result)
|
||||
assert.Equal(t, "https://agent.example.com", result.URL)
|
||||
assert.Equal(t, 60*time.Second, result.Timeout)
|
||||
assert.Equal(t, "agent-cert.pem", result.ClientCert)
|
||||
assert.Equal(t, "agent-key.pem", result.ClientKey)
|
||||
assert.Equal(t, "agent-ca.pem", result.ServerCAFile)
|
||||
}
|
||||
|
||||
func TestProxyClientConfig_Configuration(t *testing.T) {
|
||||
proxyConfig := clients.StandardClientConfig{
|
||||
URL: "http://proxy.example.com",
|
||||
Timeout: 45 * time.Second,
|
||||
ClientCert: "proxy-cert.pem",
|
||||
ClientKey: "proxy-key.pem",
|
||||
ServerCAFile: "proxy-ca.pem",
|
||||
}
|
||||
|
||||
result := proxyConfig
|
||||
|
||||
assert.Equal(t, proxyConfig, result)
|
||||
assert.Equal(t, "http://proxy.example.com", result.URL)
|
||||
assert.Equal(t, 45*time.Second, result.Timeout)
|
||||
}
|
||||
|
||||
func TestNewClient_Success(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config clients.ClientConfiguration
|
||||
}{
|
||||
{
|
||||
name: "Basic config",
|
||||
config: clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Agent config without attested TLS",
|
||||
config: &clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "https://agent.example.com",
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
AttestedTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Proxy config",
|
||||
config: clients.StandardClientConfig{
|
||||
URL: "http://proxy.example.com",
|
||||
Timeout: 45 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, err := NewClient(tt.config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport())
|
||||
assert.Equal(t, tt.config.Config().Timeout, client.Timeout())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Transport(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
client, err := NewClient(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
transport := client.Transport()
|
||||
|
||||
assert.NotNil(t, transport)
|
||||
assert.IsType(t, &http.Transport{}, transport)
|
||||
assert.Equal(t, 100, transport.MaxIdleConns)
|
||||
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
|
||||
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
|
||||
}
|
||||
|
||||
func TestClient_Secure(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config clients.ClientConfiguration
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Without TLS",
|
||||
config: clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
expected: clients.WithoutTLS.String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, err := NewClient(tt.config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
secure := client.Secure()
|
||||
assert.Equal(t, tt.expected, secure)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Timeout(t *testing.T) {
|
||||
expectedTimeout := 45 * time.Second
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: expectedTimeout,
|
||||
}
|
||||
|
||||
client, err := NewClient(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
timeout := client.Timeout()
|
||||
assert.Equal(t, expectedTimeout, timeout)
|
||||
}
|
||||
|
||||
func TestCreateTransport_DefaultSettings(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
transport, security, err := createTransport(config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Equal(t, 100, transport.MaxIdleConns)
|
||||
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
|
||||
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
|
||||
assert.Nil(t, transport.TLSClientConfig)
|
||||
}
|
||||
|
||||
func TestCreateTransport_ATLSError(t *testing.T) {
|
||||
config := &clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "https://agent.example.com",
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
AttestationPolicy: "invalid",
|
||||
AttestedTLS: true,
|
||||
ProductName: "Milan",
|
||||
}
|
||||
|
||||
transport, security, err := createTransport(config)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "failed to stat attestation policy")
|
||||
}
|
||||
|
||||
func TestCreateTransport_BasicTLSError(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "https://example.com",
|
||||
Timeout: 30 * time.Second,
|
||||
ServerCAFile: "invalid",
|
||||
}
|
||||
|
||||
transport, security, err := createTransport(config)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, transport)
|
||||
assert.Equal(t, clients.WithoutTLS, security)
|
||||
assert.Contains(t, err.Error(), "failed to load root ca file")
|
||||
}
|
||||
|
||||
func TestClientInterface_Implementation(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "http://localhost:8080",
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
client, err := NewClient(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify that client implements the Client interface
|
||||
var _ Client = client
|
||||
|
||||
// Test all interface methods
|
||||
assert.NotNil(t, client.Transport())
|
||||
assert.NotEmpty(t, client.Secure())
|
||||
assert.Greater(t, client.Timeout(), time.Duration(0))
|
||||
}
|
||||
|
||||
func TestAgentClientConfig_FieldAccess(t *testing.T) {
|
||||
config := &clients.AttestedClientConfig{
|
||||
StandardClientConfig: clients.StandardClientConfig{
|
||||
URL: "https://agent.example.com",
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
AttestationPolicy: "test-policy",
|
||||
AttestedTLS: true,
|
||||
ProductName: "TestProduct",
|
||||
}
|
||||
|
||||
assert.Equal(t, "test-policy", config.AttestationPolicy)
|
||||
assert.True(t, config.AttestedTLS)
|
||||
assert.Equal(t, "TestProduct", config.ProductName)
|
||||
assert.Equal(t, "https://agent.example.com", config.URL)
|
||||
assert.Equal(t, 60*time.Second, config.Timeout)
|
||||
}
|
||||
|
||||
func TestProxyClientConfig_FieldAccess(t *testing.T) {
|
||||
config := clients.StandardClientConfig{
|
||||
URL: "http://proxy.example.com",
|
||||
Timeout: 45 * time.Second,
|
||||
ClientCert: "proxy-cert.pem",
|
||||
ClientKey: "proxy-key.pem",
|
||||
ServerCAFile: "proxy-ca.pem",
|
||||
}
|
||||
|
||||
assert.Equal(t, "http://proxy.example.com", config.URL)
|
||||
assert.Equal(t, 45*time.Second, config.Timeout)
|
||||
assert.Equal(t, "proxy-cert.pem", config.ClientCert)
|
||||
assert.Equal(t, "proxy-key.pem", config.ClientKey)
|
||||
assert.Equal(t, "proxy-ca.pem", config.ServerCAFile)
|
||||
}
|
||||
|
||||
func TestClientConfiguration_Interface(t *testing.T) {
|
||||
// Test that all config types implement ClientConfiguration interface
|
||||
var configs []clients.ClientConfiguration
|
||||
|
||||
configs = append(configs, clients.StandardClientConfig{})
|
||||
configs = append(configs, &clients.AttestedClientConfig{})
|
||||
|
||||
for i, config := range configs {
|
||||
t.Run(t.Name()+"_"+string(rune(i+'0')), func(t *testing.T) {
|
||||
result := config.Config()
|
||||
assert.IsType(t, clients.StandardClientConfig{}, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
)
|
||||
|
||||
// Security represents the type of TLS security configuration.
|
||||
type Security int
|
||||
|
||||
const (
|
||||
WithoutTLS Security = iota
|
||||
WithTLS
|
||||
WithMTLS
|
||||
WithATLS
|
||||
WithMATLS
|
||||
)
|
||||
|
||||
// String returns a human-readable representation of the security level.
|
||||
func (s Security) String() string {
|
||||
switch s {
|
||||
case WithTLS:
|
||||
return "with TLS"
|
||||
case WithMTLS:
|
||||
return "with mTLS"
|
||||
case WithATLS:
|
||||
return "with aTLS"
|
||||
case WithMATLS:
|
||||
return "with maTLS"
|
||||
case WithoutTLS:
|
||||
return "without TLS"
|
||||
default:
|
||||
return "without TLS"
|
||||
}
|
||||
}
|
||||
|
||||
const AttestationReportSize = 0x4A0
|
||||
|
||||
var (
|
||||
ErrFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
|
||||
ErrFailedToLoadRootCA = errors.New("failed to load root ca file")
|
||||
errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file")
|
||||
)
|
||||
|
||||
// TLSResult contains the result of TLS configuration.
|
||||
type TLSResult struct {
|
||||
Config *tls.Config
|
||||
Security Security
|
||||
}
|
||||
|
||||
// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS).
|
||||
func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
security := WithoutTLS
|
||||
|
||||
// If no TLS configuration is provided, return nil config (no TLS)
|
||||
if serverCAFile == "" && clientCert == "" && clientKey == "" {
|
||||
return &TLSResult{Config: nil, Security: security}, nil
|
||||
}
|
||||
|
||||
if serverCAFile != "" {
|
||||
rootCA, err := os.ReadFile(serverCAFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFailedToLoadRootCA, err)
|
||||
}
|
||||
|
||||
if len(rootCA) > 0 {
|
||||
capool := x509.NewCertPool()
|
||||
if !capool.AppendCertsFromPEM(rootCA) {
|
||||
return nil, errors.New("failed to append root ca to tls.Config")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = capool
|
||||
security = WithTLS
|
||||
}
|
||||
}
|
||||
|
||||
if clientCert != "" || clientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
security = WithMTLS
|
||||
}
|
||||
|
||||
return &TLSResult{Config: tlsConfig, Security: security}, nil
|
||||
}
|
||||
|
||||
// LoadATLSConfig configures Attested TLS.
|
||||
func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
|
||||
security := WithATLS
|
||||
|
||||
info, err := os.Stat(cfg.AttestationPolicy)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err)
|
||||
}
|
||||
|
||||
if !info.Mode().IsRegular() {
|
||||
return nil, errAttestationPolicyIrregular
|
||||
}
|
||||
|
||||
attestation.AttestationPolicyPath = cfg.AttestationPolicy
|
||||
|
||||
var rootCAs *x509.CertPool
|
||||
|
||||
if cfg.ServerCAFile != "" {
|
||||
rootCAs, err = loadRootCAs(cfg.ServerCAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
security = WithMATLS
|
||||
}
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to generate nonce"), err)
|
||||
}
|
||||
|
||||
encoded := hex.EncodeToString(nonce)
|
||||
sni := encoded + ".nonce"
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
RootCAs: rootCAs,
|
||||
ServerName: sni,
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
return atls.VerifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.ClientCert != "" || cfg.ClientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
return &TLSResult{Config: tlsConfig, Security: security}, nil
|
||||
}
|
||||
|
||||
// loadRootCAs loads root CA certificates from a file.
|
||||
func loadRootCAs(serverCAFile string) (*x509.CertPool, error) {
|
||||
// Read the certificate file
|
||||
certPEM, err := os.ReadFile(serverCAFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to read certificate file"), err)
|
||||
}
|
||||
|
||||
// Decode the PEM block
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
|
||||
// Parse the certificate
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errors.New("failed to parse certificate"), err)
|
||||
}
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
rootCAs.AddCert(cert)
|
||||
|
||||
return rootCAs, nil
|
||||
}
|
||||
@@ -0,0 +1,440 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package clients
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSecurity_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
security Security
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "WithoutTLS",
|
||||
security: WithoutTLS,
|
||||
expected: "without TLS",
|
||||
},
|
||||
{
|
||||
name: "WithTLS",
|
||||
security: WithTLS,
|
||||
expected: "with TLS",
|
||||
},
|
||||
{
|
||||
name: "WithMTLS",
|
||||
security: WithMTLS,
|
||||
expected: "with mTLS",
|
||||
},
|
||||
{
|
||||
name: "WithATLS",
|
||||
security: WithATLS,
|
||||
expected: "with aTLS",
|
||||
},
|
||||
{
|
||||
name: "WithMATLS",
|
||||
security: WithMATLS,
|
||||
expected: "with maTLS",
|
||||
},
|
||||
{
|
||||
name: "InvalidSecurity",
|
||||
security: Security(999),
|
||||
expected: "without TLS",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, tt.security.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBasicTLSConfig(t *testing.T) {
|
||||
// Create temporary directory for test files
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Generate test certificate and key
|
||||
cert, key, caPEM := generateTestCertificates(t)
|
||||
|
||||
certFile := filepath.Join(tmpDir, "client.crt")
|
||||
keyFile := filepath.Join(tmpDir, "client.key")
|
||||
caFile := filepath.Join(tmpDir, "ca.crt")
|
||||
|
||||
require.NoError(t, os.WriteFile(certFile, cert, 0o644))
|
||||
require.NoError(t, os.WriteFile(keyFile, key, 0o644))
|
||||
require.NoError(t, os.WriteFile(caFile, caPEM, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
serverCAFile string
|
||||
clientCert string
|
||||
clientKey string
|
||||
expectedSec Security
|
||||
expectedConfig bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "NoTLS",
|
||||
serverCAFile: "",
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectedConfig: false,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "TLSOnly",
|
||||
serverCAFile: caFile,
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithTLS,
|
||||
expectedConfig: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "MTLS",
|
||||
serverCAFile: caFile,
|
||||
clientCert: certFile,
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithMTLS,
|
||||
expectedConfig: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "MTLSWithoutCA",
|
||||
serverCAFile: "",
|
||||
clientCert: certFile,
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithMTLS,
|
||||
expectedConfig: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "InvalidCAFile",
|
||||
serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
clientCert: "",
|
||||
clientKey: "",
|
||||
expectedSec: WithoutTLS,
|
||||
expectedConfig: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidCertFile",
|
||||
serverCAFile: "",
|
||||
clientCert: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithoutTLS,
|
||||
expectedConfig: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "InvalidKeyFile",
|
||||
serverCAFile: "",
|
||||
clientCert: certFile,
|
||||
clientKey: filepath.Join(tmpDir, "nonexistent.key"),
|
||||
expectedSec: WithoutTLS,
|
||||
expectedConfig: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "MismatchedCertKey",
|
||||
serverCAFile: "",
|
||||
clientCert: caFile, // Using CA file as cert (wrong format)
|
||||
clientKey: keyFile,
|
||||
expectedSec: WithoutTLS,
|
||||
expectedConfig: false,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, tt.expectedSec, result.Security)
|
||||
|
||||
if tt.expectedConfig {
|
||||
assert.NotNil(t, result.Config)
|
||||
} else {
|
||||
assert.Nil(t, result.Config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadATLSConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Create test files
|
||||
cert, key, caPEM := generateTestCertificates(t)
|
||||
|
||||
certFile := filepath.Join(tmpDir, "client.crt")
|
||||
keyFile := filepath.Join(tmpDir, "client.key")
|
||||
caFile := filepath.Join(tmpDir, "ca.crt")
|
||||
policyFile := filepath.Join(tmpDir, "policy.json")
|
||||
|
||||
require.NoError(t, os.WriteFile(certFile, cert, 0o644))
|
||||
require.NoError(t, os.WriteFile(keyFile, key, 0o644))
|
||||
require.NoError(t, os.WriteFile(caFile, caPEM, 0o644))
|
||||
require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config AttestedClientConfig
|
||||
expectedSec Security
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "ValidATLSConfig",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: "",
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidMATLSConfig",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: caFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithMATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "ValidATLSWithClientCert",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ClientCert: certFile,
|
||||
ClientKey: keyFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithATLS,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "NonexistentPolicyFile",
|
||||
config: AttestedClientConfig{
|
||||
AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to stat attestation policy file",
|
||||
},
|
||||
{
|
||||
name: "PolicyFileIsDirectory",
|
||||
config: AttestedClientConfig{
|
||||
AttestationPolicy: tmpDir, // Directory instead of file
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "attestation policy file is not a regular file",
|
||||
},
|
||||
{
|
||||
name: "InvalidCAFile",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
errorMsg: "failed to read certificate file",
|
||||
},
|
||||
{
|
||||
name: "InvalidClientCert",
|
||||
config: AttestedClientConfig{
|
||||
StandardClientConfig: StandardClientConfig{
|
||||
ClientCert: filepath.Join(tmpDir, "nonexistent.crt"),
|
||||
ClientKey: keyFile,
|
||||
},
|
||||
AttestationPolicy: policyFile,
|
||||
ProductName: "test-product",
|
||||
},
|
||||
expectedSec: WithoutTLS,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := LoadATLSConfig(tt.config)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, tt.expectedSec, result.Security)
|
||||
assert.NotNil(t, result.Config)
|
||||
|
||||
// Verify TLS config properties
|
||||
assert.True(t, result.Config.InsecureSkipVerify)
|
||||
assert.NotNil(t, result.Config.VerifyPeerCertificate)
|
||||
assert.NotEmpty(t, result.Config.ServerName)
|
||||
assert.Contains(t, result.Config.ServerName, ".nonce")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadRootCAs(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Generate test certificate
|
||||
_, _, caPEM := generateTestCertificates(t)
|
||||
|
||||
validCAFile := filepath.Join(tmpDir, "valid_ca.crt")
|
||||
invalidCAFile := filepath.Join(tmpDir, "invalid_ca.crt")
|
||||
nonExistentFile := filepath.Join(tmpDir, "nonexistent.crt")
|
||||
|
||||
require.NoError(t, os.WriteFile(validCAFile, caPEM, 0o644))
|
||||
require.NoError(t, os.WriteFile(invalidCAFile, []byte("invalid pem data"), 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
caFile string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "ValidCAFile",
|
||||
caFile: validCAFile,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "NonExistentFile",
|
||||
caFile: nonExistentFile,
|
||||
expectError: true,
|
||||
errorMsg: "failed to read certificate file",
|
||||
},
|
||||
{
|
||||
name: "InvalidPEMData",
|
||||
caFile: invalidCAFile,
|
||||
expectError: true,
|
||||
errorMsg: "failed to decode PEM block",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rootCAs, err := loadRootCAs(tt.caFile)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, rootCAs)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, rootCAs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for generating test certificates
|
||||
|
||||
func generateTestCertificates(t *testing.T) (certPEM, keyPEM, caPEM []byte) {
|
||||
// Generate CA certificate
|
||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
caTemplate := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test CA"},
|
||||
Country: []string{"US"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
caPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: caCertDER,
|
||||
})
|
||||
|
||||
// Generate client certificate
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientTemplate := x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Client"},
|
||||
Country: []string{"US"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: clientCertDER,
|
||||
})
|
||||
|
||||
clientKeyDER, err := x509.MarshalPKCS8PrivateKey(clientKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: clientKeyDER,
|
||||
})
|
||||
|
||||
return certPEM, keyPEM, caPEM
|
||||
}
|
||||
@@ -6,19 +6,16 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -110,39 +107,13 @@ func (s *Server) Start() error {
|
||||
GetCertificate: atls.GetCertificate(s.caUrl, s.cvmId),
|
||||
}
|
||||
|
||||
var mtls bool
|
||||
mtls = false
|
||||
|
||||
// Loading Server CA file
|
||||
rootCA, err := loadCertFile(c.ServerCAFile)
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, c.ServerCAFile, c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load server ca file: %w", err)
|
||||
}
|
||||
if len(rootCA) > 0 {
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
|
||||
return fmt.Errorf("failed to append server ca to tls.Config")
|
||||
}
|
||||
mtls = true
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
// Loading Client CA File
|
||||
clientCA, err := loadCertFile(c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load client ca file: %w", err)
|
||||
}
|
||||
if len(clientCA) > 0 {
|
||||
if tlsConfig.ClientCAs == nil {
|
||||
tlsConfig.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
|
||||
return fmt.Errorf("failed to append client ca to tls.Config")
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
mtls = true
|
||||
}
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
@@ -155,52 +126,17 @@ func (s *Server) Start() error {
|
||||
} else {
|
||||
switch {
|
||||
case c.CertFile != "" || c.KeyFile != "":
|
||||
certificate, err := loadX509KeyPair(c.CertFile, c.KeyFile)
|
||||
tlsSetup, err := server.SetupRegularTLS(c.CertFile, c.KeyFile, c.ServerCAFile, c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load auth certificates: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
var mtlsCA string
|
||||
// Loading Server CA file
|
||||
rootCA, err := loadCertFile(c.ServerCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load root ca file: %w", err)
|
||||
}
|
||||
if len(rootCA) > 0 {
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
|
||||
return fmt.Errorf("failed to append root ca to tls.Config")
|
||||
}
|
||||
mtlsCA = fmt.Sprintf("root ca %s", c.ServerCAFile)
|
||||
}
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsSetup.Config))
|
||||
|
||||
// Loading Client CA File
|
||||
clientCA, err := loadCertFile(c.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load client ca file: %w", err)
|
||||
}
|
||||
if len(clientCA) > 0 {
|
||||
if tlsConfig.ClientCAs == nil {
|
||||
tlsConfig.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
|
||||
return fmt.Errorf("failed to append client ca to tls.Config")
|
||||
}
|
||||
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, c.ClientCAFile)
|
||||
}
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
switch {
|
||||
case mtlsCA != "":
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
if tlsSetup.MTLS {
|
||||
mtlsCA := server.BuildMTLSDescription(c.ServerCAFile, c.ClientCAFile)
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, c.CertFile, c.KeyFile, mtlsCA))
|
||||
default:
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile))
|
||||
}
|
||||
default:
|
||||
@@ -272,36 +208,3 @@ func (s *Server) Stop() error {
|
||||
s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address))
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadCertFile(certFile string) ([]byte, error) {
|
||||
if certFile != "" {
|
||||
return readFileOrData(certFile)
|
||||
}
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func readFileOrData(input string) ([]byte, error) {
|
||||
if len(input) < 1000 && !strings.Contains(input, "\n") {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return []byte(input), nil
|
||||
}
|
||||
|
||||
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
cert, err := readFileOrData(certfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
|
||||
}
|
||||
|
||||
key, err := readFileOrData(keyfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
|
||||
}
|
||||
|
||||
return tls.X509KeyPair(cert, key)
|
||||
}
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
@@ -40,7 +40,7 @@ func TestNew(t *testing.T) {
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "50051",
|
||||
},
|
||||
@@ -85,7 +85,7 @@ func TestServerStartWithTLSFile(t *testing.T) {
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: certFile.Name(),
|
||||
@@ -130,7 +130,7 @@ func TestServerStartWithmTLSFile(t *testing.T) {
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: string(clientCertFile),
|
||||
@@ -173,7 +173,7 @@ func TestServerStop(t *testing.T) {
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
@@ -268,7 +268,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
name: "Non-TLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
@@ -280,7 +280,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
name: "TLS Server Startup with Self-Signed Certificate",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
@@ -293,7 +293,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
name: "TLS Server Startup with Invalid Certificates",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: "invalid",
|
||||
@@ -308,7 +308,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
name: "maTLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "",
|
||||
@@ -325,7 +325,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
name: "maTLS Server Startup with Invalid Server CA file",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
@@ -341,7 +341,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
|
||||
name: "maTLS Server Startup with Invalid Clinet CA file",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
smqserver "github.com/absmach/supermq/pkg/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
)
|
||||
|
||||
const (
|
||||
httpProtocol = "http"
|
||||
httpsProtocol = "https"
|
||||
)
|
||||
|
||||
type httpServer struct {
|
||||
server.BaseServer
|
||||
|
||||
server *http.Server
|
||||
caURL string
|
||||
}
|
||||
|
||||
var _ server.Server = (*httpServer)(nil)
|
||||
|
||||
func NewServer(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
|
||||
handler http.Handler, logger *slog.Logger, caURL string,
|
||||
) server.Server {
|
||||
baseServer := server.NewBaseServer(ctx, cancel, name, config, logger)
|
||||
hserver := &http.Server{Addr: baseServer.Address, Handler: handler}
|
||||
|
||||
return &httpServer{
|
||||
BaseServer: baseServer,
|
||||
server: hserver,
|
||||
caURL: caURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) Start() error {
|
||||
s.Protocol = httpProtocol
|
||||
|
||||
if s.shouldUseAttestedTLS() {
|
||||
return s.startWithAttestedTLS()
|
||||
}
|
||||
|
||||
if s.shouldUseRegularTLS() {
|
||||
return s.startWithRegularTLS()
|
||||
}
|
||||
|
||||
return s.startWithoutTLS()
|
||||
}
|
||||
|
||||
func (s *httpServer) Stop() error {
|
||||
defer s.Cancel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), smqserver.StopWaitTime)
|
||||
defer cancel()
|
||||
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
s.Logger.Error(fmt.Sprintf(
|
||||
"%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err))
|
||||
|
||||
return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err)
|
||||
}
|
||||
|
||||
s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpServer) shouldUseAttestedTLS() bool {
|
||||
cfg, ok := s.Config.(server.AgentConfig)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return cfg.AttestedTLS && s.caURL != ""
|
||||
}
|
||||
|
||||
func (s *httpServer) shouldUseRegularTLS() bool {
|
||||
return s.Config.GetBaseConfig().CertFile != "" || s.Config.GetBaseConfig().KeyFile != ""
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithAttestedTLS() error {
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: atls.GetCertificate(s.caURL, ""),
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
s.server.TLSConfig = tlsConfig
|
||||
s.Protocol = httpsProtocol
|
||||
|
||||
s.logAttestedTLSStart(mtls)
|
||||
|
||||
return s.listenAndServe(true)
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithRegularTLS() error {
|
||||
tlsSetup, err := server.SetupRegularTLS(s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
s.server.TLSConfig = tlsSetup.Config
|
||||
s.Protocol = httpsProtocol
|
||||
|
||||
s.logRegularTLSStart(tlsSetup.MTLS)
|
||||
|
||||
return s.listenAndServe(true)
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithoutTLS() error {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address))
|
||||
|
||||
return s.listenAndServe(false)
|
||||
}
|
||||
|
||||
func (s *httpServer) logAttestedTLSStart(mtls bool) {
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested mTLS", s.Name, s.Protocol, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested TLS", s.Name, s.Protocol, s.Address))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) logRegularTLSStart(mtls bool) {
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf(
|
||||
"%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s",
|
||||
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile,
|
||||
s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile))
|
||||
} else {
|
||||
s.Logger.Info(
|
||||
fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
|
||||
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) listenAndServe(useTLS bool) error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if useTLS {
|
||||
cfg := s.Config.GetBaseConfig()
|
||||
errCh <- s.server.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
|
||||
} else {
|
||||
errCh <- s.server.ListenAndServe()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.Ctx.Done():
|
||||
return s.Stop()
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,411 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
)
|
||||
|
||||
// Mock implementations for testing.
|
||||
type mockHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.Called(w, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("test response")); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type mockBaseConfig struct {
|
||||
certFile string
|
||||
keyFile string
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
host string
|
||||
port string
|
||||
}
|
||||
|
||||
func (m *mockBaseConfig) GetCertFile() string { return m.certFile }
|
||||
func (m *mockBaseConfig) GetKeyFile() string { return m.keyFile }
|
||||
func (m *mockBaseConfig) GetServerCAFile() string { return m.serverCAFile }
|
||||
func (m *mockBaseConfig) GetClientCAFile() string { return m.clientCAFile }
|
||||
|
||||
type mockServerConfig struct {
|
||||
baseConfig *mockBaseConfig
|
||||
}
|
||||
|
||||
func (m *mockServerConfig) GetHost() string { return "localhost" }
|
||||
func (m *mockServerConfig) GetPort() string { return "8080" }
|
||||
func (m *mockServerConfig) GetBaseConfig() server.ServerConfig {
|
||||
return server.ServerConfig{Config: server.Config{CertFile: m.baseConfig.certFile, KeyFile: m.baseConfig.keyFile, ServerCAFile: m.baseConfig.serverCAFile, ClientCAFile: m.baseConfig.clientCAFile, Host: m.baseConfig.host, Port: m.baseConfig.port}}
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
name := "test-server"
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
logger := slog.Default()
|
||||
caURL := "https://ca.example.com"
|
||||
|
||||
server := NewServer(ctx, cancel, name, config, handler, logger, caURL)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
httpSrv, ok := server.(*httpServer)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, caURL, httpSrv.caURL)
|
||||
assert.NotNil(t, httpSrv.server)
|
||||
assert.Equal(t, handler, httpSrv.server.Handler)
|
||||
}
|
||||
|
||||
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config server.ServerConfiguration
|
||||
caURL string
|
||||
attestedTLS bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and caURL is not empty",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: true,
|
||||
},
|
||||
caURL: "https://ca.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when caURL is empty",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: true,
|
||||
},
|
||||
caURL: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when AttestedTLS is false",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: false,
|
||||
},
|
||||
caURL: "https://ca.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when config is not AgentConfig",
|
||||
config: &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
},
|
||||
caURL: "https://ca.example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
|
||||
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.caURL)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
result := httpSrv.shouldUseAttestedTLS()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
certFile string
|
||||
keyFile string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "should use regular TLS when both cert and key files are provided",
|
||||
certFile: "cert.pem",
|
||||
keyFile: "key.pem",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should use regular TLS when only cert file is provided",
|
||||
certFile: "cert.pem",
|
||||
keyFile: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should use regular TLS when only key file is provided",
|
||||
certFile: "",
|
||||
keyFile: "key.pem",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should not use regular TLS when neither cert nor key files are provided",
|
||||
certFile: "",
|
||||
keyFile: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
certFile: tt.certFile,
|
||||
keyFile: tt.keyFile,
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
result := httpSrv.shouldUseRegularTLS()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_Stop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Start a test server that we can control
|
||||
testServer := httptest.NewServer(handler)
|
||||
defer testServer.Close()
|
||||
|
||||
// Replace the server's HTTP server with our test server's
|
||||
httpSrv.server = testServer.Config
|
||||
|
||||
err := httpSrv.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHttpServer_logAttestedTLSStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mtls bool
|
||||
}{
|
||||
{
|
||||
name: "log attested mTLS start",
|
||||
mtls: true,
|
||||
},
|
||||
{
|
||||
name: "log attested TLS start",
|
||||
mtls: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// This test mainly ensures the method doesn't panic
|
||||
// In a real scenario, you might want to capture log output
|
||||
assert.NotPanics(t, func() {
|
||||
httpSrv.logAttestedTLSStart(tt.mtls)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_logRegularTLSStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mtls bool
|
||||
}{
|
||||
{
|
||||
name: "log regular mTLS start",
|
||||
mtls: true,
|
||||
},
|
||||
{
|
||||
name: "log regular TLS start",
|
||||
mtls: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
certFile: "cert.pem",
|
||||
keyFile: "key.pem",
|
||||
serverCAFile: "server-ca.pem",
|
||||
clientCAFile: "client-ca.pem",
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// This test mainly ensures the method doesn't panic
|
||||
assert.NotPanics(t, func() {
|
||||
httpSrv.logRegularTLSStart(tt.mtls)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_startWithoutTLS(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Use a test server to avoid binding to actual ports
|
||||
testServer := httptest.NewServer(handler)
|
||||
defer testServer.Close()
|
||||
|
||||
httpSrv.server = testServer.Config
|
||||
|
||||
err := httpSrv.startWithoutTLS()
|
||||
// The error will be related to context cancellation or server shutdown
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHttpServer_Protocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupTLS func(*httpServer)
|
||||
expectedProto string
|
||||
}{
|
||||
{
|
||||
name: "HTTP protocol without TLS",
|
||||
setupTLS: func(s *httpServer) {
|
||||
s.Protocol = httpProtocol
|
||||
},
|
||||
expectedProto: httpProtocol,
|
||||
},
|
||||
{
|
||||
name: "HTTPS protocol with TLS",
|
||||
setupTLS: func(s *httpServer) {
|
||||
s.Protocol = httpsProtocol
|
||||
s.server.TLSConfig = &tls.Config{}
|
||||
},
|
||||
expectedProto: httpsProtocol,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
tt.setupTLS(httpSrv)
|
||||
|
||||
assert.Equal(t, tt.expectedProto, httpSrv.Protocol)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Cancel the context immediately
|
||||
cancel()
|
||||
|
||||
// The listenAndServe method should handle context cancellation
|
||||
err := httpSrv.listenAndServe(false)
|
||||
assert.NoError(t, err) // Should return no error when context is cancelled and Stop() succeeds
|
||||
}
|
||||
|
||||
func TestHttpServer_TLSConfiguration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
certFile: "cert.pem",
|
||||
keyFile: "key.pem",
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Test TLS configuration setup
|
||||
httpSrv.server.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
assert.NotNil(t, httpSrv.server.TLSConfig)
|
||||
assert.Equal(t, uint16(tls.VersionTLS12), httpSrv.server.TLSConfig.MinVersion)
|
||||
}
|
||||
|
||||
// Integration-style test for server lifecycle.
|
||||
func TestHttpServer_Lifecycle(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
host: "localhost",
|
||||
port: "8080",
|
||||
},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
|
||||
|
||||
// Test that server can be created and has expected initial state
|
||||
httpSrv, ok := server.(*httpServer)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, httpSrv.server)
|
||||
assert.Equal(t, "localhost:8080", httpSrv.server.Addr)
|
||||
|
||||
// Test Stop without Start (should not panic)
|
||||
err := httpSrv.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -20,7 +20,7 @@ type ServerConfiguration interface {
|
||||
GetBaseConfig() ServerConfig
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
type Config struct {
|
||||
Host string `env:"HOST" envDefault:"localhost"`
|
||||
Port string `env:"PORT" envDefault:"7001"`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
@@ -30,7 +30,7 @@ type BaseConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
BaseConfig
|
||||
Config
|
||||
}
|
||||
type AgentConfig struct {
|
||||
ServerConfig
|
||||
@@ -55,19 +55,20 @@ func (a AgentConfig) GetBaseConfig() ServerConfig {
|
||||
return a.ServerConfig
|
||||
}
|
||||
|
||||
func stopAllServer(servers ...Server) error {
|
||||
var errs []error
|
||||
for _, server := range servers {
|
||||
if err := server.Stop(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
func NewBaseServer(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config ServerConfiguration, logger *slog.Logger,
|
||||
) BaseServer {
|
||||
cfg := config.GetBaseConfig()
|
||||
address := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("encountered errors while stopping servers: %v", errs)
|
||||
return BaseServer{
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
Name: name,
|
||||
Address: address,
|
||||
Config: config,
|
||||
Logger: logger,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, svcName string, servers ...Server) error {
|
||||
@@ -87,3 +88,18 @@ func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Lo
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func stopAllServer(servers ...Server) error {
|
||||
var errs []error
|
||||
for _, server := range servers {
|
||||
if err := server.Stop(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("encountered errors while stopping servers: %v", errs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/internal/server/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/server/mocks"
|
||||
)
|
||||
|
||||
func TestStopAllServer(t *testing.T) {
|
||||
@@ -0,0 +1,161 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAppendServerCA = errors.New("failed to append server ca to tls.Config")
|
||||
ErrAppendClientCA = errors.New("failed to append client ca to tls.Config")
|
||||
)
|
||||
|
||||
// TLSSetupResult contains the result of TLS configuration setup.
|
||||
type TLSSetupResult struct {
|
||||
Config *tls.Config
|
||||
MTLS bool
|
||||
}
|
||||
|
||||
// LoadCertFile loads certificate data from file path or returns empty byte slice if path is empty.
|
||||
func LoadCertFile(certFile string) ([]byte, error) {
|
||||
if certFile != "" {
|
||||
return ReadFileOrData(certFile)
|
||||
}
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
// ReadFileOrData reads data from file if input looks like a file path,
|
||||
// otherwise treats input as raw data.
|
||||
func ReadFileOrData(input string) ([]byte, error) {
|
||||
if len(input) < 1000 && !strings.Contains(input, "\n") {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return []byte(input), nil
|
||||
}
|
||||
|
||||
// LoadX509KeyPair loads X.509 key pair from certificate and key files or data.
|
||||
func LoadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
cert, err := ReadFileOrData(certfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read cert: %w", err)
|
||||
}
|
||||
|
||||
key, err := ReadFileOrData(keyfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read key: %w", err)
|
||||
}
|
||||
|
||||
return tls.X509KeyPair(cert, key)
|
||||
}
|
||||
|
||||
// ConfigureRootCA configures the root CA certificates for the TLS config.
|
||||
func ConfigureRootCA(tlsConfig *tls.Config, serverCAFile string) error {
|
||||
rootCA, err := LoadCertFile(serverCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load server ca file: %w", err)
|
||||
}
|
||||
|
||||
if len(rootCA) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
|
||||
return ErrAppendServerCA
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigureClientCA configures the client CA certificates for the TLS config
|
||||
// Returns true if client CA was configured, false otherwise.
|
||||
func ConfigureClientCA(tlsConfig *tls.Config, clientCAFile string) (bool, error) {
|
||||
clientCA, err := LoadCertFile(clientCAFile)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load client ca file: %w", err)
|
||||
}
|
||||
|
||||
if len(clientCA) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if tlsConfig.ClientCAs == nil {
|
||||
tlsConfig.ClientCAs = x509.NewCertPool()
|
||||
}
|
||||
|
||||
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
|
||||
return false, ErrAppendClientCA
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ConfigureCertificateAuthorities configures both root and client CAs for the TLS config
|
||||
// Returns true if mTLS should be enabled (client CA is configured).
|
||||
func ConfigureCertificateAuthorities(tlsConfig *tls.Config, serverCAFile, clientCAFile string) (bool, error) {
|
||||
// Configure root CA
|
||||
if err := ConfigureRootCA(tlsConfig, serverCAFile); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Configure client CA
|
||||
hasClientCA, err := ConfigureClientCA(tlsConfig, clientCAFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return hasClientCA, nil
|
||||
}
|
||||
|
||||
// SetupRegularTLS sets up TLS configuration using regular certificates.
|
||||
func SetupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*TLSSetupResult, error) {
|
||||
certificate, err := LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load auth certificates: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
}
|
||||
|
||||
mtls, err := ConfigureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
return &TLSSetupResult{Config: tlsConfig, MTLS: mtls}, nil
|
||||
}
|
||||
|
||||
// BuildMTLSDescription builds a description string for mTLS configuration.
|
||||
func BuildMTLSDescription(serverCAFile, clientCAFile string) string {
|
||||
var parts []string
|
||||
|
||||
if serverCAFile != "" {
|
||||
parts = append(parts, fmt.Sprintf("root ca %s", serverCAFile))
|
||||
}
|
||||
|
||||
if clientCAFile != "" {
|
||||
parts = append(parts, fmt.Sprintf("client ca %s", clientCAFile))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
@@ -0,0 +1,741 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Helper function to generate a test certificate and key.
|
||||
func generateTestCert() (certPEM, keyPEM []byte, err error) {
|
||||
// Generate private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
Country: []string{"US"},
|
||||
Province: []string{""},
|
||||
Locality: []string{"Test City"},
|
||||
StreetAddress: []string{""},
|
||||
PostalCode: []string{""},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
IPAddresses: nil,
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Encode certificate
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
// Encode private key
|
||||
privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: privateKeyDER,
|
||||
})
|
||||
|
||||
return certPEM, keyPEM, nil
|
||||
}
|
||||
|
||||
// Helper function to create temporary files for testing.
|
||||
func createTempFile(t *testing.T, content []byte) string {
|
||||
tmpFile, err := os.CreateTemp("", "test-cert-*.pem")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
if _, err := tmpFile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
return tmpFile.Name()
|
||||
}
|
||||
|
||||
func TestLoadCertFile(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certFile string
|
||||
wantErr bool
|
||||
setup func() string
|
||||
cleanup func(string)
|
||||
}{
|
||||
{
|
||||
name: "empty cert file path",
|
||||
certFile: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid cert file",
|
||||
wantErr: false,
|
||||
setup: func() string {
|
||||
return createTempFile(t, certPEM)
|
||||
},
|
||||
cleanup: func(path string) {
|
||||
os.Remove(path)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
certFile: "/non/existent/file.pem",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
certFile := tt.certFile
|
||||
if tt.setup != nil {
|
||||
certFile = tt.setup()
|
||||
}
|
||||
if tt.cleanup != nil {
|
||||
defer tt.cleanup(certFile)
|
||||
}
|
||||
|
||||
data, err := LoadCertFile(certFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCertFile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.certFile != "" && !tt.wantErr && len(data) == 0 {
|
||||
t.Errorf("LoadCertFile() with valid file should return data, got empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileOrData(t *testing.T) {
|
||||
testData := "test certificate data"
|
||||
tempFile := createTempFile(t, []byte(testData))
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "file path",
|
||||
input: tempFile,
|
||||
want: testData,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "raw data with newlines",
|
||||
input: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
|
||||
want: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short raw data without newlines",
|
||||
input: "short data",
|
||||
want: "short data",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file path",
|
||||
input: "/non/existent/file.pem",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ReadFileOrData(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && string(got) != tt.want {
|
||||
t.Errorf("ReadFileOrData() = %v, want %v", string(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadX509KeyPair(t *testing.T) {
|
||||
certPEM, keyPEM, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
certFile := createTempFile(t, certPEM)
|
||||
keyFile := createTempFile(t, keyPEM)
|
||||
defer os.Remove(certFile)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certfile string
|
||||
keyfile string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid cert and key files",
|
||||
certfile: certFile,
|
||||
keyfile: keyFile,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid cert and key data",
|
||||
certfile: string(certPEM),
|
||||
keyfile: string(keyPEM),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent cert file",
|
||||
certfile: "/non/existent/cert.pem",
|
||||
keyfile: keyFile,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent key file",
|
||||
certfile: certFile,
|
||||
keyfile: "/non/existent/key.pem",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid cert data",
|
||||
certfile: "invalid cert data",
|
||||
keyfile: string(keyPEM),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid key data",
|
||||
certfile: string(certPEM),
|
||||
keyfile: "invalid key data",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cert, err := LoadX509KeyPair(tt.certfile, tt.keyfile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadX509KeyPair() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && len(cert.Certificate) == 0 {
|
||||
t.Errorf("LoadX509KeyPair() returned empty certificate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureRootCA(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer os.Remove(caFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
serverCAFile string
|
||||
wantErr bool
|
||||
expectCA bool
|
||||
}{
|
||||
{
|
||||
name: "valid CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
wantErr: false,
|
||||
expectCA: true,
|
||||
},
|
||||
{
|
||||
name: "valid CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: string(certPEM),
|
||||
wantErr: false,
|
||||
expectCA: true,
|
||||
},
|
||||
{
|
||||
name: "empty CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "",
|
||||
wantErr: false,
|
||||
expectCA: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "/non/existent/ca.pem",
|
||||
wantErr: true,
|
||||
expectCA: false,
|
||||
},
|
||||
{
|
||||
name: "invalid CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "invalid ca data",
|
||||
wantErr: true,
|
||||
expectCA: false,
|
||||
},
|
||||
{
|
||||
name: "existing RootCAs pool",
|
||||
tlsConfig: &tls.Config{RootCAs: x509.NewCertPool()},
|
||||
serverCAFile: caFile,
|
||||
wantErr: false,
|
||||
expectCA: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ConfigureRootCA(tt.tlsConfig, tt.serverCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ConfigureRootCA() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.expectCA && tt.tlsConfig.RootCAs == nil {
|
||||
t.Errorf("ConfigureRootCA() should have created RootCAs pool")
|
||||
}
|
||||
|
||||
if !tt.expectCA && tt.tlsConfig.RootCAs != nil && tt.serverCAFile == "" {
|
||||
t.Errorf("ConfigureRootCA() should not have created RootCAs pool for empty file")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureClientCA(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer os.Remove(caFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
clientCAFile string
|
||||
wantConfigured bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid client CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: caFile,
|
||||
wantConfigured: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid client CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: string(certPEM),
|
||||
wantConfigured: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty client CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: "",
|
||||
wantConfigured: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent client CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: "/non/existent/ca.pem",
|
||||
wantConfigured: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid client CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: "invalid ca data",
|
||||
wantConfigured: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "existing ClientCAs pool",
|
||||
tlsConfig: &tls.Config{ClientCAs: x509.NewCertPool()},
|
||||
clientCAFile: caFile,
|
||||
wantConfigured: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configured, err := ConfigureClientCA(tt.tlsConfig, tt.clientCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ConfigureClientCA() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if configured != tt.wantConfigured {
|
||||
t.Errorf("ConfigureClientCA() configured = %v, want %v", configured, tt.wantConfigured)
|
||||
}
|
||||
|
||||
if tt.wantConfigured && tt.tlsConfig.ClientCAs == nil {
|
||||
t.Errorf("ConfigureClientCA() should have created ClientCAs pool")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureCertificateAuthorities(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer os.Remove(caFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
wantMTLS bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "both server and client CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "only server CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "only client CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "",
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no CAs",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid server CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "/non/existent/server-ca.pem",
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid client CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: "/non/existent/client-ca.pem",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mtls, err := ConfigureCertificateAuthorities(tt.tlsConfig, tt.serverCAFile, tt.clientCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ConfigureCertificateAuthorities() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if mtls != tt.wantMTLS {
|
||||
t.Errorf("ConfigureCertificateAuthorities() mtls = %v, want %v", mtls, tt.wantMTLS)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupRegularTLS(t *testing.T) {
|
||||
certPEM, keyPEM, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
certFile := createTempFile(t, certPEM)
|
||||
keyFile := createTempFile(t, keyPEM)
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer func() {
|
||||
os.Remove(certFile)
|
||||
os.Remove(keyFile)
|
||||
os.Remove(caFile)
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certFile string
|
||||
keyFile string
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
wantMTLS bool
|
||||
wantErr bool
|
||||
expectedAuth tls.ClientAuthType
|
||||
}{
|
||||
{
|
||||
name: "regular TLS without mTLS",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "TLS with mTLS",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: true,
|
||||
wantErr: false,
|
||||
expectedAuth: tls.RequireAndVerifyClientCert,
|
||||
},
|
||||
{
|
||||
name: "TLS with only server CA",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "invalid certificate file",
|
||||
certFile: "/non/existent/cert.pem",
|
||||
keyFile: keyFile,
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "invalid key file",
|
||||
certFile: certFile,
|
||||
keyFile: "/non/existent/key.pem",
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "invalid server CA file",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: "/non/existent/server-ca.pem",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SetupRegularTLS(tt.certFile, tt.keyFile, tt.serverCAFile, tt.clientCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SetupRegularTLS() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("SetupRegularTLS() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
if result.MTLS != tt.wantMTLS {
|
||||
t.Errorf("SetupRegularTLS() MTLS = %v, want %v", result.MTLS, tt.wantMTLS)
|
||||
}
|
||||
|
||||
if result.Config.ClientAuth != tt.expectedAuth {
|
||||
t.Errorf("SetupRegularTLS() ClientAuth = %v, want %v", result.Config.ClientAuth, tt.expectedAuth)
|
||||
}
|
||||
|
||||
if len(result.Config.Certificates) == 0 {
|
||||
t.Errorf("SetupRegularTLS() should have at least one certificate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMTLSDescription(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "both server and client CA files",
|
||||
serverCAFile: "/path/to/server-ca.pem",
|
||||
clientCAFile: "/path/to/client-ca.pem",
|
||||
want: "root ca /path/to/server-ca.pem client ca /path/to/client-ca.pem",
|
||||
},
|
||||
{
|
||||
name: "only server CA file",
|
||||
serverCAFile: "/path/to/server-ca.pem",
|
||||
clientCAFile: "",
|
||||
want: "root ca /path/to/server-ca.pem",
|
||||
},
|
||||
{
|
||||
name: "only client CA file",
|
||||
serverCAFile: "",
|
||||
clientCAFile: "/path/to/client-ca.pem",
|
||||
want: "client ca /path/to/client-ca.pem",
|
||||
},
|
||||
{
|
||||
name: "no CA files",
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := BuildMTLSDescription(tt.serverCAFile, tt.clientCAFile)
|
||||
if got != tt.want {
|
||||
t.Errorf("BuildMTLSDescription() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorConstants(t *testing.T) {
|
||||
// Test that error constants are properly defined
|
||||
if ErrAppendServerCA == nil {
|
||||
t.Error("ErrAppendServerCA should not be nil")
|
||||
}
|
||||
|
||||
if ErrAppendClientCA == nil {
|
||||
t.Error("ErrAppendClientCA should not be nil")
|
||||
}
|
||||
|
||||
if ErrAppendServerCA.Error() != "failed to append server ca to tls.Config" {
|
||||
t.Errorf("ErrAppendServerCA message = %v, want 'failed to append server ca to tls.Config'", ErrAppendServerCA.Error())
|
||||
}
|
||||
|
||||
if ErrAppendClientCA.Error() != "failed to append client ca to tls.Config" {
|
||||
t.Errorf("ErrAppendClientCA message = %v, want 'failed to append client ca to tls.Config'", ErrAppendClientCA.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSSetupResult(t *testing.T) {
|
||||
// Test that TLSSetupResult struct works as expected
|
||||
config := &tls.Config{}
|
||||
result := &TLSSetupResult{
|
||||
Config: config,
|
||||
MTLS: true,
|
||||
}
|
||||
|
||||
if result.Config != config {
|
||||
t.Error("TLSSetupResult Config field should match assigned value")
|
||||
}
|
||||
|
||||
if !result.MTLS {
|
||||
t.Error("TLSSetupResult MTLS field should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileOrDataEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "999 chars without newline (should try file)",
|
||||
input: strings.Repeat("a", 999),
|
||||
wantErr: true, // Should fail as file doesn't exist
|
||||
},
|
||||
{
|
||||
name: "1001 chars without newline (should treat as data)",
|
||||
input: strings.Repeat("a", 1001),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short string with newline (should treat as data)",
|
||||
input: "short\ndata",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ReadFileOrData(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+3
-3
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -182,7 +182,7 @@ func main() {
|
||||
cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger}))
|
||||
}
|
||||
grpcServerConfig := server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Config: server.Config{
|
||||
Port: defaultPort,
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user