mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Implement structured logging with log forwarding for ingress-proxy and computation-runner, update component versions, and improve aTLS initialization and error handling. (#583)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat: Implement structured logging with log forwarding for `ingress-proxy` and `computation-runner`, update component versions, and improve aTLS initialization and error handling. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Remove explicit AGENT_ENABLE_ATLS configuration and update component versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Correct aTLS nonce verification for truncated hashes, delegate internal CVM server TLS to Ingress Proxy, and update component versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Update package build sources to ultravioletrs/cocos main branch and remove local development keys and encrypted algorithm. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove the `pkg/server` module, including its generic gRPC and HTTP server implementations. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: clarify nonce truncation in the certificate verifier. Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c1cbcec851
commit
42b05524c8
+53
-40
@@ -4,17 +4,17 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
@@ -29,55 +29,68 @@ type AgentServer interface {
|
||||
}
|
||||
|
||||
type agentServer struct {
|
||||
gs server.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
certProvider atls.CertificateProvider
|
||||
gs *grpc.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
}
|
||||
|
||||
func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider atls.CertificateProvider) AgentServer {
|
||||
func NewServer(logger *slog.Logger, svc agent.Service, host string) AgentServer {
|
||||
return &agentServer{
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
certProvider: certProvider,
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
}
|
||||
}
|
||||
|
||||
func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
agentGrpcServerConfig := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: defSvcGRPCSocket,
|
||||
Port: "",
|
||||
CertFile: cfg.CertFile,
|
||||
KeyFile: cfg.KeyFile,
|
||||
ServerCAFile: cfg.ServerCAFile,
|
||||
ClientCAFile: cfg.ClientCAFile,
|
||||
},
|
||||
},
|
||||
AttestedTLS: cfg.AttestedTls,
|
||||
}
|
||||
|
||||
registerAgentServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(as.svc))
|
||||
}
|
||||
|
||||
authSvc, err := auth.New(cmp)
|
||||
if err != nil {
|
||||
as.logger.WithGroup(cmp.ID).Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
grpcServerOptions := []grpc.ServerOption{
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
}
|
||||
|
||||
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.certProvider)
|
||||
// Add authentication interceptors
|
||||
unary, stream := agentgrpc.NewAuthInterceptor(authSvc)
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
|
||||
|
||||
// Internal Unix socket is pure plaintext HTTP/2; Ingress Proxy handles external aTLS termination
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.Creds(insecure.NewCredentials()))
|
||||
|
||||
as.gs = grpc.NewServer(grpcServerOptions...)
|
||||
|
||||
reflection.Register(as.gs)
|
||||
agent.RegisterAgentServiceServer(as.gs, agentgrpc.NewServer(as.svc))
|
||||
|
||||
socketPath := as.host
|
||||
if socketPath == "" || socketPath == "0.0.0.0" {
|
||||
socketPath = defSvcGRPCSocket
|
||||
}
|
||||
|
||||
var listener net.Listener
|
||||
if socketPath[0] == '/' || socketPath[0] == '.' {
|
||||
// Remove existing socket file if it exists
|
||||
_ = os.Remove(socketPath)
|
||||
listener, err = net.Listen("unix", socketPath)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", socketPath)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
as.logger.Error(fmt.Sprintf("failed to listen on %s: %s", socketPath, err))
|
||||
return err
|
||||
}
|
||||
|
||||
as.logger.Info(fmt.Sprintf("agent service gRPC server listening at %s without TLS", socketPath))
|
||||
|
||||
go func() {
|
||||
err := as.gs.Start()
|
||||
if err != nil {
|
||||
err := as.gs.Serve(listener)
|
||||
if err != nil && err != grpc.ErrServerStopped {
|
||||
as.logger.Error(fmt.Sprintf("failed to start grpc server %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
@@ -86,8 +99,8 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
}
|
||||
|
||||
func (as *agentServer) Stop() error {
|
||||
if as.gs == nil {
|
||||
return nil
|
||||
if as.gs != nil {
|
||||
as.gs.GracefulStop()
|
||||
}
|
||||
return as.gs.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, []byte) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
mockSvc := new(mocks.Service)
|
||||
host := "localhost"
|
||||
host := "localhost:0"
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.NoError(t, err, "Failed to generate ECDSA key")
|
||||
@@ -70,7 +70,7 @@ func TestNewServer(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(tt.logger, tt.svc, tt.host, nil)
|
||||
server := NewServer(tt.logger, tt.svc, tt.host)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
|
||||
@@ -194,7 +194,7 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMocks(svc)
|
||||
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
err := server.Start(tt.cfg, tt.cmp)
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestAgentServer_Stop(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
err := tt.setupServer(server)
|
||||
if err != nil {
|
||||
@@ -296,7 +296,7 @@ func TestAgentServer_Stop(t *testing.T) {
|
||||
|
||||
func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
// Start the server
|
||||
cfg := agent.AgentConfig{}
|
||||
@@ -340,7 +340,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
|
||||
func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
@@ -488,7 +488,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
server := NewServer(logger, svc, host)
|
||||
|
||||
err := server.Start(tt.config, tt.cmp)
|
||||
|
||||
|
||||
Binary file not shown.
+9
-6
@@ -59,7 +59,6 @@ type config struct {
|
||||
AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"`
|
||||
AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"`
|
||||
AttestationServiceSocket string `env:"ATTESTATION_SERVICE_SOCKET" envDefault:"/run/cocos/attestation.sock"`
|
||||
EnableATLS bool `env:"AGENT_ENABLE_ATLS" envDefault:"true"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -139,6 +138,7 @@ func main() {
|
||||
})
|
||||
|
||||
ccPlatform := attestation.CCPlatform()
|
||||
logger.Info(fmt.Sprintf("Detected confidential computing platform: %v", ccPlatform))
|
||||
|
||||
azureConfig := azure.NewEnvConfigFromAgent(
|
||||
cfg.AgentOSBuild,
|
||||
@@ -209,7 +209,8 @@ func main() {
|
||||
}
|
||||
|
||||
var certProvider atls.CertificateProvider
|
||||
if cfg.EnableATLS && ccPlatform != attestation.NoCC {
|
||||
if ccPlatform != attestation.NoCC {
|
||||
logger.Info(fmt.Sprintf("Initializing aTLS for platform %v with attestation service at %s", ccPlatform, cfg.AttestationServiceSocket))
|
||||
var certsSDK sdk.SDK
|
||||
if cfg.CAUrl != "" {
|
||||
certsSDK = sdk.NewSDK(sdk.Config{
|
||||
@@ -218,10 +219,12 @@ func main() {
|
||||
}
|
||||
certProvider, err = atls.NewProvider(attClient, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create certificate provider: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
logger.Error(fmt.Sprintf("failed to create certificate provider for aTLS: %s. Continuing without attested TLS.", err))
|
||||
} else {
|
||||
logger.Info("Successfully created aTLS certificate provider")
|
||||
}
|
||||
} else {
|
||||
logger.Warn("No Confidential Computing platform detected (NoCC). Certificate provider remains nil; aTLS will not be available for computations.")
|
||||
}
|
||||
|
||||
// Create ingress proxy server
|
||||
@@ -240,7 +243,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), ingressProxy, storageDir, reconnectFn, cvmGRPCClient)
|
||||
mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost), ingressProxy, storageDir, reconnectFn, cvmGRPCClient)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
|
||||
@@ -13,9 +13,12 @@ import (
|
||||
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
runnerevents "github.com/ultravioletrs/cocos/agent/runner/events"
|
||||
"github.com/ultravioletrs/cocos/agent/runner/service"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
@@ -51,16 +54,43 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
|
||||
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
|
||||
logger := slog.New(handler)
|
||||
|
||||
// Connect to Log Forwarder
|
||||
logClient, err := logclient.NewClient(cfg.LogForwarder)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Events will not be forwarded.", err))
|
||||
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs and events will not be forwarded.", err))
|
||||
} else {
|
||||
defer logClient.Close()
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case msg := <-logQueue:
|
||||
if logClient == nil {
|
||||
continue
|
||||
}
|
||||
switch m := msg.Message.(type) {
|
||||
case *cvms.ClientStreamMessage_AgentLog:
|
||||
err := logClient.SendLog(ctx, &logpb.LogEntry{
|
||||
Message: m.AgentLog.Message,
|
||||
ComputationId: m.AgentLog.ComputationId,
|
||||
Level: m.AgentLog.Level,
|
||||
Timestamp: m.AgentLog.Timestamp,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to send log", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
eventSvc := runnerevents.NewAdapter(logClient, svcName)
|
||||
|
||||
// Remove existing socket if it exists
|
||||
|
||||
@@ -5,14 +5,18 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
"github.com/ultravioletrs/cocos/pkg/egress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -22,8 +26,9 @@ const (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"`
|
||||
Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"`
|
||||
Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"`
|
||||
Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"`
|
||||
LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -51,9 +56,20 @@ func main() {
|
||||
}
|
||||
|
||||
func run(cfg config) error {
|
||||
logger, err := mglog.New(os.Stdout, cfg.Level)
|
||||
var level slog.Level
|
||||
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
|
||||
return fmt.Errorf("invalid log level: %w", err)
|
||||
}
|
||||
|
||||
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
|
||||
logger := slog.New(handler)
|
||||
|
||||
logClient, err := logclient.NewClient(cfg.LogForwarder)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create logger: %w", err)
|
||||
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs will not be forwarded.", err))
|
||||
} else {
|
||||
defer logClient.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -61,6 +77,31 @@ func run(cfg config) error {
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case msg := <-logQueue:
|
||||
if logClient == nil {
|
||||
continue
|
||||
}
|
||||
switch m := msg.Message.(type) {
|
||||
case *cvms.ClientStreamMessage_AgentLog:
|
||||
err := logClient.SendLog(ctx, &logpb.LogEntry{
|
||||
Message: m.AgentLog.Message,
|
||||
ComputationId: m.AgentLog.ComputationId,
|
||||
Level: m.AgentLog.Level,
|
||||
Timestamp: m.AgentLog.Timestamp,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to send log", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
proxy := egress.NewProxy(logger, ":"+cfg.Port)
|
||||
|
||||
g.Go(func() error {
|
||||
|
||||
@@ -5,20 +5,24 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/absmach/certs/sdk"
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -39,6 +43,7 @@ type config struct {
|
||||
AgentOSBuild string `env:"AGENT_OS_BUILD" envDefault:"UVC"`
|
||||
AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"`
|
||||
AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"`
|
||||
LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -66,11 +71,52 @@ func main() {
|
||||
}
|
||||
|
||||
func run(cfg config) error {
|
||||
logger, err := mglog.New(os.Stdout, cfg.LogLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create logger: %w", err)
|
||||
var level slog.Level
|
||||
if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil {
|
||||
return fmt.Errorf("invalid log level: %w", err)
|
||||
}
|
||||
|
||||
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
|
||||
logger := slog.New(handler)
|
||||
|
||||
logClient, err := logclient.NewClient(cfg.LogForwarder)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs will not be forwarded.", err))
|
||||
} else {
|
||||
defer logClient.Close()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case msg := <-logQueue:
|
||||
if logClient == nil {
|
||||
continue
|
||||
}
|
||||
switch m := msg.Message.(type) {
|
||||
case *cvms.ClientStreamMessage_AgentLog:
|
||||
err := logClient.SendLog(ctx, &logpb.LogEntry{
|
||||
Message: m.AgentLog.Message,
|
||||
ComputationId: m.AgentLog.ComputationId,
|
||||
Level: m.AgentLog.Level,
|
||||
Timestamp: m.AgentLog.Timestamp,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to send log", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
backendURL, err := url.Parse(cfg.Backend)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse backend URL: %w", err)
|
||||
@@ -111,11 +157,6 @@ func run(cfg config) error {
|
||||
logger.Warn("No Confidential Computing platform detected. ATLS will not be available.")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
// Create proxy server (but don't start it yet - it will be started per-computation)
|
||||
_ = ingress.NewProxyServer(logger, backendURL, certProvider)
|
||||
|
||||
|
||||
+4
-5
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/jaeger"
|
||||
"github.com/absmach/supermq/pkg/prometheus"
|
||||
smqserver "github.com/absmach/supermq/pkg/server"
|
||||
grpcserver "github.com/absmach/supermq/pkg/server/grpc"
|
||||
httpserver "github.com/absmach/supermq/pkg/server/http"
|
||||
"github.com/absmach/supermq/pkg/uuid"
|
||||
"github.com/caarlos0/env/v11"
|
||||
@@ -26,8 +27,6 @@ import (
|
||||
"github.com/ultravioletrs/cocos/manager/api/http"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/tracing"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
@@ -113,7 +112,7 @@ func main() {
|
||||
args := qemuCfg.ConstructQemuArgs()
|
||||
logger.Info(strings.Join(args, " "))
|
||||
|
||||
managerGRPCConfig := server.ServerConfig{}
|
||||
managerGRPCConfig := smqserver.Config{}
|
||||
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
exitCode = 1
|
||||
@@ -145,7 +144,7 @@ func main() {
|
||||
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
|
||||
}
|
||||
|
||||
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil)
|
||||
gs := grpcserver.NewServer(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger)
|
||||
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, http.MakeHandler(chi.NewMux(), svcName, cfg.InstanceID), logger)
|
||||
|
||||
@@ -158,7 +157,7 @@ func main() {
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return server.StopHandler(ctx, cancel, logger, svcName, gs, hs)
|
||||
return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs, hs)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3
|
||||
@@ -4,8 +4,8 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
AGENT_VERSION = 913bbccf3a22053e1979da004c732007336fc890
|
||||
AGENT_SITE = $(call github,sammyoina,cocos-ai,$(AGENT_VERSION))
|
||||
AGENT_VERSION = main
|
||||
AGENT_SITE = $(call github,ultravioletrs,cocos,$(AGENT_VERSION))
|
||||
|
||||
define AGENT_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) agent EMBED_ENABLED=$(AGENT_EMBED_ENABLED)
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
ATTESTATION_SERVICE_VERSION = 913bbccf3a22053e1979da004c732007336fc890
|
||||
ATTESTATION_SERVICE_SITE = $(call github,sammyoina,cocos-ai,$(ATTESTATION_SERVICE_VERSION))
|
||||
ATTESTATION_SERVICE_VERSION = main
|
||||
ATTESTATION_SERVICE_SITE = $(call github,ultravioletrs,cocos,$(ATTESTATION_SERVICE_VERSION))
|
||||
|
||||
define ATTESTATION_SERVICE_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) attestation-service
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
COMPUTATION_RUNNER_VERSION = 913bbccf3a22053e1979da004c732007336fc890
|
||||
COMPUTATION_RUNNER_SITE = $(call github,sammyoina,cocos-ai,$(COMPUTATION_RUNNER_VERSION))
|
||||
COMPUTATION_RUNNER_VERSION = main
|
||||
COMPUTATION_RUNNER_SITE = $(call github,ultravioletrs,cocos,$(COMPUTATION_RUNNER_VERSION))
|
||||
|
||||
define COMPUTATION_RUNNER_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) computation-runner
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
EGRESS_PROXY_VERSION = 913bbccf3a22053e1979da004c732007336fc890
|
||||
EGRESS_PROXY_SITE = $(call github,sammyoina,cocos-ai,$(EGRESS_PROXY_VERSION))
|
||||
EGRESS_PROXY_VERSION = main
|
||||
EGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(EGRESS_PROXY_VERSION))
|
||||
|
||||
define EGRESS_PROXY_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) egress-proxy
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
INGRESS_PROXY_VERSION = 913bbccf3a22053e1979da004c732007336fc890
|
||||
INGRESS_PROXY_SITE = $(call github,sammyoina,cocos-ai,$(INGRESS_PROXY_VERSION))
|
||||
INGRESS_PROXY_VERSION = main
|
||||
INGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(INGRESS_PROXY_VERSION))
|
||||
|
||||
define INGRESS_PROXY_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) ingress-proxy
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
LOG_FORWARDER_VERSION = 913bbccf3a22053e1979da004c732007336fc890
|
||||
LOG_FORWARDER_SITE = $(call github,sammyoina,cocos-ai,$(LOG_FORWARDER_VERSION))
|
||||
LOG_FORWARDER_VERSION = main
|
||||
LOG_FORWARDER_SITE = $(call github,ultravioletrs,cocos,$(LOG_FORWARDER_VERSION))
|
||||
|
||||
define LOG_FORWARDER_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) log-forwarder
|
||||
|
||||
@@ -9,7 +9,6 @@ WorkingDirectory=/cocos
|
||||
Environment="HTTP_PROXY=http://localhost:3128"
|
||||
Environment="HTTPS_PROXY=http://localhost:3128"
|
||||
Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/"
|
||||
Environment="AGENT_ENABLE_ATLS=false"
|
||||
StandardOutput=file:/var/log/cocos/agent.stdout
|
||||
StandardError=file:/var/log/cocos/agent.stderr
|
||||
EnvironmentFile=/etc/cocos/environment
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY
|
||||
-----END PRIVATE KEY-----
|
||||
@@ -1,3 +0,0 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw=
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -36,20 +37,35 @@ func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
|
||||
slog.Debug("Starting peer certificate verification for aTLS")
|
||||
if len(rawCerts) == 0 {
|
||||
return fmt.Errorf("no certificates provided")
|
||||
err := fmt.Errorf("no certificates provided")
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse x509 certificate: %w", err)
|
||||
err = fmt.Errorf("failed to parse x509 certificate: %w", err)
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
slog.Debug("Successfully parsed peer x509 certificate", "subject", cert.Subject.String())
|
||||
|
||||
if err := v.verifyCertificateSignature(cert); err != nil {
|
||||
return fmt.Errorf("certificate signature verification failed: %w", err)
|
||||
err = fmt.Errorf("certificate signature verification failed: %w", err)
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
slog.Debug("Successfully verified peer certificate signature")
|
||||
|
||||
return v.verifyAttestationExtension(cert, nonce)
|
||||
err = v.verifyAttestationExtension(cert, nonce)
|
||||
if err != nil {
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
slog.Debug("Successfully verified aTLS attestation extension")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error {
|
||||
@@ -71,6 +87,7 @@ func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate)
|
||||
func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error {
|
||||
for _, ext := range cert.Extensions {
|
||||
if platformType, err := platformTypeFromOID(ext.Id); err == nil {
|
||||
slog.Debug("Found attestation extension in peer certificate", "platform_type", platformType)
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal public key: %w", err)
|
||||
@@ -93,22 +110,29 @@ func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKe
|
||||
// Verify nonce matches
|
||||
teeNonce := append(pubKey, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
// The attestation provider truncates the 64-byte hash to 32 bytes for the EAT token nonce claim
|
||||
// This matches the Attestation Service API and standard cryptographic nonce sizes.
|
||||
expectedNonce := hashNonce[:32]
|
||||
|
||||
// Compare nonces (EAT nonce should match our computed nonce)
|
||||
if len(claims.Nonce) != len(hashNonce) {
|
||||
return fmt.Errorf("nonce length mismatch: expected %d, got %d", len(hashNonce), len(claims.Nonce))
|
||||
if len(claims.Nonce) != len(expectedNonce) {
|
||||
err := fmt.Errorf("nonce length mismatch: expected %d, got %d", len(expectedNonce), len(claims.Nonce))
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
nonceMatch := true
|
||||
for i := range claims.Nonce {
|
||||
if claims.Nonce[i] != hashNonce[i] {
|
||||
if claims.Nonce[i] != expectedNonce[i] {
|
||||
nonceMatch = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !nonceMatch {
|
||||
return fmt.Errorf("nonce mismatch in EAT token")
|
||||
err := fmt.Errorf("nonce mismatch in EAT token")
|
||||
slog.Error("aTLS handshake failed", "reason", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// Get platform verifier
|
||||
@@ -151,6 +175,7 @@ func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKe
|
||||
return fmt.Errorf("failed to verify attestation with CoRIM: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("CoRIM verification passed for aTLS peer certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestVerifyPeerCertificate_Success(t *testing.T) {
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{
|
||||
Nonce: hashNonce[:],
|
||||
Nonce: hashNonce[:32],
|
||||
RawReport: []byte("mock-report"),
|
||||
}
|
||||
eatBytes, err := cbor.Marshal(claims)
|
||||
@@ -155,7 +155,7 @@ func TestVerifyPeerCertificate_AzureSuccess(t *testing.T) {
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
|
||||
peerTemplate := &x509.Certificate{
|
||||
@@ -210,7 +210,7 @@ func TestVerifyPeerCertificate_TDXSuccess(t *testing.T) {
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
|
||||
peerTemplate := &x509.Certificate{
|
||||
@@ -273,7 +273,7 @@ func TestVerifyPeerCertificate_Failures_More(t *testing.T) {
|
||||
nonce := []byte("nonce")
|
||||
teeNonce := append(peerPubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
|
||||
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
|
||||
eatBytes, _ := cbor.Marshal(claims)
|
||||
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
|
||||
certDERWithExt, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
@@ -323,7 +323,7 @@ func TestVerifyPeerCertificate_Failures_Ext(t *testing.T) {
|
||||
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
|
||||
wrongTeeNonce := append(peerPubKeyDER, []byte("wrong-nonce")...)
|
||||
wrongHashNonce := sha3.Sum512(wrongTeeNonce)
|
||||
claims.Nonce = wrongHashNonce[:]
|
||||
claims.Nonce = wrongHashNonce[:32]
|
||||
eatBytes, _ = cbor.Marshal(claims)
|
||||
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
|
||||
certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
|
||||
|
||||
@@ -45,7 +45,7 @@ func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
|
||||
}
|
||||
|
||||
// Retry with exponential backoff for concurrent request handling
|
||||
maxRetries := 3
|
||||
maxRetries := 10
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
_, err := c.client.SendLog(ctx, entry)
|
||||
@@ -57,8 +57,11 @@ func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
|
||||
|
||||
// Don't retry on last attempt
|
||||
if attempt < maxRetries-1 {
|
||||
// Exponential backoff: 10ms, 20ms, 40ms
|
||||
backoff := time.Duration(10*(1<<uint(attempt))) * time.Millisecond
|
||||
// Backoff: 100ms, 200ms, 400ms... max 2s
|
||||
backoff := time.Duration(100*(1<<uint(attempt))) * time.Millisecond
|
||||
if backoff > 2*time.Second {
|
||||
backoff = 2 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
@@ -76,7 +79,7 @@ func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
|
||||
}
|
||||
|
||||
// Retry with exponential backoff for concurrent request handling
|
||||
maxRetries := 3
|
||||
maxRetries := 10
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
_, err := c.client.SendEvent(ctx, entry)
|
||||
@@ -88,8 +91,11 @@ func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
|
||||
|
||||
// Don't retry on last attempt
|
||||
if attempt < maxRetries-1 {
|
||||
// Exponential backoff: 10ms, 20ms, 40ms
|
||||
backoff := time.Duration(10*(1<<uint(attempt))) * time.Millisecond
|
||||
// Backoff: 100ms, 200ms, 400ms... max 2s
|
||||
backoff := time.Duration(100*(1<<uint(attempt))) * time.Millisecond
|
||||
if backoff > 2*time.Second {
|
||||
backoff = 2 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,8 +488,8 @@ func TestClientSendLogAllRetriesFail(t *testing.T) {
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &retryMockLogCollectorServer{
|
||||
failCount: 10, // Fail all attempts
|
||||
maxFailCount: 10,
|
||||
failCount: 20, // Fail all attempts
|
||||
maxFailCount: 20,
|
||||
}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
@@ -513,8 +513,8 @@ func TestClientSendLogAllRetriesFail(t *testing.T) {
|
||||
// Should fail after all retries
|
||||
err = client.SendLog(ctx, entry)
|
||||
assert.Error(t, err)
|
||||
// 3 retries + 1 final attempt = 4 calls
|
||||
assert.Equal(t, 4, mockServer.callCount)
|
||||
// 10 retries + 1 final attempt = 11 calls
|
||||
assert.Equal(t, 11, mockServer.callCount)
|
||||
}
|
||||
|
||||
func TestClientSendEventAllRetriesFail(t *testing.T) {
|
||||
@@ -527,8 +527,8 @@ func TestClientSendEventAllRetriesFail(t *testing.T) {
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &retryMockLogCollectorServer{
|
||||
failCount: 10,
|
||||
maxFailCount: 10,
|
||||
failCount: 20,
|
||||
maxFailCount: 20,
|
||||
}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
@@ -551,8 +551,8 @@ func TestClientSendEventAllRetriesFail(t *testing.T) {
|
||||
// Should fail after all retries
|
||||
err = client.SendEvent(ctx, entry)
|
||||
assert.Error(t, err)
|
||||
// 3 retries + 1 final attempt = 4 calls
|
||||
assert.Equal(t, 4, mockServer.eventCallCount)
|
||||
// 10 retries + 1 final attempt = 11 calls
|
||||
assert.Equal(t, 11, mockServer.eventCallCount)
|
||||
}
|
||||
|
||||
// retryMockLogCollectorServer is a mock server that fails a specified number of times.
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
@@ -138,7 +137,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
|
||||
if cfg.AttestedTLS {
|
||||
if p.certProvider == nil {
|
||||
return fmt.Errorf("attested TLS requested but no certificate provider available")
|
||||
return fmt.Errorf("attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running")
|
||||
}
|
||||
tlsConfig = &tls.Config{
|
||||
GetCertificate: p.certProvider.GetCertificate,
|
||||
@@ -146,7 +145,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
mtls, err := ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
@@ -159,7 +158,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
}
|
||||
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
|
||||
// Regular TLS
|
||||
tlsSetup, err := server.SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
tlsSetup, err := SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
@@ -402,7 +402,7 @@ func TestProxyAttestedTLSMissingProvider(t *testing.T) {
|
||||
|
||||
err := ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "attested TLS requested but no certificate provider available", err.Error())
|
||||
assert.Equal(t, "attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running", err.Error())
|
||||
}
|
||||
|
||||
func TestProxyMTLS(t *testing.T) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package server
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
@@ -144,18 +144,3 @@ func SetupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*TLS
|
||||
|
||||
return &TLSSetupResult{Config: tlsConfig, MTLS: mtls}, nil
|
||||
}
|
||||
|
||||
// BuildMTLSDescription builds a description string for mTLS configuration.
|
||||
func BuildMTLSDescription(serverCAFile, clientCAFile string) string {
|
||||
var parts []string
|
||||
|
||||
if serverCAFile != "" {
|
||||
parts = append(parts, fmt.Sprintf("root ca %s", serverCAFile))
|
||||
}
|
||||
|
||||
if clientCAFile != "" {
|
||||
parts = append(parts, fmt.Sprintf("client ca %s", clientCAFile))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package server contains the gRPC server implementation.
|
||||
package server
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package grpc contains the gRPC server implementation.
|
||||
package grpc
|
||||
@@ -1,265 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/health"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
const (
|
||||
stopWaitTime = 5 * time.Second
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
server.BaseServer
|
||||
mu sync.RWMutex
|
||||
server *grpc.Server
|
||||
health *health.Server
|
||||
registerService serviceRegister
|
||||
authSvc auth.Authenticator
|
||||
certProvider atls.CertificateProvider
|
||||
attestedTLSEnabled bool
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
type serviceRegister func(srv *grpc.Server)
|
||||
|
||||
var _ server.Server = (*Server)(nil)
|
||||
|
||||
func New(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
|
||||
registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, certProvider atls.CertificateProvider,
|
||||
) server.Server {
|
||||
base := config.GetBaseConfig()
|
||||
listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port)
|
||||
|
||||
var attestedTLS bool
|
||||
|
||||
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
|
||||
if certProvider == nil {
|
||||
logger.Error("Failed to create certificate provider")
|
||||
} else {
|
||||
attestedTLS = true
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
BaseServer: server.BaseServer{
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
Name: name,
|
||||
Address: listenFullAddress,
|
||||
Config: config,
|
||||
Logger: logger,
|
||||
},
|
||||
registerService: registerService,
|
||||
authSvc: authSvc,
|
||||
certProvider: certProvider,
|
||||
attestedTLSEnabled: attestedTLS,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
if s.started {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server already started")
|
||||
}
|
||||
if s.stopped {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server already stopped")
|
||||
}
|
||||
s.started = true
|
||||
s.mu.Unlock()
|
||||
|
||||
errCh := make(chan error)
|
||||
grpcServerOptions := []grpc.ServerOption{
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
}
|
||||
|
||||
// Add authentication interceptors if auth service is available
|
||||
if s.authSvc != nil {
|
||||
unary, stream := agentgrpc.NewAuthInterceptor(s.authSvc)
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
|
||||
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
|
||||
}
|
||||
|
||||
// Configure credentials
|
||||
creds, err := s.configureCredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure credentials: %w", err)
|
||||
}
|
||||
|
||||
grpcServerOptions = append(grpcServerOptions, creds)
|
||||
|
||||
// Create listener - detect Unix socket vs TCP
|
||||
var listener net.Listener
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
|
||||
// Check if this is a Unix socket path (starts with /)
|
||||
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
|
||||
// Unix socket
|
||||
// Remove existing socket file if it exists
|
||||
_ = os.Remove(baseConfig.Host)
|
||||
|
||||
listener, err = net.Listen("unix", baseConfig.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on Unix socket %s: %w", baseConfig.Host, err)
|
||||
}
|
||||
} else {
|
||||
// TCP socket
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create and configure server
|
||||
s.mu.Lock()
|
||||
s.server = grpc.NewServer(grpcServerOptions...)
|
||||
s.health = health.NewServer()
|
||||
grpchealth.RegisterHealthServer(s.server, s.health)
|
||||
s.registerService(s.server)
|
||||
s.health.SetServingStatus(s.Name, grpchealth.HealthCheckResponse_SERVING)
|
||||
s.mu.Unlock()
|
||||
|
||||
// Start server
|
||||
go func() {
|
||||
s.mu.RLock()
|
||||
server := s.server
|
||||
s.mu.RUnlock()
|
||||
|
||||
if server != nil {
|
||||
errCh <- server.Serve(listener)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.Ctx.Done():
|
||||
return s.Stop()
|
||||
case err := <-errCh:
|
||||
s.Cancel()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) configureCredentials() (grpc.ServerOption, error) {
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
|
||||
// Check if attested TLS should be used
|
||||
if s.shouldUseAttestedTLS() {
|
||||
return s.configureAttestedTLS(baseConfig.Config)
|
||||
}
|
||||
|
||||
// Check if regular TLS should be used
|
||||
if s.shouldUseRegularTLS(baseConfig.Config) {
|
||||
return s.configureRegularTLS(baseConfig.Config)
|
||||
}
|
||||
|
||||
// Use insecure credentials
|
||||
// Determine address for logging
|
||||
addr := s.Address
|
||||
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
|
||||
addr = baseConfig.Host
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, addr))
|
||||
return grpc.Creds(insecure.NewCredentials()), nil
|
||||
}
|
||||
|
||||
func (s *Server) shouldUseAttestedTLS() bool {
|
||||
return s.attestedTLSEnabled && s.certProvider != nil
|
||||
}
|
||||
|
||||
func (s *Server) shouldUseRegularTLS(config server.Config) bool {
|
||||
return config.CertFile != "" || config.KeyFile != ""
|
||||
}
|
||||
|
||||
func (s *Server) configureAttestedTLS(config server.Config) (grpc.ServerOption, error) {
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: s.certProvider.GetCertificate,
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, config.ServerCAFile, config.ClientCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
|
||||
}
|
||||
|
||||
return grpc.Creds(credentials.NewTLS(tlsConfig)), nil
|
||||
}
|
||||
|
||||
func (s *Server) configureRegularTLS(config server.Config) (grpc.ServerOption, error) {
|
||||
tlsSetup, err := server.SetupRegularTLS(config.CertFile, config.KeyFile, config.ServerCAFile, config.ClientCAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
if tlsSetup.MTLS {
|
||||
mtlsCA := server.BuildMTLSDescription(config.ServerCAFile, config.ClientCAFile)
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s",
|
||||
s.Name, s.Address, config.CertFile, config.KeyFile, mtlsCA))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s",
|
||||
s.Name, s.Address, config.CertFile, config.KeyFile))
|
||||
}
|
||||
|
||||
return grpc.Creds(credentials.NewTLS(tlsSetup.Config)), nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.stopped {
|
||||
return nil
|
||||
}
|
||||
s.stopped = true
|
||||
|
||||
defer s.Cancel()
|
||||
|
||||
c := make(chan bool)
|
||||
go func() {
|
||||
defer close(c)
|
||||
if s.health != nil {
|
||||
s.health.Shutdown()
|
||||
}
|
||||
if s.server != nil {
|
||||
s.server.GracefulStop()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(stopWaitTime):
|
||||
}
|
||||
|
||||
s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address))
|
||||
return nil
|
||||
}
|
||||
@@ -1,527 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const bufSize = 1024 * 1024
|
||||
|
||||
var lis *bufconn.Listener
|
||||
|
||||
func init() {
|
||||
lis = bufconn.Listen(bufSize)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "50051",
|
||||
},
|
||||
},
|
||||
}
|
||||
logger := slog.Default()
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
assert.NotNil(t, srv)
|
||||
assert.IsType(t, &Server{}, srv)
|
||||
}
|
||||
|
||||
func TestServerStartWithTLSFile(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
certFile, err := os.CreateTemp("", "cert*.pem")
|
||||
assert.NoError(t, err)
|
||||
|
||||
keyFile, err := os.CreateTemp("", "key*.pem")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(certFile.Name())
|
||||
os.Remove(keyFile.Name())
|
||||
})
|
||||
|
||||
_, err = certFile.Write(cert)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = keyFile.Write(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = certFile.Close()
|
||||
assert.NoError(t, err)
|
||||
err = keyFile.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: certFile.Name(),
|
||||
KeyFile: keyFile.Name(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
logContent := logBuffer.String()
|
||||
fmt.Println(logContent)
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
|
||||
}
|
||||
|
||||
func TestServerStartWithmTLSFile(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: string(clientCertFile),
|
||||
KeyFile: string(clientKeyFile),
|
||||
ServerCAFile: caCertFile,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
logContent := logBuffer.String()
|
||||
fmt.Println(logContent)
|
||||
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
|
||||
}
|
||||
|
||||
func TestServerStop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
config := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
},
|
||||
}
|
||||
buf := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
|
||||
|
||||
go func() {
|
||||
err := srv.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := srv.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Contains(t, buf.String(), "TestServer gRPC service shutdown at localhost:0")
|
||||
}
|
||||
|
||||
func generateSelfSignedCert() ([]byte, []byte, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cert, err := generateSelfSignedCertFromKey(key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return cert, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), nil
|
||||
}
|
||||
|
||||
func generateSelfSignedCertFromKey(key *rsa.PrivateKey) ([]byte, error) {
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().AddDate(1, 0, 0),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), nil
|
||||
}
|
||||
|
||||
type ThreadSafeBuffer struct {
|
||||
buffer strings.Builder
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (b *ThreadSafeBuffer) Write(p []byte) (n int, err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buffer.Write(p)
|
||||
}
|
||||
|
||||
func (b *ThreadSafeBuffer) String() string {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
return b.buffer.String()
|
||||
}
|
||||
|
||||
func TestServerInitializationAndStartup(t *testing.T) {
|
||||
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
config server.AgentConfig
|
||||
expectedLog string
|
||||
expectError bool
|
||||
setupCallback func(*testing.T, *server.AgentConfig, *ThreadSafeBuffer)
|
||||
}{
|
||||
{
|
||||
name: "Non-TLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS",
|
||||
},
|
||||
{
|
||||
name: "TLS Server Startup with Self-Signed Certificate",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
},
|
||||
},
|
||||
},
|
||||
setupCallback: setupTLSConfig,
|
||||
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
|
||||
},
|
||||
{
|
||||
name: "TLS Server Startup with Invalid Certificates",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
CertFile: "invalid",
|
||||
KeyFile: "invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
expectedLog: "failed to load auth certificates",
|
||||
},
|
||||
{
|
||||
name: "maTLS Server Startup",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "",
|
||||
ClientCAFile: "",
|
||||
},
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
setupCallback: setupMTLSConfig,
|
||||
expectError: false,
|
||||
expectedLog: "with Attested mTLS",
|
||||
},
|
||||
{
|
||||
name: "maTLS Server Startup with Invalid Server CA file",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
setupCallback: setupInvalidRootCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to load server ca file",
|
||||
},
|
||||
{
|
||||
name: "maTLS Server Startup with Invalid Clinet CA file",
|
||||
config: server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: "localhost",
|
||||
Port: "0",
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
},
|
||||
AttestedTLS: true,
|
||||
},
|
||||
setupCallback: setupInvalidClientCAConfig,
|
||||
expectError: true,
|
||||
expectedLog: "failed to load client ca file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if tc.setupCallback != nil {
|
||||
tc.setupCallback(t, &tc.config, nil)
|
||||
}
|
||||
|
||||
logBuffer := &ThreadSafeBuffer{}
|
||||
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
authSvc := new(authmocks.Authenticator)
|
||||
|
||||
mockCertProvider := new(mocks.CertificateProvider)
|
||||
|
||||
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, mockCertProvider)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
wg.Done()
|
||||
err := srv.Start()
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedLog)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
if !tc.expectError {
|
||||
logContent := logBuffer.String()
|
||||
fmt.Println(logContent)
|
||||
assert.Contains(t, logContent, tc.expectedLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
}
|
||||
|
||||
func setupMTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
config.ServerCAFile = string(cert)
|
||||
config.ClientCAFile = string(cert)
|
||||
}
|
||||
|
||||
func setupInvalidRootCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
config.ServerCAFile = "invalid"
|
||||
config.ClientCAFile = string(cert)
|
||||
}
|
||||
|
||||
func setupInvalidClientCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
|
||||
cert, key, err := generateSelfSignedCert()
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.CertFile = string(cert)
|
||||
config.KeyFile = string(key)
|
||||
config.ClientCAFile = "invalid"
|
||||
config.ServerCAFile = string(cert)
|
||||
}
|
||||
|
||||
func createCertificatesFiles() (string, string, string, error) {
|
||||
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
caTemplate := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER}))
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
clientTemplate := x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour * 24),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}))
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)}))
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
return caCertFile, clientCertFile, clientKeyFile, nil
|
||||
}
|
||||
|
||||
func createTempFile(data []byte) (string, error) {
|
||||
file, err := createTempFileHandle()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = file.Write(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = file.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return file.Name(), nil
|
||||
}
|
||||
|
||||
func createTempFileHandle() (*os.File, error) {
|
||||
return os.CreateTemp("", "test")
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
smqserver "github.com/absmach/supermq/pkg/server"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
)
|
||||
|
||||
const (
|
||||
httpProtocol = "http"
|
||||
httpsProtocol = "https"
|
||||
)
|
||||
|
||||
type httpServer struct {
|
||||
server.BaseServer
|
||||
|
||||
server *http.Server
|
||||
certProvider atls.CertificateProvider
|
||||
attestedTLSEnabled bool
|
||||
}
|
||||
|
||||
var _ server.Server = (*httpServer)(nil)
|
||||
|
||||
func NewServer(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
|
||||
handler http.Handler, logger *slog.Logger, certProvider atls.CertificateProvider,
|
||||
) server.Server {
|
||||
baseServer := server.NewBaseServer(ctx, cancel, name, config, logger)
|
||||
hserver := &http.Server{Addr: baseServer.Address, Handler: handler}
|
||||
|
||||
var attestedTLS bool
|
||||
|
||||
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
|
||||
if certProvider == nil {
|
||||
logger.Error("Failed to create certificate provider")
|
||||
} else {
|
||||
attestedTLS = true
|
||||
}
|
||||
}
|
||||
|
||||
return &httpServer{
|
||||
BaseServer: baseServer,
|
||||
server: hserver,
|
||||
certProvider: certProvider,
|
||||
attestedTLSEnabled: attestedTLS,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) Start() error {
|
||||
s.Protocol = httpProtocol
|
||||
|
||||
if s.shouldUseAttestedTLS() {
|
||||
return s.startWithAttestedTLS()
|
||||
}
|
||||
|
||||
if s.shouldUseRegularTLS() {
|
||||
return s.startWithRegularTLS()
|
||||
}
|
||||
|
||||
return s.startWithoutTLS()
|
||||
}
|
||||
|
||||
func (s *httpServer) Stop() error {
|
||||
defer s.Cancel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), smqserver.StopWaitTime)
|
||||
defer cancel()
|
||||
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
s.Logger.Error(fmt.Sprintf(
|
||||
"%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err))
|
||||
return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err)
|
||||
}
|
||||
|
||||
s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpServer) shouldUseAttestedTLS() bool {
|
||||
return s.attestedTLSEnabled && s.certProvider != nil
|
||||
}
|
||||
|
||||
func (s *httpServer) shouldUseRegularTLS() bool {
|
||||
return s.Config.GetBaseConfig().CertFile != "" || s.Config.GetBaseConfig().KeyFile != ""
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithAttestedTLS() error {
|
||||
tlsConfig := &tls.Config{
|
||||
ClientAuth: tls.NoClientCert,
|
||||
GetCertificate: s.certProvider.GetCertificate,
|
||||
}
|
||||
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
s.server.TLSConfig = tlsConfig
|
||||
s.Protocol = httpsProtocol
|
||||
|
||||
s.logAttestedTLSStart(mtls)
|
||||
return s.listenAndServe(true)
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithRegularTLS() error {
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
tlsSetup, err := server.SetupRegularTLS(baseConfig.CertFile, baseConfig.KeyFile, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
|
||||
s.server.TLSConfig = tlsSetup.Config
|
||||
s.Protocol = httpsProtocol
|
||||
|
||||
s.logRegularTLSStart(tlsSetup.MTLS)
|
||||
return s.listenAndServe(true)
|
||||
}
|
||||
|
||||
func (s *httpServer) startWithoutTLS() error {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address))
|
||||
return s.listenAndServe(false)
|
||||
}
|
||||
|
||||
func (s *httpServer) logAttestedTLSStart(mtls bool) {
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested mTLS", s.Name, s.Protocol, s.Address))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested TLS", s.Name, s.Protocol, s.Address))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) logRegularTLSStart(mtls bool) {
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
if mtls {
|
||||
s.Logger.Info(fmt.Sprintf(
|
||||
"%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s",
|
||||
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile,
|
||||
baseConfig.ServerCAFile, baseConfig.ClientCAFile))
|
||||
} else {
|
||||
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
|
||||
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpServer) listenAndServe(useTLS bool) error {
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if useTLS {
|
||||
cfg := s.Config.GetBaseConfig()
|
||||
errCh <- s.server.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
|
||||
} else {
|
||||
errCh <- s.server.ListenAndServe()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-s.Ctx.Done():
|
||||
return s.Stop()
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -1,411 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
)
|
||||
|
||||
// Mock implementations for testing.
|
||||
type mockHandler struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
m.Called(w, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("test response")); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
type mockBaseConfig struct {
|
||||
certFile string
|
||||
keyFile string
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
host string
|
||||
port string
|
||||
}
|
||||
|
||||
func (m *mockBaseConfig) GetCertFile() string { return m.certFile }
|
||||
func (m *mockBaseConfig) GetKeyFile() string { return m.keyFile }
|
||||
func (m *mockBaseConfig) GetServerCAFile() string { return m.serverCAFile }
|
||||
func (m *mockBaseConfig) GetClientCAFile() string { return m.clientCAFile }
|
||||
|
||||
type mockServerConfig struct {
|
||||
baseConfig *mockBaseConfig
|
||||
}
|
||||
|
||||
func (m *mockServerConfig) GetHost() string { return "localhost" }
|
||||
func (m *mockServerConfig) GetPort() string { return "8080" }
|
||||
func (m *mockServerConfig) GetBaseConfig() server.ServerConfig {
|
||||
return server.ServerConfig{Config: server.Config{CertFile: m.baseConfig.certFile, KeyFile: m.baseConfig.keyFile, ServerCAFile: m.baseConfig.serverCAFile, ClientCAFile: m.baseConfig.clientCAFile, Host: m.baseConfig.host, Port: m.baseConfig.port}}
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
name := "test-server"
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
logger := slog.Default()
|
||||
|
||||
server := NewServer(ctx, cancel, name, config, handler, logger, nil)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
httpSrv, ok := server.(*httpServer)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, httpSrv.server)
|
||||
assert.Equal(t, handler, httpSrv.server.Handler)
|
||||
}
|
||||
|
||||
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
|
||||
mockCertProvider := new(mocks.CertificateProvider)
|
||||
tests := []struct {
|
||||
name string
|
||||
config server.ServerConfiguration
|
||||
expected bool
|
||||
certProvider atls.CertificateProvider
|
||||
}{
|
||||
{
|
||||
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and certProvider is not empty",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: true,
|
||||
},
|
||||
certProvider: mockCertProvider,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when certProvider is empty",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: true,
|
||||
},
|
||||
certProvider: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when AttestedTLS is false",
|
||||
config: server.AgentConfig{
|
||||
AttestedTLS: false,
|
||||
},
|
||||
certProvider: mockCertProvider,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "should not use attested TLS when config is not AgentConfig",
|
||||
config: &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
},
|
||||
certProvider: mockCertProvider,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
|
||||
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.certProvider)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
result := httpSrv.shouldUseAttestedTLS()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
certFile string
|
||||
keyFile string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "should use regular TLS when both cert and key files are provided",
|
||||
certFile: "cert.pem",
|
||||
keyFile: "key.pem",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should use regular TLS when only cert file is provided",
|
||||
certFile: "cert.pem",
|
||||
keyFile: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should use regular TLS when only key file is provided",
|
||||
certFile: "",
|
||||
keyFile: "key.pem",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "should not use regular TLS when neither cert nor key files are provided",
|
||||
certFile: "",
|
||||
keyFile: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
certFile: tt.certFile,
|
||||
keyFile: tt.keyFile,
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
result := httpSrv.shouldUseRegularTLS()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_Stop(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Start a test server that we can control
|
||||
testServer := httptest.NewServer(handler)
|
||||
defer testServer.Close()
|
||||
|
||||
// Replace the server's HTTP server with our test server's
|
||||
httpSrv.server = testServer.Config
|
||||
|
||||
err := httpSrv.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHttpServer_logAttestedTLSStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mtls bool
|
||||
}{
|
||||
{
|
||||
name: "log attested mTLS start",
|
||||
mtls: true,
|
||||
},
|
||||
{
|
||||
name: "log attested TLS start",
|
||||
mtls: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// This test mainly ensures the method doesn't panic
|
||||
// In a real scenario, you might want to capture log output
|
||||
assert.NotPanics(t, func() {
|
||||
httpSrv.logAttestedTLSStart(tt.mtls)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_logRegularTLSStart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mtls bool
|
||||
}{
|
||||
{
|
||||
name: "log regular mTLS start",
|
||||
mtls: true,
|
||||
},
|
||||
{
|
||||
name: "log regular TLS start",
|
||||
mtls: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
certFile: "cert.pem",
|
||||
keyFile: "key.pem",
|
||||
serverCAFile: "server-ca.pem",
|
||||
clientCAFile: "client-ca.pem",
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// This test mainly ensures the method doesn't panic
|
||||
assert.NotPanics(t, func() {
|
||||
httpSrv.logRegularTLSStart(tt.mtls)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_startWithoutTLS(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Use a test server to avoid binding to actual ports
|
||||
testServer := httptest.NewServer(handler)
|
||||
defer testServer.Close()
|
||||
|
||||
httpSrv.server = testServer.Config
|
||||
|
||||
err := httpSrv.startWithoutTLS()
|
||||
// The error will be related to context cancellation or server shutdown
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHttpServer_Protocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupTLS func(*httpServer)
|
||||
expectedProto string
|
||||
}{
|
||||
{
|
||||
name: "HTTP protocol without TLS",
|
||||
setupTLS: func(s *httpServer) {
|
||||
s.Protocol = httpProtocol
|
||||
},
|
||||
expectedProto: httpProtocol,
|
||||
},
|
||||
{
|
||||
name: "HTTPS protocol with TLS",
|
||||
setupTLS: func(s *httpServer) {
|
||||
s.Protocol = httpsProtocol
|
||||
s.server.TLSConfig = &tls.Config{}
|
||||
},
|
||||
expectedProto: httpsProtocol,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
tt.setupTLS(httpSrv)
|
||||
|
||||
assert.Equal(t, tt.expectedProto, httpSrv.Protocol)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpServer_ContextCancellation(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Cancel the context immediately
|
||||
cancel()
|
||||
|
||||
// The listenAndServe method should handle context cancellation
|
||||
err := httpSrv.listenAndServe(false)
|
||||
assert.NoError(t, err) // Should return no error when context is cancelled and Stop() succeeds
|
||||
}
|
||||
|
||||
func TestHttpServer_TLSConfiguration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cancel := func() {}
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
certFile: "cert.pem",
|
||||
keyFile: "key.pem",
|
||||
},
|
||||
}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
|
||||
httpSrv := server.(*httpServer)
|
||||
|
||||
// Test TLS configuration setup
|
||||
httpSrv.server.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
assert.NotNil(t, httpSrv.server.TLSConfig)
|
||||
assert.Equal(t, uint16(tls.VersionTLS12), httpSrv.server.TLSConfig.MinVersion)
|
||||
}
|
||||
|
||||
// Integration-style test for server lifecycle.
|
||||
func TestHttpServer_Lifecycle(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
config := &mockServerConfig{
|
||||
baseConfig: &mockBaseConfig{
|
||||
host: "localhost",
|
||||
port: "8080",
|
||||
},
|
||||
}
|
||||
handler := &mockHandler{}
|
||||
|
||||
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
|
||||
|
||||
// Test that server can be created and has expected initial state
|
||||
httpSrv, ok := server.(*httpServer)
|
||||
require.True(t, ok)
|
||||
assert.NotNil(t, httpSrv.server)
|
||||
assert.Equal(t, "localhost:8080", httpSrv.server.Addr)
|
||||
|
||||
// Test Stop without Start (should not panic)
|
||||
err := httpSrv.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewServer creates a new instance of Server. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewServer(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Server {
|
||||
mock := &Server{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Server is an autogenerated mock type for the Server type
|
||||
type Server struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Server_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Server) EXPECT() *Server_Expecter {
|
||||
return &Server_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Start provides a mock function for the type Server
|
||||
func (_mock *Server) Start() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Start")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Server_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
|
||||
type Server_Start_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Start is a helper method to define mock.On call
|
||||
func (_e *Server_Expecter) Start() *Server_Start_Call {
|
||||
return &Server_Start_Call{Call: _e.mock.On("Start")}
|
||||
}
|
||||
|
||||
func (_c *Server_Start_Call) Run(run func()) *Server_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Server_Start_Call) Return(err error) *Server_Start_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Server_Start_Call) RunAndReturn(run func() error) *Server_Start_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function for the type Server
|
||||
func (_mock *Server) Stop() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Server_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type Server_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *Server_Expecter) Stop() *Server_Stop_Call {
|
||||
return &Server_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *Server_Stop_Call) Run(run func()) *Server_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Server_Stop_Call) Return(err error) *Server_Stop_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Server_Stop_Call) RunAndReturn(run func() error) *Server_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type ServerConfiguration interface {
|
||||
GetBaseConfig() ServerConfig
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Host string `env:"HOST" envDefault:"localhost"`
|
||||
Port string `env:"PORT" envDefault:"7001"`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
CertFile string `env:"SERVER_CERT" envDefault:""`
|
||||
KeyFile string `env:"SERVER_KEY" envDefault:""`
|
||||
ClientCAFile string `env:"CLIENT_CA_CERTS" envDefault:""`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Config
|
||||
}
|
||||
type AgentConfig struct {
|
||||
ServerConfig
|
||||
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
|
||||
}
|
||||
|
||||
type BaseServer struct {
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
Name string
|
||||
Address string
|
||||
Config ServerConfiguration
|
||||
Logger *slog.Logger
|
||||
Protocol string
|
||||
}
|
||||
|
||||
func (s ServerConfig) GetBaseConfig() ServerConfig {
|
||||
return s
|
||||
}
|
||||
|
||||
func (a AgentConfig) GetBaseConfig() ServerConfig {
|
||||
return a.ServerConfig
|
||||
}
|
||||
|
||||
func NewBaseServer(
|
||||
ctx context.Context, cancel context.CancelFunc, name string, config ServerConfiguration, logger *slog.Logger,
|
||||
) BaseServer {
|
||||
cfg := config.GetBaseConfig()
|
||||
address := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
|
||||
|
||||
return BaseServer{
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
Name: name,
|
||||
Address: address,
|
||||
Config: config,
|
||||
Logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, svcName string, servers ...Server) error {
|
||||
var err error
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT, syscall.SIGABRT)
|
||||
select {
|
||||
case sig := <-c:
|
||||
defer cancel()
|
||||
err = stopAllServer(servers...)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("%s service error during shutdown: %v", svcName, err))
|
||||
}
|
||||
logger.Info(fmt.Sprintf("%s service shutdown by signal: %s", svcName, sig))
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func stopAllServer(servers ...Server) error {
|
||||
var errs []error
|
||||
for _, server := range servers {
|
||||
if err := server.Stop(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("encountered errors while stopping servers: %v", errs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/server/mocks"
|
||||
)
|
||||
|
||||
func TestStopAllServer(t *testing.T) {
|
||||
server1 := new(mocks.Server)
|
||||
server2 := new(mocks.Server)
|
||||
server1.On("Stop").Return(nil)
|
||||
server2.On("Stop").Return(errors.New("failed to stop"))
|
||||
tests := []struct {
|
||||
name string
|
||||
servers []Server
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "All servers stop successfully",
|
||||
servers: []Server{
|
||||
server1,
|
||||
server1,
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "One server fails to stop",
|
||||
servers: []Server{
|
||||
server1,
|
||||
server2,
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "No servers",
|
||||
servers: []Server{},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := stopAllServer(tt.servers...)
|
||||
if (err != nil) != tt.expectedError {
|
||||
t.Errorf("stopAllServer() error = %v, expectedError %v", err, tt.expectedError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopHandler(t *testing.T) {
|
||||
mockServer := new(mocks.Server)
|
||||
mockServer.On("Stop").Return(nil)
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server)
|
||||
triggerSignal bool
|
||||
expectedError bool
|
||||
expectCanceled bool
|
||||
}{
|
||||
{
|
||||
name: "Graceful shutdown on signal",
|
||||
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
return ctx, cancel, logger, "test", []Server{mockServer}
|
||||
},
|
||||
triggerSignal: true,
|
||||
expectedError: false,
|
||||
expectCanceled: true,
|
||||
},
|
||||
{
|
||||
name: "Context canceled",
|
||||
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
return ctx, cancel, logger, "test", []Server{mockServer}
|
||||
},
|
||||
triggerSignal: false,
|
||||
expectedError: false,
|
||||
expectCanceled: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx, cancel, logger, svcName, servers := tt.setupFunc()
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- StopHandler(ctx, cancel, logger, svcName, servers...)
|
||||
}()
|
||||
|
||||
if tt.triggerSignal {
|
||||
// Simulate SIGINT
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err := syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
if err != nil {
|
||||
t.Errorf("failed to send signal: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if (err != nil) != tt.expectedError {
|
||||
t.Errorf("StopHandler() error = %v, expectedError %v", err, tt.expectedError)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("StopHandler() timed out")
|
||||
}
|
||||
|
||||
if tt.expectCanceled {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context was canceled as expected
|
||||
default:
|
||||
t.Error("Context was not canceled")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,741 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Helper function to generate a test certificate and key.
|
||||
func generateTestCert() (certPEM, keyPEM []byte, err error) {
|
||||
// Generate private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
Country: []string{"US"},
|
||||
Province: []string{""},
|
||||
Locality: []string{"Test City"},
|
||||
StreetAddress: []string{""},
|
||||
PostalCode: []string{""},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
IPAddresses: nil,
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Encode certificate
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
// Encode private key
|
||||
privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: privateKeyDER,
|
||||
})
|
||||
|
||||
return certPEM, keyPEM, nil
|
||||
}
|
||||
|
||||
// Helper function to create temporary files for testing.
|
||||
func createTempFile(t *testing.T, content []byte) string {
|
||||
tmpFile, err := os.CreateTemp("", "test-cert-*.pem")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
defer tmpFile.Close()
|
||||
|
||||
if _, err := tmpFile.Write(content); err != nil {
|
||||
t.Fatalf("Failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
return tmpFile.Name()
|
||||
}
|
||||
|
||||
func TestLoadCertFile(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certFile string
|
||||
wantErr bool
|
||||
setup func() string
|
||||
cleanup func(string)
|
||||
}{
|
||||
{
|
||||
name: "empty cert file path",
|
||||
certFile: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid cert file",
|
||||
wantErr: false,
|
||||
setup: func() string {
|
||||
return createTempFile(t, certPEM)
|
||||
},
|
||||
cleanup: func(path string) {
|
||||
os.Remove(path)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
certFile: "/non/existent/file.pem",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
certFile := tt.certFile
|
||||
if tt.setup != nil {
|
||||
certFile = tt.setup()
|
||||
}
|
||||
if tt.cleanup != nil {
|
||||
defer tt.cleanup(certFile)
|
||||
}
|
||||
|
||||
data, err := LoadCertFile(certFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadCertFile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.certFile != "" && !tt.wantErr && len(data) == 0 {
|
||||
t.Errorf("LoadCertFile() with valid file should return data, got empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileOrData(t *testing.T) {
|
||||
testData := "test certificate data"
|
||||
tempFile := createTempFile(t, []byte(testData))
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "file path",
|
||||
input: tempFile,
|
||||
want: testData,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "raw data with newlines",
|
||||
input: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
|
||||
want: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short raw data without newlines",
|
||||
input: "short data",
|
||||
want: "short data",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file path",
|
||||
input: "/non/existent/file.pem",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
want: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ReadFileOrData(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && string(got) != tt.want {
|
||||
t.Errorf("ReadFileOrData() = %v, want %v", string(got), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadX509KeyPair(t *testing.T) {
|
||||
certPEM, keyPEM, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
certFile := createTempFile(t, certPEM)
|
||||
keyFile := createTempFile(t, keyPEM)
|
||||
defer os.Remove(certFile)
|
||||
defer os.Remove(keyFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certfile string
|
||||
keyfile string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid cert and key files",
|
||||
certfile: certFile,
|
||||
keyfile: keyFile,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid cert and key data",
|
||||
certfile: string(certPEM),
|
||||
keyfile: string(keyPEM),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent cert file",
|
||||
certfile: "/non/existent/cert.pem",
|
||||
keyfile: keyFile,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent key file",
|
||||
certfile: certFile,
|
||||
keyfile: "/non/existent/key.pem",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid cert data",
|
||||
certfile: "invalid cert data",
|
||||
keyfile: string(keyPEM),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid key data",
|
||||
certfile: string(certPEM),
|
||||
keyfile: "invalid key data",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cert, err := LoadX509KeyPair(tt.certfile, tt.keyfile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("LoadX509KeyPair() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && len(cert.Certificate) == 0 {
|
||||
t.Errorf("LoadX509KeyPair() returned empty certificate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureRootCA(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer os.Remove(caFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
serverCAFile string
|
||||
wantErr bool
|
||||
expectCA bool
|
||||
}{
|
||||
{
|
||||
name: "valid CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
wantErr: false,
|
||||
expectCA: true,
|
||||
},
|
||||
{
|
||||
name: "valid CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: string(certPEM),
|
||||
wantErr: false,
|
||||
expectCA: true,
|
||||
},
|
||||
{
|
||||
name: "empty CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "",
|
||||
wantErr: false,
|
||||
expectCA: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "/non/existent/ca.pem",
|
||||
wantErr: true,
|
||||
expectCA: false,
|
||||
},
|
||||
{
|
||||
name: "invalid CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "invalid ca data",
|
||||
wantErr: true,
|
||||
expectCA: false,
|
||||
},
|
||||
{
|
||||
name: "existing RootCAs pool",
|
||||
tlsConfig: &tls.Config{RootCAs: x509.NewCertPool()},
|
||||
serverCAFile: caFile,
|
||||
wantErr: false,
|
||||
expectCA: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ConfigureRootCA(tt.tlsConfig, tt.serverCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ConfigureRootCA() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.expectCA && tt.tlsConfig.RootCAs == nil {
|
||||
t.Errorf("ConfigureRootCA() should have created RootCAs pool")
|
||||
}
|
||||
|
||||
if !tt.expectCA && tt.tlsConfig.RootCAs != nil && tt.serverCAFile == "" {
|
||||
t.Errorf("ConfigureRootCA() should not have created RootCAs pool for empty file")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureClientCA(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer os.Remove(caFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
clientCAFile string
|
||||
wantConfigured bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid client CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: caFile,
|
||||
wantConfigured: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid client CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: string(certPEM),
|
||||
wantConfigured: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty client CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: "",
|
||||
wantConfigured: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent client CA file",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: "/non/existent/ca.pem",
|
||||
wantConfigured: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid client CA data",
|
||||
tlsConfig: &tls.Config{},
|
||||
clientCAFile: "invalid ca data",
|
||||
wantConfigured: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "existing ClientCAs pool",
|
||||
tlsConfig: &tls.Config{ClientCAs: x509.NewCertPool()},
|
||||
clientCAFile: caFile,
|
||||
wantConfigured: true,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configured, err := ConfigureClientCA(tt.tlsConfig, tt.clientCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ConfigureClientCA() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if configured != tt.wantConfigured {
|
||||
t.Errorf("ConfigureClientCA() configured = %v, want %v", configured, tt.wantConfigured)
|
||||
}
|
||||
|
||||
if tt.wantConfigured && tt.tlsConfig.ClientCAs == nil {
|
||||
t.Errorf("ConfigureClientCA() should have created ClientCAs pool")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureCertificateAuthorities(t *testing.T) {
|
||||
certPEM, _, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer os.Remove(caFile)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
wantMTLS bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "both server and client CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "only server CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "only client CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "",
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no CAs",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid server CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: "/non/existent/server-ca.pem",
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid client CA",
|
||||
tlsConfig: &tls.Config{},
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: "/non/existent/client-ca.pem",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mtls, err := ConfigureCertificateAuthorities(tt.tlsConfig, tt.serverCAFile, tt.clientCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ConfigureCertificateAuthorities() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if mtls != tt.wantMTLS {
|
||||
t.Errorf("ConfigureCertificateAuthorities() mtls = %v, want %v", mtls, tt.wantMTLS)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupRegularTLS(t *testing.T) {
|
||||
certPEM, keyPEM, err := generateTestCert()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test cert: %v", err)
|
||||
}
|
||||
|
||||
certFile := createTempFile(t, certPEM)
|
||||
keyFile := createTempFile(t, keyPEM)
|
||||
caFile := createTempFile(t, certPEM)
|
||||
defer func() {
|
||||
os.Remove(certFile)
|
||||
os.Remove(keyFile)
|
||||
os.Remove(caFile)
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certFile string
|
||||
keyFile string
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
wantMTLS bool
|
||||
wantErr bool
|
||||
expectedAuth tls.ClientAuthType
|
||||
}{
|
||||
{
|
||||
name: "regular TLS without mTLS",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "TLS with mTLS",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: caFile,
|
||||
wantMTLS: true,
|
||||
wantErr: false,
|
||||
expectedAuth: tls.RequireAndVerifyClientCert,
|
||||
},
|
||||
{
|
||||
name: "TLS with only server CA",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: caFile,
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: false,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "invalid certificate file",
|
||||
certFile: "/non/existent/cert.pem",
|
||||
keyFile: keyFile,
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "invalid key file",
|
||||
certFile: certFile,
|
||||
keyFile: "/non/existent/key.pem",
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
{
|
||||
name: "invalid server CA file",
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
serverCAFile: "/non/existent/server-ca.pem",
|
||||
clientCAFile: "",
|
||||
wantMTLS: false,
|
||||
wantErr: true,
|
||||
expectedAuth: tls.NoClientCert,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := SetupRegularTLS(tt.certFile, tt.keyFile, tt.serverCAFile, tt.clientCAFile)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SetupRegularTLS() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("SetupRegularTLS() returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
if result.MTLS != tt.wantMTLS {
|
||||
t.Errorf("SetupRegularTLS() MTLS = %v, want %v", result.MTLS, tt.wantMTLS)
|
||||
}
|
||||
|
||||
if result.Config.ClientAuth != tt.expectedAuth {
|
||||
t.Errorf("SetupRegularTLS() ClientAuth = %v, want %v", result.Config.ClientAuth, tt.expectedAuth)
|
||||
}
|
||||
|
||||
if len(result.Config.Certificates) == 0 {
|
||||
t.Errorf("SetupRegularTLS() should have at least one certificate")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMTLSDescription(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverCAFile string
|
||||
clientCAFile string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "both server and client CA files",
|
||||
serverCAFile: "/path/to/server-ca.pem",
|
||||
clientCAFile: "/path/to/client-ca.pem",
|
||||
want: "root ca /path/to/server-ca.pem client ca /path/to/client-ca.pem",
|
||||
},
|
||||
{
|
||||
name: "only server CA file",
|
||||
serverCAFile: "/path/to/server-ca.pem",
|
||||
clientCAFile: "",
|
||||
want: "root ca /path/to/server-ca.pem",
|
||||
},
|
||||
{
|
||||
name: "only client CA file",
|
||||
serverCAFile: "",
|
||||
clientCAFile: "/path/to/client-ca.pem",
|
||||
want: "client ca /path/to/client-ca.pem",
|
||||
},
|
||||
{
|
||||
name: "no CA files",
|
||||
serverCAFile: "",
|
||||
clientCAFile: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := BuildMTLSDescription(tt.serverCAFile, tt.clientCAFile)
|
||||
if got != tt.want {
|
||||
t.Errorf("BuildMTLSDescription() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorConstants(t *testing.T) {
|
||||
// Test that error constants are properly defined
|
||||
if ErrAppendServerCA == nil {
|
||||
t.Error("ErrAppendServerCA should not be nil")
|
||||
}
|
||||
|
||||
if ErrAppendClientCA == nil {
|
||||
t.Error("ErrAppendClientCA should not be nil")
|
||||
}
|
||||
|
||||
if ErrAppendServerCA.Error() != "failed to append server ca to tls.Config" {
|
||||
t.Errorf("ErrAppendServerCA message = %v, want 'failed to append server ca to tls.Config'", ErrAppendServerCA.Error())
|
||||
}
|
||||
|
||||
if ErrAppendClientCA.Error() != "failed to append client ca to tls.Config" {
|
||||
t.Errorf("ErrAppendClientCA message = %v, want 'failed to append client ca to tls.Config'", ErrAppendClientCA.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSSetupResult(t *testing.T) {
|
||||
// Test that TLSSetupResult struct works as expected
|
||||
config := &tls.Config{}
|
||||
result := &TLSSetupResult{
|
||||
Config: config,
|
||||
MTLS: true,
|
||||
}
|
||||
|
||||
if result.Config != config {
|
||||
t.Error("TLSSetupResult Config field should match assigned value")
|
||||
}
|
||||
|
||||
if !result.MTLS {
|
||||
t.Error("TLSSetupResult MTLS field should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileOrDataEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "999 chars without newline (should try file)",
|
||||
input: strings.Repeat("a", 999),
|
||||
wantErr: true, // Should fail as file doesn't exist
|
||||
},
|
||||
{
|
||||
name: "1001 chars without newline (should treat as data)",
|
||||
input: strings.Repeat("a", 1001),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "short string with newline (should treat as data)",
|
||||
input: "short\ndata",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ReadFileOrData(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+6
-8
@@ -15,12 +15,12 @@ import (
|
||||
"strings"
|
||||
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
smqserver "github.com/absmach/supermq/pkg/server"
|
||||
grpcserver "github.com/absmach/supermq/pkg/server/grpc"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -314,24 +314,22 @@ func main() {
|
||||
reflection.Register(srv)
|
||||
cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger}))
|
||||
}
|
||||
grpcServerConfig := server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Port: defaultPort,
|
||||
},
|
||||
grpcServerConfig := smqserver.Config{
|
||||
Port: defaultPort,
|
||||
}
|
||||
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
return
|
||||
}
|
||||
|
||||
gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, nil)
|
||||
gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger)
|
||||
|
||||
g.Go(func() error {
|
||||
return gs.Start()
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return server.StopHandler(ctx, cancel, logger, svcName, gs)
|
||||
return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user