NOISSUE - Refactor http and grpc clients for reusability with Cube (#521)

* Implement gRPC server with TLS and mTLS support

- Added gRPC server implementation in pkg/server/grpc.
- Introduced server configuration options for TLS and mTLS.
- Implemented health check service for gRPC.
- Created tests for server initialization, startup, and shutdown scenarios.
- Added mock server for testing purposes.
- Implemented graceful shutdown handling for the server.
- Included documentation for the server package.

Signed-off-by: SammyOina <sammyoina@gmail.com>

* Add TLS and ATLS support to gRPC and HTTP clients; refactor security handling

Signed-off-by: SammyOina <sammyoina@gmail.com>

* Refactor server configuration structure to use Config instead of BaseConfig

Signed-off-by: SammyOina <sammyoina@gmail.com>

* Fix comments for consistency and clarity in TLS-related code

Signed-off-by: SammyOina <sammyoina@gmail.com>

* Add comprehensive tests for TLS and ATLS configurations in clients package

Signed-off-by: SammyOina <sammyoina@gmail.com>

* Refactor file permission constants in client tests to use octal notation

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add tests for HTTP server's TLS configuration and lifecycle management

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for TLS certificate handling and configuration

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for HTTP client configuration and transport

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor AttestationReportSize constant declaration for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor client configuration structure and update gRPC client implementations

- Consolidated client configuration types into a unified structure with BaseConfig.
- Introduced AttestedClientConfig and StandardClientConfig for specific use cases.
- Updated gRPC client creation functions to utilize new configuration types.
- Refactored tests to align with the new configuration structure.
- Removed redundant ClientConfiguration interface and related methods.
- Simplified TLS configuration loading logic for both standard and attested clients.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor client configuration structure and TLS handling

- Introduced StandardClientConfig to replace BaseConfig, simplifying client configuration.
- Updated AttestedClientConfig to embed StandardClientConfig instead of BaseConfig.
- Modified ClientConfiguration interface to use Config() method instead of GetBaseConfig().
- Refactored various client tests to accommodate changes in configuration structure.
- Added new TLS handling functions to support basic and attested TLS configurations.
- Implemented comprehensive tests for TLS loading and configuration validation.
- Removed deprecated methods and unnecessary code related to BaseConfig.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: SammyOina <sammyoina@gmail.com>
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-09-18 18:10:20 +03:00
committed by GitHub
parent 5377dd4d7f
commit 906d7877b2
35 changed files with 2961 additions and 524 deletions
+1 -1
View File
@@ -68,7 +68,7 @@ packages:
dir: '{{.InterfaceDir}}/mocks'
structname: '{{.InterfaceName}}'
filename: "{{.InterfaceName | lower}}.go"
github.com/ultravioletrs/cocos/internal/server:
github.com/ultravioletrs/cocos/pkg/server:
interfaces:
Server:
config:
+3 -3
View File
@@ -11,8 +11,8 @@ import (
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
@@ -53,7 +53,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
agentGrpcServerConfig := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: as.host,
Port: cfg.Port,
CertFile: cfg.CertFile,
+4 -3
View File
@@ -8,6 +8,7 @@ import (
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients/grpc/agent"
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
@@ -18,15 +19,15 @@ var Verbose bool
type CLI struct {
agentSDK sdk.SDK
agentConfig grpc.AgentClientConfig
managerConfig grpc.ManagerClientConfig
agentConfig clients.AttestedClientConfig
managerConfig clients.StandardClientConfig
client grpc.Client
managerClient manager.ManagerServiceClient
connectErr error
measurement cmdconfig.MeasurementProvider
}
func New(agentConfig grpc.AgentClientConfig, managerConfig grpc.ManagerClientConfig, measurement cmdconfig.MeasurementProvider) *CLI {
func New(agentConfig clients.AttestedClientConfig, managerConfig clients.StandardClientConfig, measurement cmdconfig.MeasurementProvider) *CLI {
return &CLI{
agentConfig: agentConfig,
managerConfig: managerConfig,
+2 -1
View File
@@ -30,6 +30,7 @@ import (
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
"github.com/ultravioletrs/cocos/pkg/attestation/tdx"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/clients"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm"
"golang.org/x/sync/errgroup"
@@ -111,7 +112,7 @@ func main() {
provider = &attestation.EmptyProvider{}
}
cvmGrpcConfig := pkggrpc.CVMClientConfig{}
cvmGrpcConfig := clients.StandardClientConfig{}
if err := env.ParseWithOptions(&cvmGrpcConfig, env.Options{Prefix: envPrefixCVMGRPC}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
exitCode = 1
+3 -3
View File
@@ -15,7 +15,7 @@ import (
"github.com/spf13/pflag"
"github.com/ultravioletrs/cocos/cli"
"github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients"
cmd "github.com/virtee/sev-snp-measure-go/sevsnpmeasure/cmd"
)
@@ -94,14 +94,14 @@ func main() {
return
}
agentGRPCConfig := grpc.AgentClientConfig{}
agentGRPCConfig := clients.AttestedClientConfig{}
if err := env.ParseWithOptions(&agentGRPCConfig, env.Options{Prefix: envPrefixAgentGRPC}); err != nil {
message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)
rootCmd.Println(message)
return
}
managerGRPCConfig := grpc.ManagerClientConfig{}
managerGRPCConfig := clients.StandardClientConfig{}
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixManagerGRPC}); err != nil {
message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)
rootCmd.Println(message)
+2 -2
View File
@@ -20,14 +20,14 @@ import (
"github.com/absmach/supermq/pkg/uuid"
"github.com/caarlos0/env/v11"
"github.com/go-chi/chi/v5"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/api"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/manager/api/http"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/tracing"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
+53 -5
View File
@@ -47,9 +47,11 @@ const (
)
var (
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
SNPvTPMOID = asn1.ObjectIdentifier{2, 99999, 1, 0}
AzureOID = asn1.ObjectIdentifier{2, 99999, 1, 1}
TDXOID = asn1.ObjectIdentifier{2, 99999, 1, 2}
errCertificateParse = errors.New("failed to parse x509 certificate")
errAttVerification = errors.New("certificate is not self signed")
)
type csrReq struct {
@@ -103,7 +105,7 @@ func getOID(platformType attestation.PlatformType) (asn1.ObjectIdentifier, error
}
}
func GetPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
func getPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType, error) {
switch {
case oid.Equal(SNPvTPMOID):
return attestation.SNPvTPM, nil
@@ -116,7 +118,7 @@ func GetPlatformTypeFromOID(oid asn1.ObjectIdentifier) (attestation.PlatformType
}
}
func VerifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
func verifyCertificateExtension(extension []byte, pubKey []byte, nonce []byte, pType attestation.PlatformType) error {
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
@@ -331,3 +333,49 @@ func processRequest(method, reqUrl string, data []byte, headers map[string]strin
}
return resp.Header, body, nil
}
// VerifyPeerCertificateATLS verifies peer certificates for Attested TLS.
func VerifyPeerCertificateATLS(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errors.Wrap(errCertificateParse, err)
}
err = verifyCertificateSignature(cert, rootCAs)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
for _, ext := range cert.Extensions {
pType, err := getPlatformTypeFromOID(ext.Id)
if err == nil {
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
}
return verifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
}
}
return errors.New("attestation extension not found in certificate")
}
// VerifyCertificateSignature verifies the certificate signature against root CAs.
func verifyCertificateSignature(cert *x509.Certificate, rootCAs *x509.CertPool) error {
if rootCAs == nil {
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
}
if _, err := cert.Verify(opts); err != nil {
return err
}
return nil
}
+233 -2
View File
@@ -6,17 +6,21 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
certssdk "github.com/absmach/certs/sdk"
"github.com/absmach/supermq/pkg/errors"
@@ -217,7 +221,7 @@ func TestGetPlatformTypeFromOID(t *testing.T) {
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
pType, err := GetPlatformTypeFromOID(c.oid)
pType, err := getPlatformTypeFromOID(c.oid)
if c.expectedError != nil {
assert.Error(t, err)
@@ -305,7 +309,7 @@ func TestVerifyCertificateExtension(t *testing.T) {
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err := VerifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
err := verifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
if c.expectError {
assert.Error(t, err)
} else {
@@ -622,3 +626,230 @@ func setAttestationPolicy(rr *sevsnp.Attestation, policyDirectory string) error
return nil
}
// TestCertificateVerification unified test suite for certificate verification.
func TestCertificateVerification(t *testing.T) {
// Setup common test data
selfSignedCert := createSelfSignedCert(t)
leafCert, rootCert := generateCertificateChain(t)
rootCAs := createCertPool(rootCert)
emptyPool := x509.NewCertPool()
t.Run("SelfSignedCertificates", func(t *testing.T) {
testCases := []testCase{
{
name: "ValidSelfSignedCertificate",
cert: selfSignedCert,
rootCAs: nil,
expectError: false,
},
{
name: "EmptyCertificate",
cert: &x509.Certificate{},
rootCAs: nil,
expectError: true,
errorMsg: "x509: missing ASN.1 contents; use ParseCertificate",
},
}
runCertificateVerificationTests(t, testCases)
})
t.Run("CertificateChainVerification", func(t *testing.T) {
testCases := []testCase{
{
name: "ValidCertificateWithRootCA",
cert: leafCert,
rootCAs: rootCAs,
expectError: false,
},
{
name: "SelfSignedCertificate",
cert: rootCert,
rootCAs: nil, // Self-signed verification
expectError: false,
},
{
name: "InvalidCertificateWithEmptyPool",
cert: rootCert,
rootCAs: emptyPool,
expectError: true,
},
}
runCertificateVerificationTests(t, testCases)
})
t.Run("ATLSPeerCertificateVerification", func(t *testing.T) {
nonce := generateNonce(t)
testCases := []atlsTestCase{
{
name: "InvalidCertificateData",
rawCerts: [][]byte{[]byte("invalid cert data")},
nonce: nonce,
rootCAs: rootCAs,
expectError: true,
errorMsg: "failed to parse x509 certificate",
},
{
name: "ValidCertificateNoAttestationExtension",
rawCerts: [][]byte{leafCert.Raw},
nonce: nonce,
rootCAs: rootCAs,
expectError: true,
errorMsg: "attestation extension not found in certificate",
},
}
runATLSVerificationTests(t, testCases)
})
}
// Unified test case structures.
type testCase struct {
name string
cert *x509.Certificate
rootCAs *x509.CertPool
expectError bool
errorMsg string
}
type atlsTestCase struct {
name string
rawCerts [][]byte
nonce []byte
rootCAs *x509.CertPool
expectError bool
errorMsg string
}
// Unified test runners.
func runCertificateVerificationTests(t *testing.T, testCases []testCase) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := verifyCertificateSignature(tc.cert, tc.rootCAs)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
if tc.errorMsg == "x509: missing ASN.1 contents; use ParseCertificate" {
// For specific error matching
assert.True(t, errors.Contains(err, errors.New(tc.errorMsg)),
fmt.Sprintf("expected error %q, got %v", tc.errorMsg, err))
} else {
assert.Contains(t, err.Error(), tc.errorMsg)
}
}
} else {
assert.NoError(t, err)
}
})
}
}
func runATLSVerificationTests(t *testing.T, testCases []atlsTestCase) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := VerifyPeerCertificateATLS(tc.rawCerts, nil, tc.nonce, tc.rootCAs)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
assert.Contains(t, err.Error(), tc.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
// Unified certificate creation utilities.
func createSelfSignedCert(t *testing.T) *x509.Certificate {
privateKey := generateRSAKey(t)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour), // Consistent duration
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
return createCertificateFromTemplate(t, &template, &template, &privateKey.PublicKey, privateKey)
}
func generateCertificateChain(t *testing.T) (leafCert, rootCert *x509.Certificate) {
// Generate root certificate
rootKey := generateRSAKey(t)
rootTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Root CA"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
rootCert = createCertificateFromTemplate(t, &rootTemplate, &rootTemplate, &rootKey.PublicKey, rootKey)
// Generate leaf certificate signed by root
leafKey := generateRSAKey(t)
leafTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Leaf"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
leafCert = createCertificateFromTemplate(t, &leafTemplate, &rootTemplate, &leafKey.PublicKey, rootKey)
return leafCert, rootCert
}
// Helper functions for consistency.
func generateRSAKey(t *testing.T) *rsa.PrivateKey {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
return privateKey
}
func createCertificateFromTemplate(t *testing.T, template, parent *x509.Certificate, pub interface{}, priv interface{}) *x509.Certificate {
certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
func createCertPool(certs ...*x509.Certificate) *x509.CertPool {
pool := x509.NewCertPool()
for _, cert := range certs {
pool.AddCert(cert)
}
return pool
}
func generateNonce(t *testing.T) []byte {
nonce := make([]byte, 64)
_, err := rand.Read(nonce)
require.NoError(t, err)
return nonce
}
+40
View File
@@ -0,0 +1,40 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
import "time"
var (
_ ClientConfiguration = (*AttestedClientConfig)(nil)
_ ClientConfiguration = (*StandardClientConfig)(nil)
)
type ClientConfiguration interface {
Config() StandardClientConfig
}
// StandardClientConfig represents a basic client configuration without attested TLS.
type StandardClientConfig struct {
URL string `env:"URL" envDefault:"localhost:7001"`
Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"`
ClientCert string `env:"CLIENT_CERT" envDefault:""`
ClientKey string `env:"CLIENT_KEY" envDefault:""`
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
}
// AttestedClientConfig represents a client configuration with attested TLS capabilities.
type AttestedClientConfig struct {
StandardClientConfig
AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""`
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"`
}
func (c AttestedClientConfig) Config() StandardClientConfig {
return c.StandardClientConfig
}
func (c StandardClientConfig) Config() StandardClientConfig {
return c
}
+1 -1
View File
@@ -2,5 +2,5 @@
// SPDX-License-Identifier: Apache-2.0
// Package clients contains the domain concept definitions needed to support
// Agent Client functionality.
// HTTP/gRPC Client functionality.
package clients
+3 -2
View File
@@ -7,6 +7,7 @@ import (
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
@@ -14,13 +15,13 @@ import (
var ErrAgentServiceUnavailable = errors.New("agent service is unavailable")
// NewAgentClient creates new agent gRPC client instance.
func NewAgentClient(ctx context.Context, cfg grpc.AgentClientConfig) (grpc.Client, agent.AgentServiceClient, error) {
func NewAgentClient(ctx context.Context, cfg clients.AttestedClientConfig) (grpc.Client, agent.AgentServiceClient, error) {
client, err := grpc.NewClient(cfg)
if err != nil {
return nil, nil, err
}
if client.Secure() != grpc.WithMATLS && client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
if client.Secure() != clients.WithMATLS.String() && client.Secure() != clients.WithATLS.String() && client.Secure() != clients.WithTLS.String() {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
+8 -8
View File
@@ -15,7 +15,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/mocks"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
@@ -78,14 +78,14 @@ func TestAgentClientIntegration(t *testing.T) {
tests := []struct {
name string
serverRunning bool
config pkggrpc.AgentClientConfig
config clients.AttestedClientConfig
err error
}{
{
name: "successful connection",
serverRunning: true,
config: pkggrpc.AgentClientConfig{
BaseConfig: pkggrpc.BaseConfig{
config: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: testServer.listenAddr,
Timeout: 1,
},
@@ -95,8 +95,8 @@ func TestAgentClientIntegration(t *testing.T) {
{
name: "server not healthy",
serverRunning: false,
config: pkggrpc.AgentClientConfig{
BaseConfig: pkggrpc.BaseConfig{
config: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "",
Timeout: 1,
},
@@ -105,8 +105,8 @@ func TestAgentClientIntegration(t *testing.T) {
},
{
name: "invalid config, missing AttestationPolicy with aTLS",
config: pkggrpc.AgentClientConfig{
BaseConfig: pkggrpc.BaseConfig{
config: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: testServer.listenAddr,
Timeout: 1,
},
-133
View File
@@ -1,133 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"os"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/attestation"
"google.golang.org/grpc/credentials"
)
func setupATLS(cfg AgentClientConfig) (credentials.TransportCredentials, security, error) {
security := withaTLS
info, err := os.Stat(cfg.AttestationPolicy)
if err != nil {
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to stat attestation policy file"), err)
}
if !info.Mode().IsRegular() {
return nil, withoutTLS, fmt.Errorf("attestation policy file is not a regular file: %s", cfg.AttestationPolicy)
}
attestation.AttestationPolicyPath = cfg.AttestationPolicy
var rootCAs *x509.CertPool = nil
if len(cfg.ServerCAFile) > 0 {
// Read the certificate file
certPEM, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to read certificate file"), err)
}
// Decode the PEM block
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, withoutTLS, fmt.Errorf("failed to decode PEM block")
}
// Parse the certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to parse certificate"), err)
}
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
security = withmaTLS
}
nonce := make([]byte, 64)
if _, err := rand.Read(nonce); err != nil {
return nil, withoutTLS, errors.Wrap(fmt.Errorf("failed to generate nonce"), err)
}
encoded := hex.EncodeToString(nonce)
sni := fmt.Sprintf("%s.nonce", encoded)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
RootCAs: rootCAs,
ServerName: sni,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return verifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
},
}
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, withoutTLS, errors.Wrap(errFailedToLoadClientCertKey, err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
}
return credentials.NewTLS(tlsConfig), security, nil
}
func verifyPeerCertificateATLS(rawCerts [][]byte, verifiedChains [][]*x509.Certificate, nonce []byte, rootCAs *x509.CertPool) error {
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return errors.Wrap(errCertificateParse, err)
}
err = checkIfCertificateSigned(cert, rootCAs)
if err != nil {
return errors.Wrap(errAttVerification, err)
}
for _, ext := range cert.Extensions {
pType, err := atls.GetPlatformTypeFromOID(ext.Id)
if err == nil {
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key to DER format: %w", err)
}
return atls.VerifyCertificateExtension(ext.Value, pubKeyDER, nonce, pType)
}
}
return fmt.Errorf("attestation extension not found in certificate")
}
func checkIfCertificateSigned(cert *x509.Certificate, rootCAs *x509.CertPool) error {
if rootCAs == nil {
rootCAs = x509.NewCertPool()
rootCAs.AddCert(cert)
}
opts := x509.VerifyOptions{
Roots: rootCAs,
CurrentTime: time.Now(),
}
if _, err := cert.Verify(opts); err != nil {
return err
}
return nil
}
+26 -78
View File
@@ -21,6 +21,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/clients"
)
func TestNewClient(t *testing.T) {
@@ -35,14 +36,14 @@ func TestNewClient(t *testing.T) {
tests := []struct {
name string
cfg BaseConfig
agentCfg AgentClientConfig
cfg clients.StandardClientConfig
agentCfg clients.AttestedClientConfig
wantErr bool
err error
}{
{
name: "Success without TLS",
cfg: BaseConfig{
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
},
wantErr: false,
@@ -50,7 +51,7 @@ func TestNewClient(t *testing.T) {
},
{
name: "Success with TLS",
cfg: BaseConfig{
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
},
@@ -59,7 +60,7 @@ func TestNewClient(t *testing.T) {
},
{
name: "Success with mTLS",
cfg: BaseConfig{
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
@@ -70,8 +71,8 @@ func TestNewClient(t *testing.T) {
},
{
name: "Success agent client with mTLS",
agentCfg: AgentClientConfig{
BaseConfig: BaseConfig{
agentCfg: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
@@ -83,8 +84,8 @@ func TestNewClient(t *testing.T) {
},
{
name: "Success agent client with aTLS",
agentCfg: AgentClientConfig{
BaseConfig: BaseConfig{
agentCfg: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
@@ -98,8 +99,8 @@ func TestNewClient(t *testing.T) {
},
{
name: "Failed agent client with aTLS",
agentCfg: AgentClientConfig{
BaseConfig: BaseConfig{
agentCfg: clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
@@ -113,34 +114,34 @@ func TestNewClient(t *testing.T) {
},
{
name: "Fail with invalid ServerCAFile",
cfg: BaseConfig{
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: "nonexistent.pem",
},
wantErr: true,
err: errFailedToLoadRootCA,
err: clients.ErrFailedToLoadRootCA,
},
{
name: "Fail with invalid ClientCert",
cfg: BaseConfig{
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: "nonexistent.pem",
ClientKey: clientKeyFile,
},
wantErr: true,
err: errFailedToLoadClientCertKey,
err: clients.ErrFailedToLoadClientCertKey,
},
{
name: "Fail with invalid ClientKey",
cfg: BaseConfig{
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
ServerCAFile: caCertFile,
ClientCert: clientCertFile,
ClientKey: "nonexistent.pem",
},
wantErr: true,
err: errFailedToLoadClientCertKey,
err: clients.ErrFailedToLoadClientCertKey,
},
}
@@ -168,39 +169,39 @@ func TestNewClient(t *testing.T) {
func TestClientSecure(t *testing.T) {
tests := []struct {
name string
secure security
secure clients.Security
expected string
}{
{
name: "Without TLS",
secure: withoutTLS,
secure: clients.WithoutTLS,
expected: "without TLS",
},
{
name: "With TLS",
secure: withTLS,
secure: clients.WithTLS,
expected: "with TLS",
},
{
name: "With mTLS",
secure: withmTLS,
secure: clients.WithMTLS,
expected: "with mTLS",
},
{
name: "With aTLS",
secure: withaTLS,
secure: clients.WithATLS,
expected: "with aTLS",
},
{
name: "With maTLS",
secure: withmaTLS,
expected: WithMATLS,
secure: clients.WithMATLS,
expected: "with maTLS",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &client{secure: tt.secure}
c := &client{security: tt.secure}
assert.Equal(t, tt.expected, c.Secure())
})
}
@@ -354,56 +355,3 @@ func createTempFile(data []byte) (string, error) {
func createTempFileHandle() (*os.File, error) {
return os.CreateTemp("", "test")
}
func TestCheckIfCertificateSelfSigned(t *testing.T) {
selfSignedCert := createSelfSignedCert(t)
tests := []struct {
name string
cert *x509.Certificate
err error
}{
{
name: "Self-signed certificate",
cert: selfSignedCert,
err: nil,
},
{
name: "missing certificate contents",
cert: &x509.Certificate{},
err: errors.New("x509: missing ASN.1 contents; use ParseCertificate"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := checkIfCertificateSigned(tt.cert, nil)
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
})
}
}
func createSelfSignedCert(t *testing.T) *x509.Certificate {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert
}
+2 -1
View File
@@ -4,11 +4,12 @@ package cvm
import (
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
)
// NewManagerClient creates new manager gRPC client instance.
func NewCVMClient(cfg grpc.CVMClientConfig) (grpc.Client, cvms.ServiceClient, error) {
func NewCVMClient(cfg clients.StandardClientConfig) (grpc.Client, cvms.ServiceClient, error) {
client, err := grpc.NewClient(cfg)
if err != nil {
return nil, nil, err
+8 -12
View File
@@ -14,7 +14,7 @@ import (
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/mocks"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
@@ -77,28 +77,24 @@ func TestAgentClientIntegration(t *testing.T) {
tests := []struct {
name string
serverRunning bool
config pkggrpc.CVMClientConfig
config clients.StandardClientConfig
err error
}{
{
name: "successful connection",
serverRunning: true,
config: pkggrpc.CVMClientConfig{
BaseConfig: pkggrpc.BaseConfig{
URL: testServer.listenAddr,
Timeout: 1,
},
config: clients.StandardClientConfig{
URL: testServer.listenAddr,
Timeout: 1,
},
err: nil,
},
{
name: "server not healthy",
serverRunning: false,
config: pkggrpc.CVMClientConfig{
BaseConfig: pkggrpc.BaseConfig{
URL: "",
Timeout: 1,
},
config: clients.StandardClientConfig{
URL: "",
Timeout: 1,
},
err: errors.New("failed to connect to grpc server"),
},
+29 -127
View File
@@ -4,84 +4,19 @@
package grpc
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/clients"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
type security int
const (
withoutTLS security = iota
withTLS
withmTLS
withaTLS
withmaTLS
)
const (
AttestationReportSize = 0x4A0
WithMATLS = "with maTLS"
WithATLS = "with aTLS"
WithTLS = "with TLS"
)
var (
errGrpcConnect = errors.New("failed to connect to grpc server")
errGrpcClose = errors.New("failed to close grpc connection")
errCertificateParse = errors.New("failed to parse x509 certificate")
errAttVerification = errors.New("certificat is not self signed")
errFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
errFailedToLoadRootCA = errors.New("failed to load root ca file")
errGrpcConnect = errors.New("failed to connect to grpc server")
errGrpcClose = errors.New("failed to close grpc connection")
)
type ClientConfiguration interface {
GetBaseConfig() BaseConfig
}
type BaseConfig struct {
URL string `env:"URL" envDefault:"localhost:7001"`
Timeout time.Duration `env:"TIMEOUT" envDefault:"60s"`
ClientCert string `env:"CLIENT_CERT" envDefault:""`
ClientKey string `env:"CLIENT_KEY" envDefault:""`
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
}
type AgentClientConfig struct {
BaseConfig
AttestationPolicy string `env:"ATTESTATION_POLICY" envDefault:""`
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
ProductName string `env:"PRODUCT_NAME" envDefault:"Milan"`
}
type ManagerClientConfig struct {
BaseConfig
}
type CVMClientConfig struct {
BaseConfig
}
func (a BaseConfig) GetBaseConfig() BaseConfig {
return a
}
func (a AgentClientConfig) GetBaseConfig() BaseConfig {
return a.BaseConfig
}
func (a CVMClientConfig) GetBaseConfig() BaseConfig {
return a.BaseConfig
}
type Client interface {
Close() error
Secure() string
@@ -90,14 +25,14 @@ type Client interface {
type client struct {
*grpc.ClientConn
cfg ClientConfiguration
secure security
cfg clients.ClientConfiguration
security clients.Security
}
var _ Client = (*client)(nil)
func NewClient(cfg ClientConfiguration) (Client, error) {
conn, secure, err := connect(cfg)
func NewClient(cfg clients.ClientConfiguration) (Client, error) {
conn, security, err := connect(cfg)
if err != nil {
return nil, err
}
@@ -105,7 +40,7 @@ func NewClient(cfg ClientConfiguration) (Client, error) {
return &client{
ClientConn: conn,
cfg: cfg,
secure: secure,
security: security,
}, nil
}
@@ -117,86 +52,53 @@ func (c *client) Close() error {
}
func (c *client) Secure() string {
switch c.secure {
case withTLS:
return WithTLS
case withmTLS:
return "with mTLS"
case withaTLS:
return "with aTLS"
case withmaTLS:
return WithMATLS
default:
return "without TLS"
}
return c.security.String()
}
func (c *client) Connection() *grpc.ClientConn {
return c.ClientConn
}
func connect(cfg ClientConfiguration) (*grpc.ClientConn, security, error) {
func connect(cfg clients.ClientConfiguration) (*grpc.ClientConn, clients.Security, error) {
opts := []grpc.DialOption{
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
}
secure := withoutTLS
security := clients.WithoutTLS
if agcfg, ok := cfg.(AgentClientConfig); ok && agcfg.AttestedTLS {
tc, sec, err := setupATLS(agcfg)
if agcfg, ok := cfg.(clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
result, err := clients.LoadATLSConfig(agcfg)
if err != nil {
return nil, secure, err
return nil, security, err
}
opts = append(opts, grpc.WithTransportCredentials(tc))
secure = sec
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(result.Config)))
security = result.Security
} else {
conf := cfg.GetBaseConfig()
conf := cfg.Config()
transportCreds, sec, err := loadTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
if err != nil {
return nil, secure, err
return nil, security, err
}
opts = append(opts, grpc.WithTransportCredentials(transportCreds))
secure = sec
security = sec
}
conn, err := grpc.Dial(cfg.GetBaseConfig().URL, opts...)
conn, err := grpc.NewClient(cfg.Config().URL, opts...)
if err != nil {
return nil, secure, errors.Wrap(errGrpcConnect, err)
return nil, security, errors.Wrap(errGrpcConnect, err)
}
return conn, secure, nil
return conn, security, nil
}
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, security, error) {
tlsConfig := &tls.Config{}
secure := withoutTLS
tc := insecure.NewCredentials()
if serverCAFile != "" {
rootCA, err := os.ReadFile(serverCAFile)
if err != nil {
return nil, secure, errors.Wrap(errFailedToLoadRootCA, err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
tc = credentials.NewTLS(tlsConfig)
}
func loadTLSConfig(serverCAFile, clientCert, clientKey string) (credentials.TransportCredentials, clients.Security, error) {
result, err := clients.LoadBasicTLSConfig(serverCAFile, clientCert, clientKey)
if err != nil {
return nil, clients.WithoutTLS, err
}
if clientCert != "" || clientKey != "" {
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
if err != nil {
return nil, secure, errors.Wrap(errFailedToLoadClientCertKey, err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
tc = credentials.NewTLS(tlsConfig)
if result.Security == clients.WithoutTLS || result.Config == nil {
return insecure.NewCredentials(), result.Security, nil
}
return tc, secure, nil
return credentials.NewTLS(result.Config), result.Security, nil
}
+2 -1
View File
@@ -4,11 +4,12 @@ package manager
import (
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/pkg/clients"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
)
// NewManagerClient creates new manager gRPC client instance.
func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
func NewManagerClient(cfg clients.StandardClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
client, err := grpc.NewClient(cfg)
if err != nil {
return nil, nil, err
+4 -6
View File
@@ -7,21 +7,19 @@ import (
"github.com/absmach/supermq/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients"
)
func TestNewManagerClient(t *testing.T) {
tests := []struct {
name string
cfg grpc.ManagerClientConfig
cfg clients.StandardClientConfig
err error
}{
{
name: "Valid config",
cfg: grpc.ManagerClientConfig{
BaseConfig: grpc.BaseConfig{
URL: "localhost:7001",
},
cfg: clients.StandardClientConfig{
URL: "localhost:7001",
},
err: nil,
},
+85
View File
@@ -0,0 +1,85 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"net/http"
"time"
"github.com/ultravioletrs/cocos/pkg/clients"
)
type Client interface {
Transport() *http.Transport
Secure() string
Timeout() time.Duration
}
type client struct {
transport *http.Transport
cfg clients.ClientConfiguration
security clients.Security
}
var _ Client = (*client)(nil)
func NewClient(cfg clients.ClientConfiguration) (Client, error) {
transport, security, err := createTransport(cfg)
if err != nil {
return nil, err
}
return &client{
transport: transport,
cfg: cfg,
security: security,
}, nil
}
func (c *client) Transport() *http.Transport {
return c.transport
}
func (c *client) Secure() string {
return c.security.String()
}
func (c *client) Timeout() time.Duration {
return c.cfg.Config().Timeout
}
func createTransport(cfg clients.ClientConfiguration) (*http.Transport, clients.Security, error) {
transport := &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
}
security := clients.WithoutTLS
if agcfg, ok := cfg.(*clients.AttestedClientConfig); ok && agcfg.AttestedTLS {
result, err := clients.LoadATLSConfig(*agcfg)
if err != nil {
return nil, security, err
}
transport.TLSClientConfig = result.Config
security = result.Security
} else {
conf := cfg.Config()
result, err := clients.LoadBasicTLSConfig(conf.ServerCAFile, conf.ClientCert, conf.ClientKey)
if err != nil {
return nil, security, err
}
if result.Security != clients.WithoutTLS {
transport.TLSClientConfig = result.Config
}
security = result.Security
}
return transport, security, nil
}
+292
View File
@@ -0,0 +1,292 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/clients"
)
func TestConfig_Configuration(t *testing.T) {
config := clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
ClientCert: "cert.pem",
ClientKey: "key.pem",
ServerCAFile: "ca.pem",
}
result := config.Config()
assert.Equal(t, config, result)
assert.Equal(t, "http://localhost:8080", result.URL)
assert.Equal(t, 30*time.Second, result.Timeout)
assert.Equal(t, "cert.pem", result.ClientCert)
assert.Equal(t, "key.pem", result.ClientKey)
assert.Equal(t, "ca.pem", result.ServerCAFile)
}
func TestAgentClientConfig_Configuration(t *testing.T) {
agentConfig := &clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "https://agent.example.com",
Timeout: 60 * time.Second,
ClientCert: "agent-cert.pem",
ClientKey: "agent-key.pem",
ServerCAFile: "agent-ca.pem",
},
AttestationPolicy: "policy.json",
AttestedTLS: true,
ProductName: "Milan",
}
result := agentConfig.Config()
assert.Equal(t, agentConfig.StandardClientConfig, result)
assert.Equal(t, "https://agent.example.com", result.URL)
assert.Equal(t, 60*time.Second, result.Timeout)
assert.Equal(t, "agent-cert.pem", result.ClientCert)
assert.Equal(t, "agent-key.pem", result.ClientKey)
assert.Equal(t, "agent-ca.pem", result.ServerCAFile)
}
func TestProxyClientConfig_Configuration(t *testing.T) {
proxyConfig := clients.StandardClientConfig{
URL: "http://proxy.example.com",
Timeout: 45 * time.Second,
ClientCert: "proxy-cert.pem",
ClientKey: "proxy-key.pem",
ServerCAFile: "proxy-ca.pem",
}
result := proxyConfig
assert.Equal(t, proxyConfig, result)
assert.Equal(t, "http://proxy.example.com", result.URL)
assert.Equal(t, 45*time.Second, result.Timeout)
}
func TestNewClient_Success(t *testing.T) {
tests := []struct {
name string
config clients.ClientConfiguration
}{
{
name: "Basic config",
config: clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
},
},
{
name: "Agent config without attested TLS",
config: &clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "https://agent.example.com",
Timeout: 60 * time.Second,
},
AttestedTLS: false,
},
},
{
name: "Proxy config",
config: clients.StandardClientConfig{
URL: "http://proxy.example.com",
Timeout: 45 * time.Second,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.config)
assert.NoError(t, err)
assert.NotNil(t, client)
assert.NotNil(t, client.Transport())
assert.Equal(t, tt.config.Config().Timeout, client.Timeout())
})
}
}
func TestClient_Transport(t *testing.T) {
config := clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
}
client, err := NewClient(config)
assert.NoError(t, err)
transport := client.Transport()
assert.NotNil(t, transport)
assert.IsType(t, &http.Transport{}, transport)
assert.Equal(t, 100, transport.MaxIdleConns)
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
}
func TestClient_Secure(t *testing.T) {
tests := []struct {
name string
config clients.ClientConfiguration
expected string
}{
{
name: "Without TLS",
config: clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
},
expected: clients.WithoutTLS.String(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.config)
assert.NoError(t, err)
secure := client.Secure()
assert.Equal(t, tt.expected, secure)
})
}
}
func TestClient_Timeout(t *testing.T) {
expectedTimeout := 45 * time.Second
config := clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: expectedTimeout,
}
client, err := NewClient(config)
assert.NoError(t, err)
timeout := client.Timeout()
assert.Equal(t, expectedTimeout, timeout)
}
func TestCreateTransport_DefaultSettings(t *testing.T) {
config := clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
}
transport, security, err := createTransport(config)
assert.NoError(t, err)
assert.NotNil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Equal(t, 100, transport.MaxIdleConns)
assert.Equal(t, 90*time.Second, transport.IdleConnTimeout)
assert.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
assert.Nil(t, transport.TLSClientConfig)
}
func TestCreateTransport_ATLSError(t *testing.T) {
config := &clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "https://agent.example.com",
Timeout: 60 * time.Second,
},
AttestationPolicy: "invalid",
AttestedTLS: true,
ProductName: "Milan",
}
transport, security, err := createTransport(config)
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Contains(t, err.Error(), "failed to stat attestation policy")
}
func TestCreateTransport_BasicTLSError(t *testing.T) {
config := clients.StandardClientConfig{
URL: "https://example.com",
Timeout: 30 * time.Second,
ServerCAFile: "invalid",
}
transport, security, err := createTransport(config)
assert.Error(t, err)
assert.Nil(t, transport)
assert.Equal(t, clients.WithoutTLS, security)
assert.Contains(t, err.Error(), "failed to load root ca file")
}
func TestClientInterface_Implementation(t *testing.T) {
config := clients.StandardClientConfig{
URL: "http://localhost:8080",
Timeout: 30 * time.Second,
}
client, err := NewClient(config)
assert.NoError(t, err)
// Verify that client implements the Client interface
var _ Client = client
// Test all interface methods
assert.NotNil(t, client.Transport())
assert.NotEmpty(t, client.Secure())
assert.Greater(t, client.Timeout(), time.Duration(0))
}
func TestAgentClientConfig_FieldAccess(t *testing.T) {
config := &clients.AttestedClientConfig{
StandardClientConfig: clients.StandardClientConfig{
URL: "https://agent.example.com",
Timeout: 60 * time.Second,
},
AttestationPolicy: "test-policy",
AttestedTLS: true,
ProductName: "TestProduct",
}
assert.Equal(t, "test-policy", config.AttestationPolicy)
assert.True(t, config.AttestedTLS)
assert.Equal(t, "TestProduct", config.ProductName)
assert.Equal(t, "https://agent.example.com", config.URL)
assert.Equal(t, 60*time.Second, config.Timeout)
}
func TestProxyClientConfig_FieldAccess(t *testing.T) {
config := clients.StandardClientConfig{
URL: "http://proxy.example.com",
Timeout: 45 * time.Second,
ClientCert: "proxy-cert.pem",
ClientKey: "proxy-key.pem",
ServerCAFile: "proxy-ca.pem",
}
assert.Equal(t, "http://proxy.example.com", config.URL)
assert.Equal(t, 45*time.Second, config.Timeout)
assert.Equal(t, "proxy-cert.pem", config.ClientCert)
assert.Equal(t, "proxy-key.pem", config.ClientKey)
assert.Equal(t, "proxy-ca.pem", config.ServerCAFile)
}
func TestClientConfiguration_Interface(t *testing.T) {
// Test that all config types implement ClientConfiguration interface
var configs []clients.ClientConfiguration
configs = append(configs, clients.StandardClientConfig{})
configs = append(configs, &clients.AttestedClientConfig{})
for i, config := range configs {
t.Run(t.Name()+"_"+string(rune(i+'0')), func(t *testing.T) {
result := config.Config()
assert.IsType(t, clients.StandardClientConfig{}, result)
})
}
}
+180
View File
@@ -0,0 +1,180 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"os"
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/attestation"
)
// Security represents the type of TLS security configuration.
type Security int
const (
WithoutTLS Security = iota
WithTLS
WithMTLS
WithATLS
WithMATLS
)
// String returns a human-readable representation of the security level.
func (s Security) String() string {
switch s {
case WithTLS:
return "with TLS"
case WithMTLS:
return "with mTLS"
case WithATLS:
return "with aTLS"
case WithMATLS:
return "with maTLS"
case WithoutTLS:
return "without TLS"
default:
return "without TLS"
}
}
const AttestationReportSize = 0x4A0
var (
ErrFailedToLoadClientCertKey = errors.New("failed to load client certificate and key")
ErrFailedToLoadRootCA = errors.New("failed to load root ca file")
errAttestationPolicyIrregular = errors.New("attestation policy file is not a regular file")
)
// TLSResult contains the result of TLS configuration.
type TLSResult struct {
Config *tls.Config
Security Security
}
// LoadBasicTLSConfig loads standard TLS configuration (TLS/mTLS).
func LoadBasicTLSConfig(serverCAFile, clientCert, clientKey string) (*TLSResult, error) {
tlsConfig := &tls.Config{}
security := WithoutTLS
// If no TLS configuration is provided, return nil config (no TLS)
if serverCAFile == "" && clientCert == "" && clientKey == "" {
return &TLSResult{Config: nil, Security: security}, nil
}
if serverCAFile != "" {
rootCA, err := os.ReadFile(serverCAFile)
if err != nil {
return nil, errors.Wrap(ErrFailedToLoadRootCA, err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, errors.New("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
security = WithTLS
}
}
if clientCert != "" || clientKey != "" {
certificate, err := tls.LoadX509KeyPair(clientCert, clientKey)
if err != nil {
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
security = WithMTLS
}
return &TLSResult{Config: tlsConfig, Security: security}, nil
}
// LoadATLSConfig configures Attested TLS.
func LoadATLSConfig(cfg AttestedClientConfig) (*TLSResult, error) {
security := WithATLS
info, err := os.Stat(cfg.AttestationPolicy)
if err != nil {
return nil, errors.Wrap(errors.New("failed to stat attestation policy file"), err)
}
if !info.Mode().IsRegular() {
return nil, errAttestationPolicyIrregular
}
attestation.AttestationPolicyPath = cfg.AttestationPolicy
var rootCAs *x509.CertPool
if cfg.ServerCAFile != "" {
rootCAs, err = loadRootCAs(cfg.ServerCAFile)
if err != nil {
return nil, err
}
security = WithMATLS
}
nonce := make([]byte, 64)
if _, err := rand.Read(nonce); err != nil {
return nil, errors.Wrap(errors.New("failed to generate nonce"), err)
}
encoded := hex.EncodeToString(nonce)
sni := encoded + ".nonce"
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
RootCAs: rootCAs,
ServerName: sni,
VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return atls.VerifyPeerCertificateATLS(rawCerts, verifiedChains, nonce, rootCAs)
},
}
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, errors.Wrap(ErrFailedToLoadClientCertKey, err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
}
return &TLSResult{Config: tlsConfig, Security: security}, nil
}
// loadRootCAs loads root CA certificates from a file.
func loadRootCAs(serverCAFile string) (*x509.CertPool, error) {
// Read the certificate file
certPEM, err := os.ReadFile(serverCAFile)
if err != nil {
return nil, errors.Wrap(errors.New("failed to read certificate file"), err)
}
// Decode the PEM block
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, errors.New("failed to decode PEM block")
}
// Parse the certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, errors.Wrap(errors.New("failed to parse certificate"), err)
}
rootCAs := x509.NewCertPool()
rootCAs.AddCert(cert)
return rootCAs, nil
}
+440
View File
@@ -0,0 +1,440 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package clients
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSecurity_String(t *testing.T) {
tests := []struct {
name string
security Security
expected string
}{
{
name: "WithoutTLS",
security: WithoutTLS,
expected: "without TLS",
},
{
name: "WithTLS",
security: WithTLS,
expected: "with TLS",
},
{
name: "WithMTLS",
security: WithMTLS,
expected: "with mTLS",
},
{
name: "WithATLS",
security: WithATLS,
expected: "with aTLS",
},
{
name: "WithMATLS",
security: WithMATLS,
expected: "with maTLS",
},
{
name: "InvalidSecurity",
security: Security(999),
expected: "without TLS",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, tt.security.String())
})
}
}
func TestLoadBasicTLSConfig(t *testing.T) {
// Create temporary directory for test files
tmpDir := t.TempDir()
// Generate test certificate and key
cert, key, caPEM := generateTestCertificates(t)
certFile := filepath.Join(tmpDir, "client.crt")
keyFile := filepath.Join(tmpDir, "client.key")
caFile := filepath.Join(tmpDir, "ca.crt")
require.NoError(t, os.WriteFile(certFile, cert, 0o644))
require.NoError(t, os.WriteFile(keyFile, key, 0o644))
require.NoError(t, os.WriteFile(caFile, caPEM, 0o644))
tests := []struct {
name string
serverCAFile string
clientCert string
clientKey string
expectedSec Security
expectedConfig bool
expectError bool
}{
{
name: "NoTLS",
serverCAFile: "",
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectedConfig: false,
expectError: false,
},
{
name: "TLSOnly",
serverCAFile: caFile,
clientCert: "",
clientKey: "",
expectedSec: WithTLS,
expectedConfig: true,
expectError: false,
},
{
name: "MTLS",
serverCAFile: caFile,
clientCert: certFile,
clientKey: keyFile,
expectedSec: WithMTLS,
expectedConfig: true,
expectError: false,
},
{
name: "MTLSWithoutCA",
serverCAFile: "",
clientCert: certFile,
clientKey: keyFile,
expectedSec: WithMTLS,
expectedConfig: true,
expectError: false,
},
{
name: "InvalidCAFile",
serverCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
clientCert: "",
clientKey: "",
expectedSec: WithoutTLS,
expectedConfig: false,
expectError: true,
},
{
name: "InvalidCertFile",
serverCAFile: "",
clientCert: filepath.Join(tmpDir, "nonexistent.crt"),
clientKey: keyFile,
expectedSec: WithoutTLS,
expectedConfig: false,
expectError: true,
},
{
name: "InvalidKeyFile",
serverCAFile: "",
clientCert: certFile,
clientKey: filepath.Join(tmpDir, "nonexistent.key"),
expectedSec: WithoutTLS,
expectedConfig: false,
expectError: true,
},
{
name: "MismatchedCertKey",
serverCAFile: "",
clientCert: caFile, // Using CA file as cert (wrong format)
clientKey: keyFile,
expectedSec: WithoutTLS,
expectedConfig: false,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LoadBasicTLSConfig(tt.serverCAFile, tt.clientCert, tt.clientKey)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, tt.expectedSec, result.Security)
if tt.expectedConfig {
assert.NotNil(t, result.Config)
} else {
assert.Nil(t, result.Config)
}
})
}
}
func TestLoadATLSConfig(t *testing.T) {
tmpDir := t.TempDir()
// Create test files
cert, key, caPEM := generateTestCertificates(t)
certFile := filepath.Join(tmpDir, "client.crt")
keyFile := filepath.Join(tmpDir, "client.key")
caFile := filepath.Join(tmpDir, "ca.crt")
policyFile := filepath.Join(tmpDir, "policy.json")
require.NoError(t, os.WriteFile(certFile, cert, 0o644))
require.NoError(t, os.WriteFile(keyFile, key, 0o644))
require.NoError(t, os.WriteFile(caFile, caPEM, 0o644))
require.NoError(t, os.WriteFile(policyFile, []byte(`{"policy": "test"}`), 0o644))
tests := []struct {
name string
config AttestedClientConfig
expectedSec Security
expectError bool
errorMsg string
}{
{
name: "ValidATLSConfig",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: "",
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithATLS,
expectError: false,
},
{
name: "ValidMATLSConfig",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: caFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithMATLS,
expectError: false,
},
{
name: "ValidATLSWithClientCert",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ClientCert: certFile,
ClientKey: keyFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithATLS,
expectError: false,
},
{
name: "NonexistentPolicyFile",
config: AttestedClientConfig{
AttestationPolicy: filepath.Join(tmpDir, "nonexistent.json"),
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to stat attestation policy file",
},
{
name: "PolicyFileIsDirectory",
config: AttestedClientConfig{
AttestationPolicy: tmpDir, // Directory instead of file
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "attestation policy file is not a regular file",
},
{
name: "InvalidCAFile",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ServerCAFile: filepath.Join(tmpDir, "nonexistent.crt"),
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
errorMsg: "failed to read certificate file",
},
{
name: "InvalidClientCert",
config: AttestedClientConfig{
StandardClientConfig: StandardClientConfig{
ClientCert: filepath.Join(tmpDir, "nonexistent.crt"),
ClientKey: keyFile,
},
AttestationPolicy: policyFile,
ProductName: "test-product",
},
expectedSec: WithoutTLS,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := LoadATLSConfig(tt.config)
if tt.expectError {
assert.Error(t, err)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
return
}
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, tt.expectedSec, result.Security)
assert.NotNil(t, result.Config)
// Verify TLS config properties
assert.True(t, result.Config.InsecureSkipVerify)
assert.NotNil(t, result.Config.VerifyPeerCertificate)
assert.NotEmpty(t, result.Config.ServerName)
assert.Contains(t, result.Config.ServerName, ".nonce")
})
}
}
func TestLoadRootCAs(t *testing.T) {
tmpDir := t.TempDir()
// Generate test certificate
_, _, caPEM := generateTestCertificates(t)
validCAFile := filepath.Join(tmpDir, "valid_ca.crt")
invalidCAFile := filepath.Join(tmpDir, "invalid_ca.crt")
nonExistentFile := filepath.Join(tmpDir, "nonexistent.crt")
require.NoError(t, os.WriteFile(validCAFile, caPEM, 0o644))
require.NoError(t, os.WriteFile(invalidCAFile, []byte("invalid pem data"), 0o644))
tests := []struct {
name string
caFile string
expectError bool
errorMsg string
}{
{
name: "ValidCAFile",
caFile: validCAFile,
expectError: false,
},
{
name: "NonExistentFile",
caFile: nonExistentFile,
expectError: true,
errorMsg: "failed to read certificate file",
},
{
name: "InvalidPEMData",
caFile: invalidCAFile,
expectError: true,
errorMsg: "failed to decode PEM block",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rootCAs, err := loadRootCAs(tt.caFile)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, rootCAs)
if tt.errorMsg != "" {
assert.Contains(t, err.Error(), tt.errorMsg)
}
return
}
assert.NoError(t, err)
assert.NotNil(t, rootCAs)
})
}
}
// Helper functions for generating test certificates
func generateTestCertificates(t *testing.T) (certPEM, keyPEM, caPEM []byte) {
// Generate CA certificate
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
caTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
require.NoError(t, err)
caPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: caCertDER,
})
// Generate client certificate
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
clientTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Client"},
Country: []string{"US"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: clientCertDER,
})
clientKeyDER, err := x509.MarshalPKCS8PrivateKey(clientKey)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: clientKeyDER,
})
return certPEM, keyPEM, caPEM
}
@@ -6,19 +6,16 @@ package grpc
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"net"
"os"
"strings"
"sync"
"time"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/internal/server"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -110,39 +107,13 @@ func (s *Server) Start() error {
GetCertificate: atls.GetCertificate(s.caUrl, s.cvmId),
}
var mtls bool
mtls = false
// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, c.ServerCAFile, c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to load server ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append server ca to tls.Config")
}
mtls = true
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
mtls = true
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
@@ -155,52 +126,17 @@ func (s *Server) Start() error {
} else {
switch {
case c.CertFile != "" || c.KeyFile != "":
certificate, err := loadX509KeyPair(c.CertFile, c.KeyFile)
tlsSetup, err := server.SetupRegularTLS(c.CertFile, c.KeyFile, c.ServerCAFile, c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
return fmt.Errorf("failed to setup TLS: %w", err)
}
var mtlsCA string
// Loading Server CA file
rootCA, err := loadCertFile(c.ServerCAFile)
if err != nil {
return fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return fmt.Errorf("failed to append root ca to tls.Config")
}
mtlsCA = fmt.Sprintf("root ca %s", c.ServerCAFile)
}
creds = grpc.Creds(credentials.NewTLS(tlsSetup.Config))
// Loading Client CA File
clientCA, err := loadCertFile(c.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) > 0 {
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return fmt.Errorf("failed to append client ca to tls.Config")
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, c.ClientCAFile)
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
switch {
case mtlsCA != "":
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
if tlsSetup.MTLS {
mtlsCA := server.BuildMTLSDescription(c.ServerCAFile, c.ClientCAFile)
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, c.CertFile, c.KeyFile, mtlsCA))
default:
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, c.CertFile, c.KeyFile))
}
default:
@@ -272,36 +208,3 @@ func (s *Server) Stop() error {
s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address))
return nil
}
func loadCertFile(certFile string) ([]byte, error) {
if certFile != "" {
return readFileOrData(certFile)
}
return []byte{}, nil
}
func readFileOrData(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
} else {
return nil, err
}
}
return []byte(input), nil
}
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
cert, err := readFileOrData(certfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
}
key, err := readFileOrData(keyfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
}
return tls.X509KeyPair(cert, key)
}
@@ -20,8 +20,8 @@ import (
"github.com/stretchr/testify/assert"
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
"github.com/ultravioletrs/cocos/internal/server"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/server"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
@@ -40,7 +40,7 @@ func TestNew(t *testing.T) {
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "50051",
},
@@ -85,7 +85,7 @@ func TestServerStartWithTLSFile(t *testing.T) {
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
CertFile: certFile.Name(),
@@ -130,7 +130,7 @@ func TestServerStartWithmTLSFile(t *testing.T) {
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
CertFile: string(clientCertFile),
@@ -173,7 +173,7 @@ func TestServerStop(t *testing.T) {
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
},
@@ -268,7 +268,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
name: "Non-TLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
},
@@ -280,7 +280,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
name: "TLS Server Startup with Self-Signed Certificate",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
},
@@ -293,7 +293,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
name: "TLS Server Startup with Invalid Certificates",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
CertFile: "invalid",
@@ -308,7 +308,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
name: "maTLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "",
@@ -325,7 +325,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
name: "maTLS Server Startup with Invalid Server CA file",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
@@ -341,7 +341,7 @@ func TestServerInitializationAndStartup(t *testing.T) {
name: "maTLS Server Startup with Invalid Clinet CA file",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
+173
View File
@@ -0,0 +1,173 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net/http"
smqserver "github.com/absmach/supermq/pkg/server"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
)
const (
httpProtocol = "http"
httpsProtocol = "https"
)
type httpServer struct {
server.BaseServer
server *http.Server
caURL string
}
var _ server.Server = (*httpServer)(nil)
func NewServer(
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
handler http.Handler, logger *slog.Logger, caURL string,
) server.Server {
baseServer := server.NewBaseServer(ctx, cancel, name, config, logger)
hserver := &http.Server{Addr: baseServer.Address, Handler: handler}
return &httpServer{
BaseServer: baseServer,
server: hserver,
caURL: caURL,
}
}
func (s *httpServer) Start() error {
s.Protocol = httpProtocol
if s.shouldUseAttestedTLS() {
return s.startWithAttestedTLS()
}
if s.shouldUseRegularTLS() {
return s.startWithRegularTLS()
}
return s.startWithoutTLS()
}
func (s *httpServer) Stop() error {
defer s.Cancel()
ctx, cancel := context.WithTimeout(context.Background(), smqserver.StopWaitTime)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
s.Logger.Error(fmt.Sprintf(
"%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err))
return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err)
}
s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address))
return nil
}
func (s *httpServer) shouldUseAttestedTLS() bool {
cfg, ok := s.Config.(server.AgentConfig)
if !ok {
return false
}
return cfg.AttestedTLS && s.caURL != ""
}
func (s *httpServer) shouldUseRegularTLS() bool {
return s.Config.GetBaseConfig().CertFile != "" || s.Config.GetBaseConfig().KeyFile != ""
}
func (s *httpServer) startWithAttestedTLS() error {
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
GetCertificate: atls.GetCertificate(s.caURL, ""),
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
s.server.TLSConfig = tlsConfig
s.Protocol = httpsProtocol
s.logAttestedTLSStart(mtls)
return s.listenAndServe(true)
}
func (s *httpServer) startWithRegularTLS() error {
tlsSetup, err := server.SetupRegularTLS(s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile, s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup TLS: %w", err)
}
s.server.TLSConfig = tlsSetup.Config
s.Protocol = httpsProtocol
s.logRegularTLSStart(tlsSetup.MTLS)
return s.listenAndServe(true)
}
func (s *httpServer) startWithoutTLS() error {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address))
return s.listenAndServe(false)
}
func (s *httpServer) logAttestedTLSStart(mtls bool) {
if mtls {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested mTLS", s.Name, s.Protocol, s.Address))
} else {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested TLS", s.Name, s.Protocol, s.Address))
}
}
func (s *httpServer) logRegularTLSStart(mtls bool) {
if mtls {
s.Logger.Info(fmt.Sprintf(
"%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s",
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile,
s.Config.GetBaseConfig().ServerCAFile, s.Config.GetBaseConfig().ClientCAFile))
} else {
s.Logger.Info(
fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
s.Name, s.Protocol, s.Address, s.Config.GetBaseConfig().CertFile, s.Config.GetBaseConfig().KeyFile))
}
}
func (s *httpServer) listenAndServe(useTLS bool) error {
errCh := make(chan error, 1)
go func() {
if useTLS {
cfg := s.Config.GetBaseConfig()
errCh <- s.server.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
} else {
errCh <- s.server.ListenAndServe()
}
}()
select {
case <-s.Ctx.Done():
return s.Stop()
case err := <-errCh:
return err
}
}
+411
View File
@@ -0,0 +1,411 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"crypto/tls"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/server"
)
// Mock implementations for testing.
type mockHandler struct {
mock.Mock
}
func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.Called(w, r)
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("test response")); err != nil {
panic(err)
}
}
type mockBaseConfig struct {
certFile string
keyFile string
serverCAFile string
clientCAFile string
host string
port string
}
func (m *mockBaseConfig) GetCertFile() string { return m.certFile }
func (m *mockBaseConfig) GetKeyFile() string { return m.keyFile }
func (m *mockBaseConfig) GetServerCAFile() string { return m.serverCAFile }
func (m *mockBaseConfig) GetClientCAFile() string { return m.clientCAFile }
type mockServerConfig struct {
baseConfig *mockBaseConfig
}
func (m *mockServerConfig) GetHost() string { return "localhost" }
func (m *mockServerConfig) GetPort() string { return "8080" }
func (m *mockServerConfig) GetBaseConfig() server.ServerConfig {
return server.ServerConfig{Config: server.Config{CertFile: m.baseConfig.certFile, KeyFile: m.baseConfig.keyFile, ServerCAFile: m.baseConfig.serverCAFile, ClientCAFile: m.baseConfig.clientCAFile, Host: m.baseConfig.host, Port: m.baseConfig.port}}
}
func TestNewServer(t *testing.T) {
ctx := context.Background()
cancel := func() {}
name := "test-server"
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
logger := slog.Default()
caURL := "https://ca.example.com"
server := NewServer(ctx, cancel, name, config, handler, logger, caURL)
assert.NotNil(t, server)
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.Equal(t, caURL, httpSrv.caURL)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, handler, httpSrv.server.Handler)
}
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
tests := []struct {
name string
config server.ServerConfiguration
caURL string
attestedTLS bool
expected bool
}{
{
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and caURL is not empty",
config: server.AgentConfig{
AttestedTLS: true,
},
caURL: "https://ca.example.com",
expected: true,
},
{
name: "should not use attested TLS when caURL is empty",
config: server.AgentConfig{
AttestedTLS: true,
},
caURL: "",
expected: false,
},
{
name: "should not use attested TLS when AttestedTLS is false",
config: server.AgentConfig{
AttestedTLS: false,
},
caURL: "https://ca.example.com",
expected: false,
},
{
name: "should not use attested TLS when config is not AgentConfig",
config: &mockServerConfig{
baseConfig: &mockBaseConfig{},
},
caURL: "https://ca.example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.caURL)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseAttestedTLS()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
tests := []struct {
name string
certFile string
keyFile string
expected bool
}{
{
name: "should use regular TLS when both cert and key files are provided",
certFile: "cert.pem",
keyFile: "key.pem",
expected: true,
},
{
name: "should use regular TLS when only cert file is provided",
certFile: "cert.pem",
keyFile: "",
expected: true,
},
{
name: "should use regular TLS when only key file is provided",
certFile: "",
keyFile: "key.pem",
expected: true,
},
{
name: "should not use regular TLS when neither cert nor key files are provided",
certFile: "",
keyFile: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: tt.certFile,
keyFile: tt.keyFile,
},
}
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), "")
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseRegularTLS()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHttpServer_Stop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
httpSrv := server.(*httpServer)
// Start a test server that we can control
testServer := httptest.NewServer(handler)
defer testServer.Close()
// Replace the server's HTTP server with our test server's
httpSrv.server = testServer.Config
err := httpSrv.Stop()
assert.NoError(t, err)
}
func TestHttpServer_logAttestedTLSStart(t *testing.T) {
tests := []struct {
name string
mtls bool
}{
{
name: "log attested mTLS start",
mtls: true,
},
{
name: "log attested TLS start",
mtls: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
// In a real scenario, you might want to capture log output
assert.NotPanics(t, func() {
httpSrv.logAttestedTLSStart(tt.mtls)
})
})
}
}
func TestHttpServer_logRegularTLSStart(t *testing.T) {
tests := []struct {
name string
mtls bool
}{
{
name: "log regular mTLS start",
mtls: true,
},
{
name: "log regular TLS start",
mtls: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: "cert.pem",
keyFile: "key.pem",
serverCAFile: "server-ca.pem",
clientCAFile: "client-ca.pem",
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
assert.NotPanics(t, func() {
httpSrv.logRegularTLSStart(tt.mtls)
})
})
}
}
func TestHttpServer_startWithoutTLS(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
httpSrv := server.(*httpServer)
// Use a test server to avoid binding to actual ports
testServer := httptest.NewServer(handler)
defer testServer.Close()
httpSrv.server = testServer.Config
err := httpSrv.startWithoutTLS()
// The error will be related to context cancellation or server shutdown
assert.Error(t, err)
}
func TestHttpServer_Protocol(t *testing.T) {
tests := []struct {
name string
setupTLS func(*httpServer)
expectedProto string
}{
{
name: "HTTP protocol without TLS",
setupTLS: func(s *httpServer) {
s.Protocol = httpProtocol
},
expectedProto: httpProtocol,
},
{
name: "HTTPS protocol with TLS",
setupTLS: func(s *httpServer) {
s.Protocol = httpsProtocol
s.server.TLSConfig = &tls.Config{}
},
expectedProto: httpsProtocol,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
httpSrv := server.(*httpServer)
tt.setupTLS(httpSrv)
assert.Equal(t, tt.expectedProto, httpSrv.Protocol)
})
}
}
func TestHttpServer_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
httpSrv := server.(*httpServer)
// Cancel the context immediately
cancel()
// The listenAndServe method should handle context cancellation
err := httpSrv.listenAndServe(false)
assert.NoError(t, err) // Should return no error when context is cancelled and Stop() succeeds
}
func TestHttpServer_TLSConfiguration(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: "cert.pem",
keyFile: "key.pem",
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), "")
httpSrv := server.(*httpServer)
// Test TLS configuration setup
httpSrv.server.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
assert.NotNil(t, httpSrv.server.TLSConfig)
assert.Equal(t, uint16(tls.VersionTLS12), httpSrv.server.TLSConfig.MinVersion)
}
// Integration-style test for server lifecycle.
func TestHttpServer_Lifecycle(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
host: "localhost",
port: "8080",
},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), "")
// Test that server can be created and has expected initial state
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, "localhost:8080", httpSrv.server.Addr)
// Test Stop without Start (should not panic)
err := httpSrv.Stop()
assert.NoError(t, err)
}
@@ -20,7 +20,7 @@ type ServerConfiguration interface {
GetBaseConfig() ServerConfig
}
type BaseConfig struct {
type Config struct {
Host string `env:"HOST" envDefault:"localhost"`
Port string `env:"PORT" envDefault:"7001"`
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
@@ -30,7 +30,7 @@ type BaseConfig struct {
}
type ServerConfig struct {
BaseConfig
Config
}
type AgentConfig struct {
ServerConfig
@@ -55,19 +55,20 @@ func (a AgentConfig) GetBaseConfig() ServerConfig {
return a.ServerConfig
}
func stopAllServer(servers ...Server) error {
var errs []error
for _, server := range servers {
if err := server.Stop(); err != nil {
errs = append(errs, err)
}
}
func NewBaseServer(
ctx context.Context, cancel context.CancelFunc, name string, config ServerConfiguration, logger *slog.Logger,
) BaseServer {
cfg := config.GetBaseConfig()
address := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
if len(errs) > 0 {
return fmt.Errorf("encountered errors while stopping servers: %v", errs)
return BaseServer{
Ctx: ctx,
Cancel: cancel,
Name: name,
Address: address,
Config: config,
Logger: logger,
}
return nil
}
func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, svcName string, servers ...Server) error {
@@ -87,3 +88,18 @@ func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Lo
return nil
}
}
func stopAllServer(servers ...Server) error {
var errs []error
for _, server := range servers {
if err := server.Stop(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return fmt.Errorf("encountered errors while stopping servers: %v", errs)
}
return nil
}
@@ -11,7 +11,7 @@ import (
"testing"
"time"
"github.com/ultravioletrs/cocos/internal/server/mocks"
"github.com/ultravioletrs/cocos/pkg/server/mocks"
)
func TestStopAllServer(t *testing.T) {
+161
View File
@@ -0,0 +1,161 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"os"
"strings"
)
var (
ErrAppendServerCA = errors.New("failed to append server ca to tls.Config")
ErrAppendClientCA = errors.New("failed to append client ca to tls.Config")
)
// TLSSetupResult contains the result of TLS configuration setup.
type TLSSetupResult struct {
Config *tls.Config
MTLS bool
}
// LoadCertFile loads certificate data from file path or returns empty byte slice if path is empty.
func LoadCertFile(certFile string) ([]byte, error) {
if certFile != "" {
return ReadFileOrData(certFile)
}
return []byte{}, nil
}
// ReadFileOrData reads data from file if input looks like a file path,
// otherwise treats input as raw data.
func ReadFileOrData(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
}
return nil, err
}
return []byte(input), nil
}
// LoadX509KeyPair loads X.509 key pair from certificate and key files or data.
func LoadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
cert, err := ReadFileOrData(certfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read cert: %w", err)
}
key, err := ReadFileOrData(keyfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read key: %w", err)
}
return tls.X509KeyPair(cert, key)
}
// ConfigureRootCA configures the root CA certificates for the TLS config.
func ConfigureRootCA(tlsConfig *tls.Config, serverCAFile string) error {
rootCA, err := LoadCertFile(serverCAFile)
if err != nil {
return fmt.Errorf("failed to load server ca file: %w", err)
}
if len(rootCA) == 0 {
return nil
}
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCA) {
return ErrAppendServerCA
}
return nil
}
// ConfigureClientCA configures the client CA certificates for the TLS config
// Returns true if client CA was configured, false otherwise.
func ConfigureClientCA(tlsConfig *tls.Config, clientCAFile string) (bool, error) {
clientCA, err := LoadCertFile(clientCAFile)
if err != nil {
return false, fmt.Errorf("failed to load client ca file: %w", err)
}
if len(clientCA) == 0 {
return false, nil
}
if tlsConfig.ClientCAs == nil {
tlsConfig.ClientCAs = x509.NewCertPool()
}
if !tlsConfig.ClientCAs.AppendCertsFromPEM(clientCA) {
return false, ErrAppendClientCA
}
return true, nil
}
// ConfigureCertificateAuthorities configures both root and client CAs for the TLS config
// Returns true if mTLS should be enabled (client CA is configured).
func ConfigureCertificateAuthorities(tlsConfig *tls.Config, serverCAFile, clientCAFile string) (bool, error) {
// Configure root CA
if err := ConfigureRootCA(tlsConfig, serverCAFile); err != nil {
return false, err
}
// Configure client CA
hasClientCA, err := ConfigureClientCA(tlsConfig, clientCAFile)
if err != nil {
return false, err
}
return hasClientCA, nil
}
// SetupRegularTLS sets up TLS configuration using regular certificates.
func SetupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*TLSSetupResult, error) {
certificate, err := LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}
mtls, err := ConfigureCertificateAuthorities(tlsConfig, serverCAFile, clientCAFile)
if err != nil {
return nil, err
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
return &TLSSetupResult{Config: tlsConfig, MTLS: mtls}, nil
}
// BuildMTLSDescription builds a description string for mTLS configuration.
func BuildMTLSDescription(serverCAFile, clientCAFile string) string {
var parts []string
if serverCAFile != "" {
parts = append(parts, fmt.Sprintf("root ca %s", serverCAFile))
}
if clientCAFile != "" {
parts = append(parts, fmt.Sprintf("client ca %s", clientCAFile))
}
return strings.Join(parts, " ")
}
+741
View File
@@ -0,0 +1,741 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"strings"
"testing"
"time"
)
// Helper function to generate a test certificate and key.
func generateTestCert() (certPEM, keyPEM []byte, err error) {
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test City"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
IPAddresses: nil,
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, nil, err
}
// Encode certificate
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
// Encode private key
privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, nil, err
}
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: privateKeyDER,
})
return certPEM, keyPEM, nil
}
// Helper function to create temporary files for testing.
func createTempFile(t *testing.T, content []byte) string {
tmpFile, err := os.CreateTemp("", "test-cert-*.pem")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer tmpFile.Close()
if _, err := tmpFile.Write(content); err != nil {
t.Fatalf("Failed to write temp file: %v", err)
}
return tmpFile.Name()
}
func TestLoadCertFile(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
tests := []struct {
name string
certFile string
wantErr bool
setup func() string
cleanup func(string)
}{
{
name: "empty cert file path",
certFile: "",
wantErr: false,
},
{
name: "valid cert file",
wantErr: false,
setup: func() string {
return createTempFile(t, certPEM)
},
cleanup: func(path string) {
os.Remove(path)
},
},
{
name: "non-existent file",
certFile: "/non/existent/file.pem",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
certFile := tt.certFile
if tt.setup != nil {
certFile = tt.setup()
}
if tt.cleanup != nil {
defer tt.cleanup(certFile)
}
data, err := LoadCertFile(certFile)
if (err != nil) != tt.wantErr {
t.Errorf("LoadCertFile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.certFile != "" && !tt.wantErr && len(data) == 0 {
t.Errorf("LoadCertFile() with valid file should return data, got empty")
}
})
}
}
func TestReadFileOrData(t *testing.T) {
testData := "test certificate data"
tempFile := createTempFile(t, []byte(testData))
defer os.Remove(tempFile)
tests := []struct {
name string
input string
want string
wantErr bool
}{
{
name: "file path",
input: tempFile,
want: testData,
wantErr: false,
},
{
name: "raw data with newlines",
input: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
want: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
wantErr: false,
},
{
name: "short raw data without newlines",
input: "short data",
want: "short data",
wantErr: true,
},
{
name: "non-existent file path",
input: "/non/existent/file.pem",
want: "",
wantErr: true,
},
{
name: "empty input",
input: "",
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ReadFileOrData(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && string(got) != tt.want {
t.Errorf("ReadFileOrData() = %v, want %v", string(got), tt.want)
}
})
}
}
func TestLoadX509KeyPair(t *testing.T) {
certPEM, keyPEM, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
certFile := createTempFile(t, certPEM)
keyFile := createTempFile(t, keyPEM)
defer os.Remove(certFile)
defer os.Remove(keyFile)
tests := []struct {
name string
certfile string
keyfile string
wantErr bool
}{
{
name: "valid cert and key files",
certfile: certFile,
keyfile: keyFile,
wantErr: false,
},
{
name: "valid cert and key data",
certfile: string(certPEM),
keyfile: string(keyPEM),
wantErr: false,
},
{
name: "non-existent cert file",
certfile: "/non/existent/cert.pem",
keyfile: keyFile,
wantErr: true,
},
{
name: "non-existent key file",
certfile: certFile,
keyfile: "/non/existent/key.pem",
wantErr: true,
},
{
name: "invalid cert data",
certfile: "invalid cert data",
keyfile: string(keyPEM),
wantErr: true,
},
{
name: "invalid key data",
certfile: string(certPEM),
keyfile: "invalid key data",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cert, err := LoadX509KeyPair(tt.certfile, tt.keyfile)
if (err != nil) != tt.wantErr {
t.Errorf("LoadX509KeyPair() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(cert.Certificate) == 0 {
t.Errorf("LoadX509KeyPair() returned empty certificate")
}
})
}
}
func TestConfigureRootCA(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
caFile := createTempFile(t, certPEM)
defer os.Remove(caFile)
tests := []struct {
name string
tlsConfig *tls.Config
serverCAFile string
wantErr bool
expectCA bool
}{
{
name: "valid CA file",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
wantErr: false,
expectCA: true,
},
{
name: "valid CA data",
tlsConfig: &tls.Config{},
serverCAFile: string(certPEM),
wantErr: false,
expectCA: true,
},
{
name: "empty CA file",
tlsConfig: &tls.Config{},
serverCAFile: "",
wantErr: false,
expectCA: false,
},
{
name: "non-existent CA file",
tlsConfig: &tls.Config{},
serverCAFile: "/non/existent/ca.pem",
wantErr: true,
expectCA: false,
},
{
name: "invalid CA data",
tlsConfig: &tls.Config{},
serverCAFile: "invalid ca data",
wantErr: true,
expectCA: false,
},
{
name: "existing RootCAs pool",
tlsConfig: &tls.Config{RootCAs: x509.NewCertPool()},
serverCAFile: caFile,
wantErr: false,
expectCA: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ConfigureRootCA(tt.tlsConfig, tt.serverCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("ConfigureRootCA() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.expectCA && tt.tlsConfig.RootCAs == nil {
t.Errorf("ConfigureRootCA() should have created RootCAs pool")
}
if !tt.expectCA && tt.tlsConfig.RootCAs != nil && tt.serverCAFile == "" {
t.Errorf("ConfigureRootCA() should not have created RootCAs pool for empty file")
}
})
}
}
func TestConfigureClientCA(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
caFile := createTempFile(t, certPEM)
defer os.Remove(caFile)
tests := []struct {
name string
tlsConfig *tls.Config
clientCAFile string
wantConfigured bool
wantErr bool
}{
{
name: "valid client CA file",
tlsConfig: &tls.Config{},
clientCAFile: caFile,
wantConfigured: true,
wantErr: false,
},
{
name: "valid client CA data",
tlsConfig: &tls.Config{},
clientCAFile: string(certPEM),
wantConfigured: true,
wantErr: false,
},
{
name: "empty client CA file",
tlsConfig: &tls.Config{},
clientCAFile: "",
wantConfigured: false,
wantErr: false,
},
{
name: "non-existent client CA file",
tlsConfig: &tls.Config{},
clientCAFile: "/non/existent/ca.pem",
wantConfigured: false,
wantErr: true,
},
{
name: "invalid client CA data",
tlsConfig: &tls.Config{},
clientCAFile: "invalid ca data",
wantConfigured: false,
wantErr: true,
},
{
name: "existing ClientCAs pool",
tlsConfig: &tls.Config{ClientCAs: x509.NewCertPool()},
clientCAFile: caFile,
wantConfigured: true,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configured, err := ConfigureClientCA(tt.tlsConfig, tt.clientCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("ConfigureClientCA() error = %v, wantErr %v", err, tt.wantErr)
return
}
if configured != tt.wantConfigured {
t.Errorf("ConfigureClientCA() configured = %v, want %v", configured, tt.wantConfigured)
}
if tt.wantConfigured && tt.tlsConfig.ClientCAs == nil {
t.Errorf("ConfigureClientCA() should have created ClientCAs pool")
}
})
}
}
func TestConfigureCertificateAuthorities(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
caFile := createTempFile(t, certPEM)
defer os.Remove(caFile)
tests := []struct {
name string
tlsConfig *tls.Config
serverCAFile string
clientCAFile string
wantMTLS bool
wantErr bool
}{
{
name: "both server and client CA",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
clientCAFile: caFile,
wantMTLS: true,
wantErr: false,
},
{
name: "only server CA",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
clientCAFile: "",
wantMTLS: false,
wantErr: false,
},
{
name: "only client CA",
tlsConfig: &tls.Config{},
serverCAFile: "",
clientCAFile: caFile,
wantMTLS: true,
wantErr: false,
},
{
name: "no CAs",
tlsConfig: &tls.Config{},
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: false,
},
{
name: "invalid server CA",
tlsConfig: &tls.Config{},
serverCAFile: "/non/existent/server-ca.pem",
clientCAFile: caFile,
wantMTLS: false,
wantErr: true,
},
{
name: "invalid client CA",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
clientCAFile: "/non/existent/client-ca.pem",
wantMTLS: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mtls, err := ConfigureCertificateAuthorities(tt.tlsConfig, tt.serverCAFile, tt.clientCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("ConfigureCertificateAuthorities() error = %v, wantErr %v", err, tt.wantErr)
return
}
if mtls != tt.wantMTLS {
t.Errorf("ConfigureCertificateAuthorities() mtls = %v, want %v", mtls, tt.wantMTLS)
}
})
}
}
func TestSetupRegularTLS(t *testing.T) {
certPEM, keyPEM, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
certFile := createTempFile(t, certPEM)
keyFile := createTempFile(t, keyPEM)
caFile := createTempFile(t, certPEM)
defer func() {
os.Remove(certFile)
os.Remove(keyFile)
os.Remove(caFile)
}()
tests := []struct {
name string
certFile string
keyFile string
serverCAFile string
clientCAFile string
wantMTLS bool
wantErr bool
expectedAuth tls.ClientAuthType
}{
{
name: "regular TLS without mTLS",
certFile: certFile,
keyFile: keyFile,
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: false,
expectedAuth: tls.NoClientCert,
},
{
name: "TLS with mTLS",
certFile: certFile,
keyFile: keyFile,
serverCAFile: caFile,
clientCAFile: caFile,
wantMTLS: true,
wantErr: false,
expectedAuth: tls.RequireAndVerifyClientCert,
},
{
name: "TLS with only server CA",
certFile: certFile,
keyFile: keyFile,
serverCAFile: caFile,
clientCAFile: "",
wantMTLS: false,
wantErr: false,
expectedAuth: tls.NoClientCert,
},
{
name: "invalid certificate file",
certFile: "/non/existent/cert.pem",
keyFile: keyFile,
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: true,
expectedAuth: tls.NoClientCert,
},
{
name: "invalid key file",
certFile: certFile,
keyFile: "/non/existent/key.pem",
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: true,
expectedAuth: tls.NoClientCert,
},
{
name: "invalid server CA file",
certFile: certFile,
keyFile: keyFile,
serverCAFile: "/non/existent/server-ca.pem",
clientCAFile: "",
wantMTLS: false,
wantErr: true,
expectedAuth: tls.NoClientCert,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SetupRegularTLS(tt.certFile, tt.keyFile, tt.serverCAFile, tt.clientCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("SetupRegularTLS() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if result == nil {
t.Errorf("SetupRegularTLS() returned nil result")
return
}
if result.MTLS != tt.wantMTLS {
t.Errorf("SetupRegularTLS() MTLS = %v, want %v", result.MTLS, tt.wantMTLS)
}
if result.Config.ClientAuth != tt.expectedAuth {
t.Errorf("SetupRegularTLS() ClientAuth = %v, want %v", result.Config.ClientAuth, tt.expectedAuth)
}
if len(result.Config.Certificates) == 0 {
t.Errorf("SetupRegularTLS() should have at least one certificate")
}
})
}
}
func TestBuildMTLSDescription(t *testing.T) {
tests := []struct {
name string
serverCAFile string
clientCAFile string
want string
}{
{
name: "both server and client CA files",
serverCAFile: "/path/to/server-ca.pem",
clientCAFile: "/path/to/client-ca.pem",
want: "root ca /path/to/server-ca.pem client ca /path/to/client-ca.pem",
},
{
name: "only server CA file",
serverCAFile: "/path/to/server-ca.pem",
clientCAFile: "",
want: "root ca /path/to/server-ca.pem",
},
{
name: "only client CA file",
serverCAFile: "",
clientCAFile: "/path/to/client-ca.pem",
want: "client ca /path/to/client-ca.pem",
},
{
name: "no CA files",
serverCAFile: "",
clientCAFile: "",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildMTLSDescription(tt.serverCAFile, tt.clientCAFile)
if got != tt.want {
t.Errorf("BuildMTLSDescription() = %v, want %v", got, tt.want)
}
})
}
}
func TestErrorConstants(t *testing.T) {
// Test that error constants are properly defined
if ErrAppendServerCA == nil {
t.Error("ErrAppendServerCA should not be nil")
}
if ErrAppendClientCA == nil {
t.Error("ErrAppendClientCA should not be nil")
}
if ErrAppendServerCA.Error() != "failed to append server ca to tls.Config" {
t.Errorf("ErrAppendServerCA message = %v, want 'failed to append server ca to tls.Config'", ErrAppendServerCA.Error())
}
if ErrAppendClientCA.Error() != "failed to append client ca to tls.Config" {
t.Errorf("ErrAppendClientCA message = %v, want 'failed to append client ca to tls.Config'", ErrAppendClientCA.Error())
}
}
func TestTLSSetupResult(t *testing.T) {
// Test that TLSSetupResult struct works as expected
config := &tls.Config{}
result := &TLSSetupResult{
Config: config,
MTLS: true,
}
if result.Config != config {
t.Error("TLSSetupResult Config field should match assigned value")
}
if !result.MTLS {
t.Error("TLSSetupResult MTLS field should be true")
}
}
func TestReadFileOrDataEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "999 chars without newline (should try file)",
input: strings.Repeat("a", 999),
wantErr: true, // Should fail as file doesn't exist
},
{
name: "1001 chars without newline (should treat as data)",
input: strings.Repeat("a", 1001),
wantErr: false,
},
{
name: "short string with newline (should treat as data)",
input: "short\ndata",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ReadFileOrData(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
+3 -3
View File
@@ -18,8 +18,8 @@ import (
"github.com/ultravioletrs/cocos/agent/cvms"
cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -182,7 +182,7 @@ func main() {
cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger}))
}
grpcServerConfig := server.ServerConfig{
BaseConfig: server.BaseConfig{
Config: server.Config{
Port: defaultPort,
},
}