mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Fix loading of CA certs on agent (#321)
* debug connection Signed-off-by: Sammy Oina <sammyoina@gmail.com> * actual fix Signed-off-by: Sammy Oina <sammyoina@gmail.com> * remove debugs Signed-off-by: Sammy Oina <sammyoina@gmail.com> * remove test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add unit test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * more tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * consolidate tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix client auth Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug Signed-off-by: Sammy Oina <sammyoina@gmail.com> * better handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
0864eb69c9
commit
92a4f8bd32
+3
-1
@@ -5,6 +5,7 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
@@ -25,12 +26,13 @@ func New(config grpc.Config) *CLI {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CLI) InitializeSDK() error {
|
||||
func (c *CLI) InitializeSDK(cmd *cobra.Command) error {
|
||||
agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.config)
|
||||
if err != nil {
|
||||
c.connectErr = err
|
||||
return err
|
||||
}
|
||||
cmd.Println("🔗 Connected to agent using ", agentGRPCClient.Secure())
|
||||
c.client = agentGRPCClient
|
||||
|
||||
c.agentSDK = sdk.NewAgentSDK(agentClient)
|
||||
|
||||
+1
-1
@@ -100,7 +100,7 @@ func main() {
|
||||
|
||||
cliSVC := cli.New(agentGRPCConfig)
|
||||
|
||||
if err := cliSVC.InitializeSDK(); err == nil {
|
||||
if err := cliSVC.InitializeSDK(rootCmd); err == nil {
|
||||
defer cliSVC.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ func (s *Server) Start() error {
|
||||
return fmt.Errorf("failed to load auth certificates: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
ClientAuth: tls.NoClientCert,
|
||||
Certificates: []tls.Certificate{certificate},
|
||||
}
|
||||
|
||||
@@ -161,12 +161,17 @@ func (s *Server) Start() error {
|
||||
}
|
||||
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
|
||||
}
|
||||
|
||||
if mtlsCA != "" {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
|
||||
switch {
|
||||
case mtlsCA != "":
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS", s.Name, s.Address))
|
||||
default:
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS", s.Name, s.Address))
|
||||
}
|
||||
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
@@ -223,31 +228,28 @@ func (s *Server) Stop() error {
|
||||
|
||||
func loadCertFile(certFile string) ([]byte, error) {
|
||||
if certFile != "" {
|
||||
return os.ReadFile(certFile)
|
||||
return readFileOrData(certFile)
|
||||
}
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
var cert, key []byte
|
||||
var err error
|
||||
|
||||
readFileOrData := func(input string) ([]byte, error) {
|
||||
if len(input) < 1000 && !strings.Contains(input, "\n") {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
func readFileOrData(input string) ([]byte, error) {
|
||||
if len(input) < 1000 && !strings.Contains(input, "\n") {
|
||||
data, err := os.ReadFile(input)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
return []byte(input), nil
|
||||
}
|
||||
return []byte(input), nil
|
||||
}
|
||||
|
||||
cert, err = readFileOrData(certfile)
|
||||
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
|
||||
cert, err := readFileOrData(certfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
|
||||
}
|
||||
|
||||
key, err = readFileOrData(keyfile)
|
||||
key, err := readFileOrData(keyfile)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -51,49 +52,39 @@ func TestNew(t *testing.T) {
|
||||
assert.IsType(t, &Server{}, srv)
|
||||
}
|
||||
|
||||
func TestServerStart(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
}
|
||||
buf := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
assert.Contains(t, buf.String(), "TestServer service gRPC server listening at localhost:0 without TLS")
|
||||
}
|
||||
|
||||
func TestServerStartWithTLS(t *testing.T) {
|
||||
func TestServerStartWithTLSFile(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
certFile, err := os.CreateTemp("", "cert*.pem")
|
||||
assert.NoError(t, err)
|
||||
|
||||
keyFile, err := os.CreateTemp("", "key*.pem")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(certFile.Name())
|
||||
os.Remove(keyFile.Name())
|
||||
})
|
||||
|
||||
_, err = certFile.Write(cert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = keyFile.Write(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = certFile.Close()
|
||||
assert.NoError(t, err)
|
||||
err = keyFile.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: string(cert),
|
||||
KeyFile: string(key),
|
||||
CertFile: certFile.Name(),
|
||||
KeyFile: keyFile.Name(),
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
@@ -125,13 +116,41 @@ func TestServerStartWithTLS(t *testing.T) {
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
|
||||
}
|
||||
|
||||
func TestServerStartWithAttestedTLS(t *testing.T) {
|
||||
func TestServerStartWithmTLSFile(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
certFile, err := os.CreateTemp("", "cert*.pem")
|
||||
assert.NoError(t, err)
|
||||
|
||||
keyFile, err := os.CreateTemp("", "key*.pem")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(certFile.Name())
|
||||
os.Remove(keyFile.Name())
|
||||
})
|
||||
|
||||
_, err = certFile.Write(cert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = keyFile.Write(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = certFile.Close()
|
||||
assert.NoError(t, err)
|
||||
err = keyFile.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config := server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
AttestedTLS: true,
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: certFile.Name(),
|
||||
KeyFile: keyFile.Name(),
|
||||
ServerCAFile: certFile.Name(),
|
||||
ClientCAFile: certFile.Name(),
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
@@ -152,16 +171,15 @@ func TestServerStartWithAttestedTLS(t *testing.T) {
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(1000 * time.Millisecond)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
logContent := logBuffer.String()
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS")
|
||||
|
||||
qp.AssertExpectations(t)
|
||||
fmt.Println(logContent)
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
|
||||
}
|
||||
|
||||
func TestServerStop(t *testing.T) {
|
||||
@@ -246,3 +264,166 @@ func (b *ThreadSafeBuffer) String() string {
|
||||
defer b.mu.Unlock()
|
||||
return b.buffer.String()
|
||||
}
|
||||
|
||||
func TestServerInitializationAndStartup(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config server.Config
|
||||
expectedLog string
|
||||
expectError bool
|
||||
setupCallback func(*testing.T, *server.Config, *ThreadSafeBuffer)
|
||||
}{
|
||||
{
|
||||
name: "Non-TLS Server Startup",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS",
|
||||
},
|
||||
{
|
||||
name: "TLS Server Startup with Self-Signed Certificate",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
setupCallback: setupTLSConfig,
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
|
||||
},
|
||||
{
|
||||
name: "TLS Server Startup with Invalid Certificates",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: "invalid",
|
||||
KeyFile: "invalid",
|
||||
},
|
||||
expectError: true,
|
||||
expectedLog: "failed to load auth certificates",
|
||||
},
|
||||
{
|
||||
name: "mTLS Server Startup",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
setupCallback: setupMTLSConfig,
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
|
||||
},
|
||||
{
|
||||
name: "mTLS Server Startup with Invalid Root CA",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
setupCallback: setupInvalidRootCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to append root ca to tls.Config",
|
||||
},
|
||||
{
|
||||
name: "mTLS Server Startup with Invalid Client CA",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
setupCallback: setupInvalidClientCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to append client ca to tls.Config",
|
||||
},
|
||||
{
|
||||
name: "Attested TLS Server Startup",
|
||||
config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
AttestedTLS: true,
|
||||
},
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if tc.setupCallback != nil {
|
||||
tc.setupCallback(t, &tc.config, nil)
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
qp := new(mocks.QuoteProvider)
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedLog)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
if !tc.expectError {
|
||||
logContent := logBuffer.String()
|
||||
fmt.Println(logContent)
|
||||
assert.Contains(t, logContent, tc.expectedLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
}
|
||||
|
||||
func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
config.ServerCAFile = string(cert)
|
||||
config.ClientCAFile = string(cert)
|
||||
}
|
||||
|
||||
func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
config.ServerCAFile = "invalid"
|
||||
config.ClientCAFile = string(cert)
|
||||
}
|
||||
|
||||
func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
config.ClientCAFile = "invalid"
|
||||
config.ServerCAFile = string(cert)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.Ag
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if client.Secure() != grpc.WithATLS {
|
||||
if client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "agent",
|
||||
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
const (
|
||||
AttestationReportSize = 0x4A0
|
||||
WithATLS = "with aTLS"
|
||||
WithTLS = "with TLS"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -102,7 +103,7 @@ func (c *client) Close() error {
|
||||
func (c *client) Secure() string {
|
||||
switch c.secure {
|
||||
case withTLS:
|
||||
return "with TLS"
|
||||
return WithTLS
|
||||
case withmTLS:
|
||||
return "with mTLS"
|
||||
case withaTLS:
|
||||
|
||||
Reference in New Issue
Block a user