NOISSUE - Fix loading of CA certs on agent (#321)

* debug connection

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* actual fix

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* remove debugs

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* remove test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add unit test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* more tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* consolidate tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix client auth

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* better handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-12-04 16:03:41 +03:00
committed by GitHub
parent 0864eb69c9
commit 92a4f8bd32
6 changed files with 251 additions and 65 deletions
+3 -1
View File
@@ -5,6 +5,7 @@ package cli
import (
"context"
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients/grpc/agent"
"github.com/ultravioletrs/cocos/pkg/sdk"
@@ -25,12 +26,13 @@ func New(config grpc.Config) *CLI {
}
}
func (c *CLI) InitializeSDK() error {
func (c *CLI) InitializeSDK(cmd *cobra.Command) error {
agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.config)
if err != nil {
c.connectErr = err
return err
}
cmd.Println("🔗 Connected to agent using ", agentGRPCClient.Secure())
c.client = agentGRPCClient
c.agentSDK = sdk.NewAgentSDK(agentClient)
+1 -1
View File
@@ -100,7 +100,7 @@ func main() {
cliSVC := cli.New(agentGRPCConfig)
if err := cliSVC.InitializeSDK(); err == nil {
if err := cliSVC.InitializeSDK(rootCmd); err == nil {
defer cliSVC.Close()
}
+19 -17
View File
@@ -127,7 +127,7 @@ func (s *Server) Start() error {
return fmt.Errorf("failed to load auth certificates: %w", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientAuth: tls.NoClientCert,
Certificates: []tls.Certificate{certificate},
}
@@ -161,12 +161,17 @@ func (s *Server) Start() error {
}
mtlsCA = fmt.Sprintf("%s client ca %s", mtlsCA, s.Config.ClientCAFile)
}
if mtlsCA != "" {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
creds = grpc.Creds(credentials.NewTLS(tlsConfig))
switch {
case mtlsCA != "":
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile, mtlsCA))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS", s.Name, s.Address))
default:
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", s.Name, s.Address, s.Config.CertFile, s.Config.KeyFile))
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS", s.Name, s.Address))
}
listener, err = net.Listen("tcp", s.Address)
@@ -223,31 +228,28 @@ func (s *Server) Stop() error {
func loadCertFile(certFile string) ([]byte, error) {
if certFile != "" {
return os.ReadFile(certFile)
return readFileOrData(certFile)
}
return []byte{}, nil
}
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
var cert, key []byte
var err error
readFileOrData := func(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
}
func readFileOrData(input string) ([]byte, error) {
if len(input) < 1000 && !strings.Contains(input, "\n") {
data, err := os.ReadFile(input)
if err == nil {
return data, nil
}
return []byte(input), nil
}
return []byte(input), nil
}
cert, err = readFileOrData(certfile)
func loadX509KeyPair(certfile, keyfile string) (tls.Certificate, error) {
cert, err := readFileOrData(certfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read cert: %v", err)
}
key, err = readFileOrData(keyfile)
key, err := readFileOrData(keyfile)
if err != nil {
return tls.Certificate{}, fmt.Errorf("failed to read key: %v", err)
}
+225 -44
View File
@@ -12,6 +12,7 @@ import (
"fmt"
"log/slog"
"math/big"
"os"
"strings"
"sync"
"testing"
@@ -51,49 +52,39 @@ func TestNew(t *testing.T) {
assert.IsType(t, &Server{}, srv)
}
func TestServerStart(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := server.Config{
Host: "localhost",
Port: "0",
}
buf := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(100 * time.Millisecond)
cancel()
assert.Contains(t, buf.String(), "TestServer service gRPC server listening at localhost:0 without TLS")
}
func TestServerStartWithTLS(t *testing.T) {
func TestServerStartWithTLSFile(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
certFile, err := os.CreateTemp("", "cert*.pem")
assert.NoError(t, err)
keyFile, err := os.CreateTemp("", "key*.pem")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(certFile.Name())
os.Remove(keyFile.Name())
})
_, err = certFile.Write(cert)
assert.NoError(t, err)
_, err = keyFile.Write(key)
assert.NoError(t, err)
err = certFile.Close()
assert.NoError(t, err)
err = keyFile.Close()
assert.NoError(t, err)
config := server.Config{
Host: "localhost",
Port: "0",
CertFile: string(cert),
KeyFile: string(key),
CertFile: certFile.Name(),
KeyFile: keyFile.Name(),
}
logBuffer := &ThreadSafeBuffer{}
@@ -125,13 +116,41 @@ func TestServerStartWithTLS(t *testing.T) {
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStartWithAttestedTLS(t *testing.T) {
func TestServerStartWithmTLSFile(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
certFile, err := os.CreateTemp("", "cert*.pem")
assert.NoError(t, err)
keyFile, err := os.CreateTemp("", "key*.pem")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(certFile.Name())
os.Remove(keyFile.Name())
})
_, err = certFile.Write(cert)
assert.NoError(t, err)
_, err = keyFile.Write(key)
assert.NoError(t, err)
err = certFile.Close()
assert.NoError(t, err)
err = keyFile.Close()
assert.NoError(t, err)
config := server.Config{
Host: "localhost",
Port: "0",
AttestedTLS: true,
Host: "localhost",
Port: "0",
CertFile: certFile.Name(),
KeyFile: keyFile.Name(),
ServerCAFile: certFile.Name(),
ClientCAFile: certFile.Name(),
}
logBuffer := &ThreadSafeBuffer{}
@@ -152,16 +171,15 @@ func TestServerStartWithAttestedTLS(t *testing.T) {
wg.Wait()
time.Sleep(100 * time.Millisecond)
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(1000 * time.Millisecond)
time.Sleep(200 * time.Millisecond)
logContent := logBuffer.String()
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with Attested TLS")
qp.AssertExpectations(t)
fmt.Println(logContent)
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStop(t *testing.T) {
@@ -246,3 +264,166 @@ func (b *ThreadSafeBuffer) String() string {
defer b.mu.Unlock()
return b.buffer.String()
}
func TestServerInitializationAndStartup(t *testing.T) {
testCases := []struct {
name string
config server.Config
expectedLog string
expectError bool
setupCallback func(*testing.T, *server.Config, *ThreadSafeBuffer)
}{
{
name: "Non-TLS Server Startup",
config: server.Config{
Host: "localhost",
Port: "0",
},
expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS",
},
{
name: "TLS Server Startup with Self-Signed Certificate",
config: server.Config{
Host: "localhost",
Port: "0",
},
setupCallback: setupTLSConfig,
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
},
{
name: "TLS Server Startup with Invalid Certificates",
config: server.Config{
Host: "localhost",
Port: "0",
CertFile: "invalid",
KeyFile: "invalid",
},
expectError: true,
expectedLog: "failed to load auth certificates",
},
{
name: "mTLS Server Startup",
config: server.Config{
Host: "localhost",
Port: "0",
},
setupCallback: setupMTLSConfig,
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
},
{
name: "mTLS Server Startup with Invalid Root CA",
config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
setupCallback: setupInvalidRootCAConfig,
expectError: true,
expectedLog: "failed to append root ca to tls.Config",
},
{
name: "mTLS Server Startup with Invalid Client CA",
config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
setupCallback: setupInvalidClientCAConfig,
expectError: true,
expectedLog: "failed to append client ca to tls.Config",
},
{
name: "Attested TLS Server Startup",
config: server.Config{
Host: "localhost",
Port: "0",
AttestedTLS: true,
},
expectedLog: "TestServer service gRPC server listening at localhost:0 with Attested TLS",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if tc.setupCallback != nil {
tc.setupCallback(t, &tc.config, nil)
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
qp := new(mocks.QuoteProvider)
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
if tc.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedLog)
} else {
assert.NoError(t, err)
}
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
if !tc.expectError {
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, tc.expectedLog)
}
})
}
}
func setupTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
}
func setupMTLSConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ServerCAFile = string(cert)
config.ClientCAFile = string(cert)
}
func setupInvalidRootCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ServerCAFile = "invalid"
config.ClientCAFile = string(cert)
}
func setupInvalidClientCAConfig(t *testing.T, config *server.Config, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ClientCAFile = "invalid"
config.ServerCAFile = string(cert)
}
+1 -1
View File
@@ -20,7 +20,7 @@ func NewAgentClient(ctx context.Context, cfg grpc.Config) (grpc.Client, agent.Ag
return nil, nil, err
}
if client.Secure() != grpc.WithATLS {
if client.Secure() != grpc.WithATLS && client.Secure() != grpc.WithTLS {
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "agent",
+2 -1
View File
@@ -35,6 +35,7 @@ const (
const (
AttestationReportSize = 0x4A0
WithATLS = "with aTLS"
WithTLS = "with TLS"
)
var (
@@ -102,7 +103,7 @@ func (c *client) Close() error {
func (c *client) Secure() string {
switch c.secure {
case withTLS:
return "with TLS"
return WithTLS
case withmTLS:
return "with mTLS"
case withaTLS: