diff --git a/agent/cvms/server/cvm.go b/agent/cvms/server/cvm.go index 2c989e64..3212ba9c 100644 --- a/agent/cvms/server/cvm.go +++ b/agent/cvms/server/cvm.go @@ -4,17 +4,17 @@ package server import ( - context "context" "fmt" "log/slog" + "net" + "os" "github.com/ultravioletrs/cocos/agent" agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc" "github.com/ultravioletrs/cocos/agent/auth" - "github.com/ultravioletrs/cocos/pkg/atls" - "github.com/ultravioletrs/cocos/pkg/server" - grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" ) @@ -29,55 +29,68 @@ type AgentServer interface { } type agentServer struct { - gs server.Server - logger *slog.Logger - svc agent.Service - host string - certProvider atls.CertificateProvider + gs *grpc.Server + logger *slog.Logger + svc agent.Service + host string } -func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider atls.CertificateProvider) AgentServer { +func NewServer(logger *slog.Logger, svc agent.Service, host string) AgentServer { return &agentServer{ - logger: logger, - svc: svc, - host: host, - certProvider: certProvider, + logger: logger, + svc: svc, + host: host, } } func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error { - agentGrpcServerConfig := server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: defSvcGRPCSocket, - Port: "", - CertFile: cfg.CertFile, - KeyFile: cfg.KeyFile, - ServerCAFile: cfg.ServerCAFile, - ClientCAFile: cfg.ClientCAFile, - }, - }, - AttestedTLS: cfg.AttestedTls, - } - - registerAgentServiceServer := func(srv *grpc.Server) { - reflection.Register(srv) - agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(as.svc)) - } - authSvc, err := auth.New(cmp) if err != nil { as.logger.WithGroup(cmp.ID).Error(fmt.Sprintf("failed to create auth service %s", err.Error())) return err } - ctx, cancel := context.WithCancel(context.Background()) + grpcServerOptions := []grpc.ServerOption{ + grpc.StatsHandler(otelgrpc.NewServerHandler()), + } - as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.certProvider) + // Add authentication interceptors + unary, stream := agentgrpc.NewAuthInterceptor(authSvc) + grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary)) + grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream)) + + // Internal Unix socket is pure plaintext HTTP/2; Ingress Proxy handles external aTLS termination + grpcServerOptions = append(grpcServerOptions, grpc.Creds(insecure.NewCredentials())) + + as.gs = grpc.NewServer(grpcServerOptions...) + + reflection.Register(as.gs) + agent.RegisterAgentServiceServer(as.gs, agentgrpc.NewServer(as.svc)) + + socketPath := as.host + if socketPath == "" || socketPath == "0.0.0.0" { + socketPath = defSvcGRPCSocket + } + + var listener net.Listener + if socketPath[0] == '/' || socketPath[0] == '.' { + // Remove existing socket file if it exists + _ = os.Remove(socketPath) + listener, err = net.Listen("unix", socketPath) + } else { + listener, err = net.Listen("tcp", socketPath) + } + + if err != nil { + as.logger.Error(fmt.Sprintf("failed to listen on %s: %s", socketPath, err)) + return err + } + + as.logger.Info(fmt.Sprintf("agent service gRPC server listening at %s without TLS", socketPath)) go func() { - err := as.gs.Start() - if err != nil { + err := as.gs.Serve(listener) + if err != nil && err != grpc.ErrServerStopped { as.logger.Error(fmt.Sprintf("failed to start grpc server %s", err.Error())) } }() @@ -86,8 +99,8 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error } func (as *agentServer) Stop() error { - if as.gs == nil { - return nil + if as.gs != nil { + as.gs.GracefulStop() } - return as.gs.Stop() + return nil } diff --git a/agent/cvms/server/cvm_test.go b/agent/cvms/server/cvm_test.go index ddfbc2a5..1d1a8ad2 100644 --- a/agent/cvms/server/cvm_test.go +++ b/agent/cvms/server/cvm_test.go @@ -21,7 +21,7 @@ import ( func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, []byte) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError})) mockSvc := new(mocks.Service) - host := "localhost" + host := "localhost:0" privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.NoError(t, err, "Failed to generate ECDSA key") @@ -70,7 +70,7 @@ func TestNewServer(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := NewServer(tt.logger, tt.svc, tt.host, nil) + server := NewServer(tt.logger, tt.svc, tt.host) assert.NotNil(t, server) @@ -194,7 +194,7 @@ func TestAgentServer_Start(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tt.setupMocks(svc) - server := NewServer(logger, svc, host, nil) + server := NewServer(logger, svc, host) err := server.Start(tt.cfg, tt.cmp) @@ -268,7 +268,7 @@ func TestAgentServer_Stop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := NewServer(logger, svc, host, nil) + server := NewServer(logger, svc, host) err := tt.setupServer(server) if err != nil { @@ -296,7 +296,7 @@ func TestAgentServer_Stop(t *testing.T) { func TestAgentServer_StopMultipleTimes(t *testing.T) { logger, svc, host, pubKey := setupTest(t) - server := NewServer(logger, svc, host, nil) + server := NewServer(logger, svc, host) // Start the server cfg := agent.AgentConfig{} @@ -340,7 +340,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) { func TestAgentServer_StartAfterStop(t *testing.T) { logger, svc, host, pubKey := setupTest(t) - server := NewServer(logger, svc, host, nil) + server := NewServer(logger, svc, host) cfg := agent.AgentConfig{} cmp := agent.Computation{ @@ -488,7 +488,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - server := NewServer(logger, svc, host, nil) + server := NewServer(logger, svc, host) err := server.Start(tt.config, tt.cmp) diff --git a/algorithm.lin-reg-py.enc b/algorithm.lin-reg-py.enc deleted file mode 100644 index 3e2a8f20..00000000 Binary files a/algorithm.lin-reg-py.enc and /dev/null differ diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 2b400a5b..e0e3b100 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -59,7 +59,6 @@ type config struct { AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"` AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"` AttestationServiceSocket string `env:"ATTESTATION_SERVICE_SOCKET" envDefault:"/run/cocos/attestation.sock"` - EnableATLS bool `env:"AGENT_ENABLE_ATLS" envDefault:"true"` } func main() { @@ -139,6 +138,7 @@ func main() { }) ccPlatform := attestation.CCPlatform() + logger.Info(fmt.Sprintf("Detected confidential computing platform: %v", ccPlatform)) azureConfig := azure.NewEnvConfigFromAgent( cfg.AgentOSBuild, @@ -209,7 +209,8 @@ func main() { } var certProvider atls.CertificateProvider - if cfg.EnableATLS && ccPlatform != attestation.NoCC { + if ccPlatform != attestation.NoCC { + logger.Info(fmt.Sprintf("Initializing aTLS for platform %v with attestation service at %s", ccPlatform, cfg.AttestationServiceSocket)) var certsSDK sdk.SDK if cfg.CAUrl != "" { certsSDK = sdk.NewSDK(sdk.Config{ @@ -218,10 +219,12 @@ func main() { } certProvider, err = atls.NewProvider(attClient, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK) if err != nil { - logger.Error(fmt.Sprintf("failed to create certificate provider: %s", err)) - exitCode = 1 - return + logger.Error(fmt.Sprintf("failed to create certificate provider for aTLS: %s. Continuing without attested TLS.", err)) + } else { + logger.Info("Successfully created aTLS certificate provider") } + } else { + logger.Warn("No Confidential Computing platform detected (NoCC). Certificate provider remains nil; aTLS will not be available for computations.") } // Create ingress proxy server @@ -240,7 +243,7 @@ func main() { return } - mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), ingressProxy, storageDir, reconnectFn, cvmGRPCClient) + mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost), ingressProxy, storageDir, reconnectFn, cvmGRPCClient) if err != nil { logger.Error(err.Error()) exitCode = 1 diff --git a/cmd/computation-runner/main.go b/cmd/computation-runner/main.go index 1ded568b..db124c4f 100644 --- a/cmd/computation-runner/main.go +++ b/cmd/computation-runner/main.go @@ -13,9 +13,12 @@ import ( mglog "github.com/absmach/supermq/logger" "github.com/caarlos0/env/v11" + "github.com/ultravioletrs/cocos/agent/cvms" + logpb "github.com/ultravioletrs/cocos/agent/log" pb "github.com/ultravioletrs/cocos/agent/runner" runnerevents "github.com/ultravioletrs/cocos/agent/runner/events" "github.com/ultravioletrs/cocos/agent/runner/service" + agentlogger "github.com/ultravioletrs/cocos/internal/logger" logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -51,16 +54,43 @@ func main() { return } - logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level})) + logQueue := make(chan *cvms.ClientStreamMessage, 1000) + handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue) + logger := slog.New(handler) // Connect to Log Forwarder logClient, err := logclient.NewClient(cfg.LogForwarder) if err != nil { - logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Events will not be forwarded.", err)) + logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs and events will not be forwarded.", err)) } else { defer logClient.Close() } + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case msg := <-logQueue: + if logClient == nil { + continue + } + switch m := msg.Message.(type) { + case *cvms.ClientStreamMessage_AgentLog: + err := logClient.SendLog(ctx, &logpb.LogEntry{ + Message: m.AgentLog.Message, + ComputationId: m.AgentLog.ComputationId, + Level: m.AgentLog.Level, + Timestamp: m.AgentLog.Timestamp, + }) + if err != nil { + logger.Error("failed to send log", "error", err) + } + } + } + } + }) + eventSvc := runnerevents.NewAdapter(logClient, svcName) // Remove existing socket if it exists diff --git a/cmd/egress-proxy/main.go b/cmd/egress-proxy/main.go index cd2aec5b..d0d27093 100644 --- a/cmd/egress-proxy/main.go +++ b/cmd/egress-proxy/main.go @@ -5,14 +5,18 @@ package main import ( "context" "fmt" + "log/slog" "os" "os/signal" "syscall" - mglog "github.com/absmach/supermq/logger" "github.com/caarlos0/env/v11" "github.com/spf13/cobra" "github.com/spf13/pflag" + "github.com/ultravioletrs/cocos/agent/cvms" + logpb "github.com/ultravioletrs/cocos/agent/log" + agentlogger "github.com/ultravioletrs/cocos/internal/logger" + logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log" "github.com/ultravioletrs/cocos/pkg/egress" "golang.org/x/sync/errgroup" ) @@ -22,8 +26,9 @@ const ( ) type config struct { - Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"` - Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"` + Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"` + Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"` + LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"` } func main() { @@ -51,9 +56,20 @@ func main() { } func run(cfg config) error { - logger, err := mglog.New(os.Stdout, cfg.Level) + var level slog.Level + if err := level.UnmarshalText([]byte(cfg.Level)); err != nil { + return fmt.Errorf("invalid log level: %w", err) + } + + logQueue := make(chan *cvms.ClientStreamMessage, 1000) + handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue) + logger := slog.New(handler) + + logClient, err := logclient.NewClient(cfg.LogForwarder) if err != nil { - return fmt.Errorf("failed to create logger: %w", err) + logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs will not be forwarded.", err)) + } else { + defer logClient.Close() } ctx, cancel := context.WithCancel(context.Background()) @@ -61,6 +77,31 @@ func run(cfg config) error { g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case msg := <-logQueue: + if logClient == nil { + continue + } + switch m := msg.Message.(type) { + case *cvms.ClientStreamMessage_AgentLog: + err := logClient.SendLog(ctx, &logpb.LogEntry{ + Message: m.AgentLog.Message, + ComputationId: m.AgentLog.ComputationId, + Level: m.AgentLog.Level, + Timestamp: m.AgentLog.Timestamp, + }) + if err != nil { + logger.Error("failed to send log", "error", err) + } + } + } + } + }) + proxy := egress.NewProxy(logger, ":"+cfg.Port) g.Go(func() error { diff --git a/cmd/ingress-proxy/main.go b/cmd/ingress-proxy/main.go index fcb2a435..180b18a3 100644 --- a/cmd/ingress-proxy/main.go +++ b/cmd/ingress-proxy/main.go @@ -5,20 +5,24 @@ package main import ( "context" "fmt" + "log/slog" "net/url" "os" "os/signal" "syscall" "github.com/absmach/certs/sdk" - mglog "github.com/absmach/supermq/logger" "github.com/caarlos0/env/v11" "github.com/spf13/cobra" "github.com/spf13/pflag" + "github.com/ultravioletrs/cocos/agent/cvms" + logpb "github.com/ultravioletrs/cocos/agent/log" + agentlogger "github.com/ultravioletrs/cocos/internal/logger" "github.com/ultravioletrs/cocos/pkg/atls" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/azure" attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" + logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log" "github.com/ultravioletrs/cocos/pkg/ingress" "golang.org/x/sync/errgroup" ) @@ -39,6 +43,7 @@ type config struct { AgentOSBuild string `env:"AGENT_OS_BUILD" envDefault:"UVC"` AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"` AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"` + LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"` } func main() { @@ -66,11 +71,52 @@ func main() { } func run(cfg config) error { - logger, err := mglog.New(os.Stdout, cfg.LogLevel) - if err != nil { - return fmt.Errorf("failed to create logger: %w", err) + var level slog.Level + if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil { + return fmt.Errorf("invalid log level: %w", err) } + logQueue := make(chan *cvms.ClientStreamMessage, 1000) + handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue) + logger := slog.New(handler) + + logClient, err := logclient.NewClient(cfg.LogForwarder) + if err != nil { + logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs will not be forwarded.", err)) + } else { + defer logClient.Close() + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case msg := <-logQueue: + if logClient == nil { + continue + } + switch m := msg.Message.(type) { + case *cvms.ClientStreamMessage_AgentLog: + err := logClient.SendLog(ctx, &logpb.LogEntry{ + Message: m.AgentLog.Message, + ComputationId: m.AgentLog.ComputationId, + Level: m.AgentLog.Level, + Timestamp: m.AgentLog.Timestamp, + }) + if err != nil { + logger.Error("failed to send log", "error", err) + } + } + } + } + }) + backendURL, err := url.Parse(cfg.Backend) if err != nil { return fmt.Errorf("failed to parse backend URL: %w", err) @@ -111,11 +157,6 @@ func run(cfg config) error { logger.Warn("No Confidential Computing platform detected. ATLS will not be available.") } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - g, ctx := errgroup.WithContext(ctx) - // Create proxy server (but don't start it yet - it will be started per-computation) _ = ingress.NewProxyServer(logger, backendURL, certProvider) diff --git a/cmd/manager/main.go b/cmd/manager/main.go index b6185b2e..249875ea 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -16,6 +16,7 @@ import ( "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/prometheus" smqserver "github.com/absmach/supermq/pkg/server" + grpcserver "github.com/absmach/supermq/pkg/server/grpc" httpserver "github.com/absmach/supermq/pkg/server/http" "github.com/absmach/supermq/pkg/uuid" "github.com/caarlos0/env/v11" @@ -26,8 +27,6 @@ import ( "github.com/ultravioletrs/cocos/manager/api/http" "github.com/ultravioletrs/cocos/manager/qemu" "github.com/ultravioletrs/cocos/manager/tracing" - "github.com/ultravioletrs/cocos/pkg/server" - grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -113,7 +112,7 @@ func main() { args := qemuCfg.ConstructQemuArgs() logger.Info(strings.Join(args, " ")) - managerGRPCConfig := server.ServerConfig{} + managerGRPCConfig := smqserver.Config{} if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) exitCode = 1 @@ -145,7 +144,7 @@ func main() { manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc)) } - gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil) + gs := grpcserver.NewServer(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger) hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, http.MakeHandler(chi.NewMux(), svcName, cfg.InstanceID), logger) @@ -158,7 +157,7 @@ func main() { }) g.Go(func() error { - return server.StopHandler(ctx, cancel, logger, svcName, gs, hs) + return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs, hs) }) if err := g.Wait(); err != nil { diff --git a/encryption.key b/encryption.key deleted file mode 100644 index 840cae66..00000000 --- a/encryption.key +++ /dev/null @@ -1 +0,0 @@ -bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3 diff --git a/hal/linux/package/agent/agent.mk b/hal/linux/package/agent/agent.mk index e4d8ab53..66a9eb87 100644 --- a/hal/linux/package/agent/agent.mk +++ b/hal/linux/package/agent/agent.mk @@ -4,8 +4,8 @@ # ################################################################################ -AGENT_VERSION = 913bbccf3a22053e1979da004c732007336fc890 -AGENT_SITE = $(call github,sammyoina,cocos-ai,$(AGENT_VERSION)) +AGENT_VERSION = main +AGENT_SITE = $(call github,ultravioletrs,cocos,$(AGENT_VERSION)) define AGENT_BUILD_CMDS $(MAKE) -C $(@D) agent EMBED_ENABLED=$(AGENT_EMBED_ENABLED) diff --git a/hal/linux/package/attestation-service/attestation-service.mk b/hal/linux/package/attestation-service/attestation-service.mk index 70983079..aae86836 100644 --- a/hal/linux/package/attestation-service/attestation-service.mk +++ b/hal/linux/package/attestation-service/attestation-service.mk @@ -4,8 +4,8 @@ # ################################################################################ -ATTESTATION_SERVICE_VERSION = 913bbccf3a22053e1979da004c732007336fc890 -ATTESTATION_SERVICE_SITE = $(call github,sammyoina,cocos-ai,$(ATTESTATION_SERVICE_VERSION)) +ATTESTATION_SERVICE_VERSION = main +ATTESTATION_SERVICE_SITE = $(call github,ultravioletrs,cocos,$(ATTESTATION_SERVICE_VERSION)) define ATTESTATION_SERVICE_BUILD_CMDS $(MAKE) -C $(@D) attestation-service diff --git a/hal/linux/package/computation-runner/computation-runner.mk b/hal/linux/package/computation-runner/computation-runner.mk index d1a40319..1850695a 100644 --- a/hal/linux/package/computation-runner/computation-runner.mk +++ b/hal/linux/package/computation-runner/computation-runner.mk @@ -4,8 +4,8 @@ # ################################################################################ -COMPUTATION_RUNNER_VERSION = 913bbccf3a22053e1979da004c732007336fc890 -COMPUTATION_RUNNER_SITE = $(call github,sammyoina,cocos-ai,$(COMPUTATION_RUNNER_VERSION)) +COMPUTATION_RUNNER_VERSION = main +COMPUTATION_RUNNER_SITE = $(call github,ultravioletrs,cocos,$(COMPUTATION_RUNNER_VERSION)) define COMPUTATION_RUNNER_BUILD_CMDS $(MAKE) -C $(@D) computation-runner diff --git a/hal/linux/package/egress-proxy/egress-proxy.mk b/hal/linux/package/egress-proxy/egress-proxy.mk index 34cab06c..eb5e5866 100644 --- a/hal/linux/package/egress-proxy/egress-proxy.mk +++ b/hal/linux/package/egress-proxy/egress-proxy.mk @@ -4,8 +4,8 @@ # ################################################################################ -EGRESS_PROXY_VERSION = 913bbccf3a22053e1979da004c732007336fc890 -EGRESS_PROXY_SITE = $(call github,sammyoina,cocos-ai,$(EGRESS_PROXY_VERSION)) +EGRESS_PROXY_VERSION = main +EGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(EGRESS_PROXY_VERSION)) define EGRESS_PROXY_BUILD_CMDS $(MAKE) -C $(@D) egress-proxy diff --git a/hal/linux/package/ingress-proxy/ingress-proxy.mk b/hal/linux/package/ingress-proxy/ingress-proxy.mk index b9604106..e407a572 100644 --- a/hal/linux/package/ingress-proxy/ingress-proxy.mk +++ b/hal/linux/package/ingress-proxy/ingress-proxy.mk @@ -4,8 +4,8 @@ # ################################################################################ -INGRESS_PROXY_VERSION = 913bbccf3a22053e1979da004c732007336fc890 -INGRESS_PROXY_SITE = $(call github,sammyoina,cocos-ai,$(INGRESS_PROXY_VERSION)) +INGRESS_PROXY_VERSION = main +INGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(INGRESS_PROXY_VERSION)) define INGRESS_PROXY_BUILD_CMDS $(MAKE) -C $(@D) ingress-proxy diff --git a/hal/linux/package/log-forwarder/log-forwarder.mk b/hal/linux/package/log-forwarder/log-forwarder.mk index 0c2af471..47545880 100644 --- a/hal/linux/package/log-forwarder/log-forwarder.mk +++ b/hal/linux/package/log-forwarder/log-forwarder.mk @@ -4,8 +4,8 @@ # ################################################################################ -LOG_FORWARDER_VERSION = 913bbccf3a22053e1979da004c732007336fc890 -LOG_FORWARDER_SITE = $(call github,sammyoina,cocos-ai,$(LOG_FORWARDER_VERSION)) +LOG_FORWARDER_VERSION = main +LOG_FORWARDER_SITE = $(call github,ultravioletrs,cocos,$(LOG_FORWARDER_VERSION)) define LOG_FORWARDER_BUILD_CMDS $(MAKE) -C $(@D) log-forwarder diff --git a/init/systemd/cocos-agent.service b/init/systemd/cocos-agent.service index 602cf681..ef74eaf3 100644 --- a/init/systemd/cocos-agent.service +++ b/init/systemd/cocos-agent.service @@ -9,7 +9,6 @@ WorkingDirectory=/cocos Environment="HTTP_PROXY=http://localhost:3128" Environment="HTTPS_PROXY=http://localhost:3128" Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/" -Environment="AGENT_ENABLE_ATLS=false" StandardOutput=file:/var/log/cocos/agent.stdout StandardError=file:/var/log/cocos/agent.stderr EnvironmentFile=/etc/cocos/environment diff --git a/kbs-admin.key b/kbs-admin.key deleted file mode 100644 index 62418798..00000000 --- a/kbs-admin.key +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN PRIVATE KEY----- -MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY ------END PRIVATE KEY----- diff --git a/kbs-admin.pub b/kbs-admin.pub deleted file mode 100644 index 8fc02406..00000000 --- a/kbs-admin.pub +++ /dev/null @@ -1,3 +0,0 @@ ------BEGIN PUBLIC KEY----- -MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw= ------END PUBLIC KEY----- diff --git a/pkg/atls/certificate_verifier.go b/pkg/atls/certificate_verifier.go index 4eb07f0e..49928c4b 100644 --- a/pkg/atls/certificate_verifier.go +++ b/pkg/atls/certificate_verifier.go @@ -6,6 +6,7 @@ import ( "crypto/x509" "encoding/asn1" "fmt" + "log/slog" "os" "time" @@ -36,20 +37,35 @@ func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier { } func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error { + slog.Debug("Starting peer certificate verification for aTLS") if len(rawCerts) == 0 { - return fmt.Errorf("no certificates provided") + err := fmt.Errorf("no certificates provided") + slog.Error("aTLS handshake failed", "reason", err.Error()) + return err } cert, err := x509.ParseCertificate(rawCerts[0]) if err != nil { - return fmt.Errorf("failed to parse x509 certificate: %w", err) + err = fmt.Errorf("failed to parse x509 certificate: %w", err) + slog.Error("aTLS handshake failed", "reason", err.Error()) + return err } + slog.Debug("Successfully parsed peer x509 certificate", "subject", cert.Subject.String()) if err := v.verifyCertificateSignature(cert); err != nil { - return fmt.Errorf("certificate signature verification failed: %w", err) + err = fmt.Errorf("certificate signature verification failed: %w", err) + slog.Error("aTLS handshake failed", "reason", err.Error()) + return err } + slog.Debug("Successfully verified peer certificate signature") - return v.verifyAttestationExtension(cert, nonce) + err = v.verifyAttestationExtension(cert, nonce) + if err != nil { + slog.Error("aTLS handshake failed", "reason", err.Error()) + return err + } + slog.Debug("Successfully verified aTLS attestation extension") + return nil } func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error { @@ -71,6 +87,7 @@ func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error { for _, ext := range cert.Extensions { if platformType, err := platformTypeFromOID(ext.Id); err == nil { + slog.Debug("Found attestation extension in peer certificate", "platform_type", platformType) pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey) if err != nil { return fmt.Errorf("failed to marshal public key: %w", err) @@ -93,22 +110,29 @@ func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKe // Verify nonce matches teeNonce := append(pubKey, nonce...) hashNonce := sha3.Sum512(teeNonce) + // The attestation provider truncates the 64-byte hash to 32 bytes for the EAT token nonce claim + // This matches the Attestation Service API and standard cryptographic nonce sizes. + expectedNonce := hashNonce[:32] // Compare nonces (EAT nonce should match our computed nonce) - if len(claims.Nonce) != len(hashNonce) { - return fmt.Errorf("nonce length mismatch: expected %d, got %d", len(hashNonce), len(claims.Nonce)) + if len(claims.Nonce) != len(expectedNonce) { + err := fmt.Errorf("nonce length mismatch: expected %d, got %d", len(expectedNonce), len(claims.Nonce)) + slog.Error("aTLS handshake failed", "reason", err.Error()) + return err } nonceMatch := true for i := range claims.Nonce { - if claims.Nonce[i] != hashNonce[i] { + if claims.Nonce[i] != expectedNonce[i] { nonceMatch = false break } } if !nonceMatch { - return fmt.Errorf("nonce mismatch in EAT token") + err := fmt.Errorf("nonce mismatch in EAT token") + slog.Error("aTLS handshake failed", "reason", err.Error()) + return err } // Get platform verifier @@ -151,6 +175,7 @@ func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKe return fmt.Errorf("failed to verify attestation with CoRIM: %w", err) } + slog.Debug("CoRIM verification passed for aTLS peer certificate") return nil } diff --git a/pkg/atls/certificate_verifier_test.go b/pkg/atls/certificate_verifier_test.go index e4cc374b..cdf7ed4c 100644 --- a/pkg/atls/certificate_verifier_test.go +++ b/pkg/atls/certificate_verifier_test.go @@ -79,7 +79,7 @@ func TestVerifyPeerCertificate_Success(t *testing.T) { hashNonce := sha3.Sum512(teeNonce) claims := eat.EATClaims{ - Nonce: hashNonce[:], + Nonce: hashNonce[:32], RawReport: []byte("mock-report"), } eatBytes, err := cbor.Marshal(claims) @@ -155,7 +155,7 @@ func TestVerifyPeerCertificate_AzureSuccess(t *testing.T) { teeNonce := append(peerPubKeyDER, nonce...) hashNonce := sha3.Sum512(teeNonce) - claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")} + claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")} eatBytes, _ := cbor.Marshal(claims) peerTemplate := &x509.Certificate{ @@ -210,7 +210,7 @@ func TestVerifyPeerCertificate_TDXSuccess(t *testing.T) { teeNonce := append(peerPubKeyDER, nonce...) hashNonce := sha3.Sum512(teeNonce) - claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")} + claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")} eatBytes, _ := cbor.Marshal(claims) peerTemplate := &x509.Certificate{ @@ -273,7 +273,7 @@ func TestVerifyPeerCertificate_Failures_More(t *testing.T) { nonce := []byte("nonce") teeNonce := append(peerPubKeyDER, nonce...) hashNonce := sha3.Sum512(teeNonce) - claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")} + claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")} eatBytes, _ := cbor.Marshal(claims) peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}} certDERWithExt, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) @@ -323,7 +323,7 @@ func TestVerifyPeerCertificate_Failures_Ext(t *testing.T) { peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey) wrongTeeNonce := append(peerPubKeyDER, []byte("wrong-nonce")...) wrongHashNonce := sha3.Sum512(wrongTeeNonce) - claims.Nonce = wrongHashNonce[:] + claims.Nonce = wrongHashNonce[:32] eatBytes, _ = cbor.Marshal(claims) peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}} certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey) diff --git a/pkg/clients/grpc/log/client.go b/pkg/clients/grpc/log/client.go index 1dc01d6b..a6c49e37 100644 --- a/pkg/clients/grpc/log/client.go +++ b/pkg/clients/grpc/log/client.go @@ -45,7 +45,7 @@ func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error { } // Retry with exponential backoff for concurrent request handling - maxRetries := 3 + maxRetries := 10 for attempt := 0; attempt < maxRetries; attempt++ { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) _, err := c.client.SendLog(ctx, entry) @@ -57,8 +57,11 @@ func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error { // Don't retry on last attempt if attempt < maxRetries-1 { - // Exponential backoff: 10ms, 20ms, 40ms - backoff := time.Duration(10*(1< 2*time.Second { + backoff = 2 * time.Second + } time.Sleep(backoff) } } @@ -76,7 +79,7 @@ func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error { } // Retry with exponential backoff for concurrent request handling - maxRetries := 3 + maxRetries := 10 for attempt := 0; attempt < maxRetries; attempt++ { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) _, err := c.client.SendEvent(ctx, entry) @@ -88,8 +91,11 @@ func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error { // Don't retry on last attempt if attempt < maxRetries-1 { - // Exponential backoff: 10ms, 20ms, 40ms - backoff := time.Duration(10*(1< 2*time.Second { + backoff = 2 * time.Second + } time.Sleep(backoff) } } diff --git a/pkg/clients/grpc/log/client_test.go b/pkg/clients/grpc/log/client_test.go index a1d932f7..3ad4d380 100644 --- a/pkg/clients/grpc/log/client_test.go +++ b/pkg/clients/grpc/log/client_test.go @@ -488,8 +488,8 @@ func TestClientSendLogAllRetriesFail(t *testing.T) { grpcServer := grpc.NewServer() mockServer := &retryMockLogCollectorServer{ - failCount: 10, // Fail all attempts - maxFailCount: 10, + failCount: 20, // Fail all attempts + maxFailCount: 20, } log.RegisterLogCollectorServer(grpcServer, mockServer) @@ -513,8 +513,8 @@ func TestClientSendLogAllRetriesFail(t *testing.T) { // Should fail after all retries err = client.SendLog(ctx, entry) assert.Error(t, err) - // 3 retries + 1 final attempt = 4 calls - assert.Equal(t, 4, mockServer.callCount) + // 10 retries + 1 final attempt = 11 calls + assert.Equal(t, 11, mockServer.callCount) } func TestClientSendEventAllRetriesFail(t *testing.T) { @@ -527,8 +527,8 @@ func TestClientSendEventAllRetriesFail(t *testing.T) { grpcServer := grpc.NewServer() mockServer := &retryMockLogCollectorServer{ - failCount: 10, - maxFailCount: 10, + failCount: 20, + maxFailCount: 20, } log.RegisterLogCollectorServer(grpcServer, mockServer) @@ -551,8 +551,8 @@ func TestClientSendEventAllRetriesFail(t *testing.T) { // Should fail after all retries err = client.SendEvent(ctx, entry) assert.Error(t, err) - // 3 retries + 1 final attempt = 4 calls - assert.Equal(t, 4, mockServer.eventCallCount) + // 10 retries + 1 final attempt = 11 calls + assert.Equal(t, 11, mockServer.eventCallCount) } // retryMockLogCollectorServer is a mock server that fails a specified number of times. diff --git a/pkg/ingress/proxy.go b/pkg/ingress/proxy.go index fb0deeef..678fa8bb 100644 --- a/pkg/ingress/proxy.go +++ b/pkg/ingress/proxy.go @@ -14,7 +14,6 @@ import ( "sync" "github.com/ultravioletrs/cocos/pkg/atls" - "github.com/ultravioletrs/cocos/pkg/server" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) @@ -138,7 +137,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { if cfg.AttestedTLS { if p.certProvider == nil { - return fmt.Errorf("attested TLS requested but no certificate provider available") + return fmt.Errorf("attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running") } tlsConfig = &tls.Config{ GetCertificate: p.certProvider.GetCertificate, @@ -146,7 +145,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { NextProtos: []string{"h2", "http/1.1"}, } - mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile) + mtls, err := ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile) if err != nil { return fmt.Errorf("failed to configure certificate authorities: %w", err) } @@ -159,7 +158,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { } } else if cfg.CertFile != "" && cfg.KeyFile != "" { // Regular TLS - tlsSetup, err := server.SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile) + tlsSetup, err := SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile) if err != nil { return fmt.Errorf("failed to setup TLS: %w", err) } diff --git a/pkg/ingress/proxy_test.go b/pkg/ingress/proxy_test.go index 7c2eda5f..4b4f84f7 100644 --- a/pkg/ingress/proxy_test.go +++ b/pkg/ingress/proxy_test.go @@ -402,7 +402,7 @@ func TestProxyAttestedTLSMissingProvider(t *testing.T) { err := ps.Start(cfg, ctx) assert.Error(t, err) - assert.Equal(t, "attested TLS requested but no certificate provider available", err.Error()) + assert.Equal(t, "attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running", err.Error()) } func TestProxyMTLS(t *testing.T) { diff --git a/pkg/server/tlsutil.go b/pkg/ingress/tls.go similarity index 90% rename from pkg/server/tlsutil.go rename to pkg/ingress/tls.go index c7dbb7d2..e4424502 100644 --- a/pkg/server/tlsutil.go +++ b/pkg/ingress/tls.go @@ -1,7 +1,7 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 -package server +package ingress import ( "crypto/tls" @@ -144,18 +144,3 @@ func SetupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*TLS return &TLSSetupResult{Config: tlsConfig, MTLS: mtls}, nil } - -// BuildMTLSDescription builds a description string for mTLS configuration. -func BuildMTLSDescription(serverCAFile, clientCAFile string) string { - var parts []string - - if serverCAFile != "" { - parts = append(parts, fmt.Sprintf("root ca %s", serverCAFile)) - } - - if clientCAFile != "" { - parts = append(parts, fmt.Sprintf("client ca %s", clientCAFile)) - } - - return strings.Join(parts, " ") -} diff --git a/pkg/server/doc.go b/pkg/server/doc.go deleted file mode 100644 index 4c14fc39..00000000 --- a/pkg/server/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -// Package server contains the gRPC server implementation. -package server diff --git a/pkg/server/grpc/doc.go b/pkg/server/grpc/doc.go deleted file mode 100644 index 323efa1b..00000000 --- a/pkg/server/grpc/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -// Package grpc contains the gRPC server implementation. -package grpc diff --git a/pkg/server/grpc/grpc.go b/pkg/server/grpc/grpc.go deleted file mode 100644 index f6d92018..00000000 --- a/pkg/server/grpc/grpc.go +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -package grpc - -import ( - "context" - "crypto/tls" - "fmt" - "log/slog" - "net" - "os" - "sync" - "time" - - agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc" - "github.com/ultravioletrs/cocos/agent/auth" - "github.com/ultravioletrs/cocos/pkg/atls" - "github.com/ultravioletrs/cocos/pkg/server" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/health" - grpchealth "google.golang.org/grpc/health/grpc_health_v1" -) - -const ( - stopWaitTime = 5 * time.Second -) - -type Server struct { - server.BaseServer - mu sync.RWMutex - server *grpc.Server - health *health.Server - registerService serviceRegister - authSvc auth.Authenticator - certProvider atls.CertificateProvider - attestedTLSEnabled bool - started bool - stopped bool -} - -type serviceRegister func(srv *grpc.Server) - -var _ server.Server = (*Server)(nil) - -func New( - ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, - registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, certProvider atls.CertificateProvider, -) server.Server { - base := config.GetBaseConfig() - listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port) - - var attestedTLS bool - - if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS { - if certProvider == nil { - logger.Error("Failed to create certificate provider") - } else { - attestedTLS = true - } - } - - return &Server{ - BaseServer: server.BaseServer{ - Ctx: ctx, - Cancel: cancel, - Name: name, - Address: listenFullAddress, - Config: config, - Logger: logger, - }, - registerService: registerService, - authSvc: authSvc, - certProvider: certProvider, - attestedTLSEnabled: attestedTLS, - } -} - -func (s *Server) Start() error { - s.mu.Lock() - if s.started { - s.mu.Unlock() - return fmt.Errorf("server already started") - } - if s.stopped { - s.mu.Unlock() - return fmt.Errorf("server already stopped") - } - s.started = true - s.mu.Unlock() - - errCh := make(chan error) - grpcServerOptions := []grpc.ServerOption{ - grpc.StatsHandler(otelgrpc.NewServerHandler()), - } - - // Add authentication interceptors if auth service is available - if s.authSvc != nil { - unary, stream := agentgrpc.NewAuthInterceptor(s.authSvc) - grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary)) - grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream)) - } - - // Configure credentials - creds, err := s.configureCredentials() - if err != nil { - return fmt.Errorf("failed to configure credentials: %w", err) - } - - grpcServerOptions = append(grpcServerOptions, creds) - - // Create listener - detect Unix socket vs TCP - var listener net.Listener - baseConfig := s.Config.GetBaseConfig() - - // Check if this is a Unix socket path (starts with /) - if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' { - // Unix socket - // Remove existing socket file if it exists - _ = os.Remove(baseConfig.Host) - - listener, err = net.Listen("unix", baseConfig.Host) - if err != nil { - return fmt.Errorf("failed to listen on Unix socket %s: %w", baseConfig.Host, err) - } - } else { - // TCP socket - listener, err = net.Listen("tcp", s.Address) - if err != nil { - return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) - } - } - - // Create and configure server - s.mu.Lock() - s.server = grpc.NewServer(grpcServerOptions...) - s.health = health.NewServer() - grpchealth.RegisterHealthServer(s.server, s.health) - s.registerService(s.server) - s.health.SetServingStatus(s.Name, grpchealth.HealthCheckResponse_SERVING) - s.mu.Unlock() - - // Start server - go func() { - s.mu.RLock() - server := s.server - s.mu.RUnlock() - - if server != nil { - errCh <- server.Serve(listener) - } - }() - - select { - case <-s.Ctx.Done(): - return s.Stop() - case err := <-errCh: - s.Cancel() - return err - } -} - -func (s *Server) configureCredentials() (grpc.ServerOption, error) { - baseConfig := s.Config.GetBaseConfig() - - // Check if attested TLS should be used - if s.shouldUseAttestedTLS() { - return s.configureAttestedTLS(baseConfig.Config) - } - - // Check if regular TLS should be used - if s.shouldUseRegularTLS(baseConfig.Config) { - return s.configureRegularTLS(baseConfig.Config) - } - - // Use insecure credentials - // Determine address for logging - addr := s.Address - if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' { - addr = baseConfig.Host - } - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, addr)) - return grpc.Creds(insecure.NewCredentials()), nil -} - -func (s *Server) shouldUseAttestedTLS() bool { - return s.attestedTLSEnabled && s.certProvider != nil -} - -func (s *Server) shouldUseRegularTLS(config server.Config) bool { - return config.CertFile != "" || config.KeyFile != "" -} - -func (s *Server) configureAttestedTLS(config server.Config) (grpc.ServerOption, error) { - tlsConfig := &tls.Config{ - ClientAuth: tls.NoClientCert, - GetCertificate: s.certProvider.GetCertificate, - } - - mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, config.ServerCAFile, config.ClientCAFile) - if err != nil { - return nil, fmt.Errorf("failed to configure certificate authorities: %w", err) - } - - if mtls { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address)) - } else { - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address)) - } - - return grpc.Creds(credentials.NewTLS(tlsConfig)), nil -} - -func (s *Server) configureRegularTLS(config server.Config) (grpc.ServerOption, error) { - tlsSetup, err := server.SetupRegularTLS(config.CertFile, config.KeyFile, config.ServerCAFile, config.ClientCAFile) - if err != nil { - return nil, fmt.Errorf("failed to setup TLS: %w", err) - } - - if tlsSetup.MTLS { - mtlsCA := server.BuildMTLSDescription(config.ServerCAFile, config.ClientCAFile) - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s", - s.Name, s.Address, config.CertFile, config.KeyFile, mtlsCA)) - } else { - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s", - s.Name, s.Address, config.CertFile, config.KeyFile)) - } - - return grpc.Creds(credentials.NewTLS(tlsSetup.Config)), nil -} - -func (s *Server) Stop() error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.stopped { - return nil - } - s.stopped = true - - defer s.Cancel() - - c := make(chan bool) - go func() { - defer close(c) - if s.health != nil { - s.health.Shutdown() - } - if s.server != nil { - s.server.GracefulStop() - } - }() - - select { - case <-c: - case <-time.After(stopWaitTime): - } - - s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address)) - return nil -} diff --git a/pkg/server/grpc/grpc_test.go b/pkg/server/grpc/grpc_test.go deleted file mode 100644 index e5e96d98..00000000 --- a/pkg/server/grpc/grpc_test.go +++ /dev/null @@ -1,527 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package grpc - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "log/slog" - "math/big" - "os" - "strings" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks" - "github.com/ultravioletrs/cocos/pkg/atls/mocks" - "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" - "github.com/ultravioletrs/cocos/pkg/server" - "google.golang.org/grpc" - "google.golang.org/grpc/test/bufconn" -) - -const bufSize = 1024 * 1024 - -var lis *bufconn.Listener - -func init() { - lis = bufconn.Listen(bufSize) -} - -func TestNew(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - config := server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "50051", - }, - }, - } - logger := slog.Default() - authSvc := new(authmocks.Authenticator) - - srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil) - - assert.NotNil(t, srv) - assert.IsType(t, &Server{}, srv) -} - -func TestServerStartWithTLSFile(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - certFile, err := os.CreateTemp("", "cert*.pem") - assert.NoError(t, err) - - keyFile, err := os.CreateTemp("", "key*.pem") - assert.NoError(t, err) - - t.Cleanup(func() { - os.Remove(certFile.Name()) - os.Remove(keyFile.Name()) - }) - - _, err = certFile.Write(cert) - assert.NoError(t, err) - - _, err = keyFile.Write(key) - assert.NoError(t, err) - - err = certFile.Close() - assert.NoError(t, err) - err = keyFile.Close() - assert.NoError(t, err) - - config := server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - CertFile: certFile.Name(), - KeyFile: keyFile.Name(), - }, - }, - } - - logBuffer := &ThreadSafeBuffer{} - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - authSvc := new(authmocks.Authenticator) - - srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil) - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - wg.Done() - err := srv.Start() - assert.NoError(t, err) - }() - - wg.Wait() - - time.Sleep(200 * time.Millisecond) - - cancel() - - time.Sleep(200 * time.Millisecond) - - logContent := logBuffer.String() - fmt.Println(logContent) - assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") -} - -func TestServerStartWithmTLSFile(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles() - assert.NoError(t, err) - - config := server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - CertFile: string(clientCertFile), - KeyFile: string(clientKeyFile), - ServerCAFile: caCertFile, - }, - }, - } - - logBuffer := &ThreadSafeBuffer{} - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - authSvc := new(authmocks.Authenticator) - - srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil) - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - wg.Done() - err := srv.Start() - assert.NoError(t, err) - }() - - wg.Wait() - - time.Sleep(200 * time.Millisecond) - - cancel() - - time.Sleep(200 * time.Millisecond) - - logContent := logBuffer.String() - fmt.Println(logContent) - assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS") -} - -func TestServerStop(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - config := server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - }, - }, - } - buf := &ThreadSafeBuffer{} - logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - authSvc := new(authmocks.Authenticator) - - srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil) - - go func() { - err := srv.Start() - assert.NoError(t, err) - }() - - time.Sleep(100 * time.Millisecond) - - cancel() - - time.Sleep(100 * time.Millisecond) - - err := srv.Stop() - assert.NoError(t, err) - - assert.Contains(t, buf.String(), "TestServer gRPC service shutdown at localhost:0") -} - -func generateSelfSignedCert() ([]byte, []byte, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - cert, err := generateSelfSignedCertFromKey(key) - if err != nil { - return nil, nil, err - } - - return cert, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), nil -} - -func generateSelfSignedCertFromKey(key *rsa.PrivateKey) ([]byte, error) { - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(1, 0, 0), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - return nil, err - } - - return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), nil -} - -type ThreadSafeBuffer struct { - buffer strings.Builder - mu sync.Mutex -} - -func (b *ThreadSafeBuffer) Write(p []byte) (n int, err error) { - b.mu.Lock() - defer b.mu.Unlock() - return b.buffer.Write(p) -} - -func (b *ThreadSafeBuffer) String() string { - b.mu.Lock() - defer b.mu.Unlock() - return b.buffer.String() -} - -func TestServerInitializationAndStartup(t *testing.T) { - vtpm.ExternalTPM = &vtpm.DummyRWC{} - - testCases := []struct { - name string - config server.AgentConfig - expectedLog string - expectError bool - setupCallback func(*testing.T, *server.AgentConfig, *ThreadSafeBuffer) - }{ - { - name: "Non-TLS Server Startup", - config: server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - }, - }, - }, - expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS", - }, - { - name: "TLS Server Startup with Self-Signed Certificate", - config: server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - }, - }, - }, - setupCallback: setupTLSConfig, - expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS", - }, - { - name: "TLS Server Startup with Invalid Certificates", - config: server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - CertFile: "invalid", - KeyFile: "invalid", - }, - }, - }, - expectError: true, - expectedLog: "failed to load auth certificates", - }, - { - name: "maTLS Server Startup", - config: server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "", - ClientCAFile: "", - }, - }, - AttestedTLS: true, - }, - setupCallback: setupMTLSConfig, - expectError: false, - expectedLog: "with Attested mTLS", - }, - { - name: "maTLS Server Startup with Invalid Server CA file", - config: server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "invalid", - }, - }, - AttestedTLS: true, - }, - setupCallback: setupInvalidRootCAConfig, - expectError: true, - expectedLog: "failed to load server ca file", - }, - { - name: "maTLS Server Startup with Invalid Clinet CA file", - config: server.AgentConfig{ - ServerConfig: server.ServerConfig{ - Config: server.Config{ - Host: "localhost", - Port: "0", - ServerCAFile: "invalid", - }, - }, - AttestedTLS: true, - }, - setupCallback: setupInvalidClientCAConfig, - expectError: true, - expectedLog: "failed to load client ca file", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - if tc.setupCallback != nil { - tc.setupCallback(t, &tc.config, nil) - } - - logBuffer := &ThreadSafeBuffer{} - logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - authSvc := new(authmocks.Authenticator) - - mockCertProvider := new(mocks.CertificateProvider) - - srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, mockCertProvider) - var wg sync.WaitGroup - wg.Add(1) - - go func() { - wg.Done() - err := srv.Start() - if tc.expectError { - assert.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedLog) - } else { - assert.NoError(t, err) - } - }() - - wg.Wait() - - time.Sleep(200 * time.Millisecond) - - cancel() - - time.Sleep(200 * time.Millisecond) - - if !tc.expectError { - logContent := logBuffer.String() - fmt.Println(logContent) - assert.Contains(t, logContent, tc.expectedLog) - } - }) - } -} - -func setupTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - config.CertFile = string(cert) - config.KeyFile = string(key) -} - -func setupMTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - config.CertFile = string(cert) - config.KeyFile = string(key) - config.ServerCAFile = string(cert) - config.ClientCAFile = string(cert) -} - -func setupInvalidRootCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - config.CertFile = string(cert) - config.KeyFile = string(key) - config.ServerCAFile = "invalid" - config.ClientCAFile = string(cert) -} - -func setupInvalidClientCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) { - cert, key, err := generateSelfSignedCert() - assert.NoError(t, err) - - config.CertFile = string(cert) - config.KeyFile = string(key) - config.ClientCAFile = "invalid" - config.ServerCAFile = string(cert) -} - -func createCertificatesFiles() (string, string, string, error) { - caKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return "", "", "", err - } - - caTemplate := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Org"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour * 24), - KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - BasicConstraintsValid: true, - IsCA: true, - } - - caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey) - if err != nil { - return "", "", "", err - } - - caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER})) - if err != nil { - return "", "", "", err - } - - clientKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return "", "", "", err - } - - clientTemplate := x509.Certificate{ - SerialNumber: big.NewInt(2), - Subject: pkix.Name{ - Organization: []string{"Test Org"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour * 24), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, - BasicConstraintsValid: true, - } - - clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey) - if err != nil { - return "", "", "", err - } - - clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER})) - if err != nil { - return "", "", "", err - } - - clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)})) - if err != nil { - return "", "", "", err - } - - return caCertFile, clientCertFile, clientKeyFile, nil -} - -func createTempFile(data []byte) (string, error) { - file, err := createTempFileHandle() - if err != nil { - return "", err - } - - _, err = file.Write(data) - if err != nil { - return "", err - } - - err = file.Close() - if err != nil { - return "", err - } - - return file.Name(), nil -} - -func createTempFileHandle() (*os.File, error) { - return os.CreateTemp("", "test") -} diff --git a/pkg/server/http/http.go b/pkg/server/http/http.go deleted file mode 100644 index f2c43385..00000000 --- a/pkg/server/http/http.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -package http - -import ( - "context" - "crypto/tls" - "fmt" - "log/slog" - "net/http" - - smqserver "github.com/absmach/supermq/pkg/server" - "github.com/ultravioletrs/cocos/pkg/atls" - "github.com/ultravioletrs/cocos/pkg/server" -) - -const ( - httpProtocol = "http" - httpsProtocol = "https" -) - -type httpServer struct { - server.BaseServer - - server *http.Server - certProvider atls.CertificateProvider - attestedTLSEnabled bool -} - -var _ server.Server = (*httpServer)(nil) - -func NewServer( - ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, - handler http.Handler, logger *slog.Logger, certProvider atls.CertificateProvider, -) server.Server { - baseServer := server.NewBaseServer(ctx, cancel, name, config, logger) - hserver := &http.Server{Addr: baseServer.Address, Handler: handler} - - var attestedTLS bool - - if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS { - if certProvider == nil { - logger.Error("Failed to create certificate provider") - } else { - attestedTLS = true - } - } - - return &httpServer{ - BaseServer: baseServer, - server: hserver, - certProvider: certProvider, - attestedTLSEnabled: attestedTLS, - } -} - -func (s *httpServer) Start() error { - s.Protocol = httpProtocol - - if s.shouldUseAttestedTLS() { - return s.startWithAttestedTLS() - } - - if s.shouldUseRegularTLS() { - return s.startWithRegularTLS() - } - - return s.startWithoutTLS() -} - -func (s *httpServer) Stop() error { - defer s.Cancel() - - ctx, cancel := context.WithTimeout(context.Background(), smqserver.StopWaitTime) - defer cancel() - - if err := s.server.Shutdown(ctx); err != nil { - s.Logger.Error(fmt.Sprintf( - "%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err)) - return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err) - } - - s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address)) - return nil -} - -func (s *httpServer) shouldUseAttestedTLS() bool { - return s.attestedTLSEnabled && s.certProvider != nil -} - -func (s *httpServer) shouldUseRegularTLS() bool { - return s.Config.GetBaseConfig().CertFile != "" || s.Config.GetBaseConfig().KeyFile != "" -} - -func (s *httpServer) startWithAttestedTLS() error { - tlsConfig := &tls.Config{ - ClientAuth: tls.NoClientCert, - GetCertificate: s.certProvider.GetCertificate, - } - - baseConfig := s.Config.GetBaseConfig() - mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, baseConfig.ServerCAFile, baseConfig.ClientCAFile) - if err != nil { - return fmt.Errorf("failed to configure certificate authorities: %w", err) - } - - if mtls { - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - } - - s.server.TLSConfig = tlsConfig - s.Protocol = httpsProtocol - - s.logAttestedTLSStart(mtls) - return s.listenAndServe(true) -} - -func (s *httpServer) startWithRegularTLS() error { - baseConfig := s.Config.GetBaseConfig() - tlsSetup, err := server.SetupRegularTLS(baseConfig.CertFile, baseConfig.KeyFile, baseConfig.ServerCAFile, baseConfig.ClientCAFile) - if err != nil { - return fmt.Errorf("failed to setup TLS: %w", err) - } - - s.server.TLSConfig = tlsSetup.Config - s.Protocol = httpsProtocol - - s.logRegularTLSStart(tlsSetup.MTLS) - return s.listenAndServe(true) -} - -func (s *httpServer) startWithoutTLS() error { - s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address)) - return s.listenAndServe(false) -} - -func (s *httpServer) logAttestedTLSStart(mtls bool) { - if mtls { - s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested mTLS", s.Name, s.Protocol, s.Address)) - } else { - s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested TLS", s.Name, s.Protocol, s.Address)) - } -} - -func (s *httpServer) logRegularTLSStart(mtls bool) { - baseConfig := s.Config.GetBaseConfig() - if mtls { - s.Logger.Info(fmt.Sprintf( - "%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s", - s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile, - baseConfig.ServerCAFile, baseConfig.ClientCAFile)) - } else { - s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s", - s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile)) - } -} - -func (s *httpServer) listenAndServe(useTLS bool) error { - errCh := make(chan error, 1) - - go func() { - if useTLS { - cfg := s.Config.GetBaseConfig() - errCh <- s.server.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile) - } else { - errCh <- s.server.ListenAndServe() - } - }() - - select { - case <-s.Ctx.Done(): - return s.Stop() - case err := <-errCh: - return err - } -} diff --git a/pkg/server/http/http_test.go b/pkg/server/http/http_test.go deleted file mode 100644 index 3529032b..00000000 --- a/pkg/server/http/http_test.go +++ /dev/null @@ -1,411 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -package http - -import ( - "context" - "crypto/tls" - "log/slog" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/ultravioletrs/cocos/pkg/atls" - "github.com/ultravioletrs/cocos/pkg/atls/mocks" - "github.com/ultravioletrs/cocos/pkg/server" -) - -// Mock implementations for testing. -type mockHandler struct { - mock.Mock -} - -func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - m.Called(w, r) - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test response")); err != nil { - panic(err) - } -} - -type mockBaseConfig struct { - certFile string - keyFile string - serverCAFile string - clientCAFile string - host string - port string -} - -func (m *mockBaseConfig) GetCertFile() string { return m.certFile } -func (m *mockBaseConfig) GetKeyFile() string { return m.keyFile } -func (m *mockBaseConfig) GetServerCAFile() string { return m.serverCAFile } -func (m *mockBaseConfig) GetClientCAFile() string { return m.clientCAFile } - -type mockServerConfig struct { - baseConfig *mockBaseConfig -} - -func (m *mockServerConfig) GetHost() string { return "localhost" } -func (m *mockServerConfig) GetPort() string { return "8080" } -func (m *mockServerConfig) GetBaseConfig() server.ServerConfig { - return server.ServerConfig{Config: server.Config{CertFile: m.baseConfig.certFile, KeyFile: m.baseConfig.keyFile, ServerCAFile: m.baseConfig.serverCAFile, ClientCAFile: m.baseConfig.clientCAFile, Host: m.baseConfig.host, Port: m.baseConfig.port}} -} - -func TestNewServer(t *testing.T) { - ctx := context.Background() - cancel := func() {} - name := "test-server" - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - } - handler := &mockHandler{} - logger := slog.Default() - - server := NewServer(ctx, cancel, name, config, handler, logger, nil) - - assert.NotNil(t, server) - httpSrv, ok := server.(*httpServer) - require.True(t, ok) - assert.NotNil(t, httpSrv.server) - assert.Equal(t, handler, httpSrv.server.Handler) -} - -func TestHttpServer_shouldUseAttestedTLS(t *testing.T) { - mockCertProvider := new(mocks.CertificateProvider) - tests := []struct { - name string - config server.ServerConfiguration - expected bool - certProvider atls.CertificateProvider - }{ - { - name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and certProvider is not empty", - config: server.AgentConfig{ - AttestedTLS: true, - }, - certProvider: mockCertProvider, - expected: true, - }, - { - name: "should not use attested TLS when certProvider is empty", - config: server.AgentConfig{ - AttestedTLS: true, - }, - certProvider: nil, - expected: false, - }, - { - name: "should not use attested TLS when AttestedTLS is false", - config: server.AgentConfig{ - AttestedTLS: false, - }, - certProvider: mockCertProvider, - expected: false, - }, - { - name: "should not use attested TLS when config is not AgentConfig", - config: &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - }, - certProvider: mockCertProvider, - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cancel := func() {} - - server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.certProvider) - httpSrv := server.(*httpServer) - - result := httpSrv.shouldUseAttestedTLS() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestHttpServer_shouldUseRegularTLS(t *testing.T) { - tests := []struct { - name string - certFile string - keyFile string - expected bool - }{ - { - name: "should use regular TLS when both cert and key files are provided", - certFile: "cert.pem", - keyFile: "key.pem", - expected: true, - }, - { - name: "should use regular TLS when only cert file is provided", - certFile: "cert.pem", - keyFile: "", - expected: true, - }, - { - name: "should use regular TLS when only key file is provided", - certFile: "", - keyFile: "key.pem", - expected: true, - }, - { - name: "should not use regular TLS when neither cert nor key files are provided", - certFile: "", - keyFile: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cancel := func() {} - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{ - certFile: tt.certFile, - keyFile: tt.keyFile, - }, - } - - server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), nil) - httpSrv := server.(*httpServer) - - result := httpSrv.shouldUseRegularTLS() - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestHttpServer_Stop(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - } - handler := &mockHandler{} - - server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil) - httpSrv := server.(*httpServer) - - // Start a test server that we can control - testServer := httptest.NewServer(handler) - defer testServer.Close() - - // Replace the server's HTTP server with our test server's - httpSrv.server = testServer.Config - - err := httpSrv.Stop() - assert.NoError(t, err) -} - -func TestHttpServer_logAttestedTLSStart(t *testing.T) { - tests := []struct { - name string - mtls bool - }{ - { - name: "log attested mTLS start", - mtls: true, - }, - { - name: "log attested TLS start", - mtls: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cancel := func() {} - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - } - - server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil) - httpSrv := server.(*httpServer) - - // This test mainly ensures the method doesn't panic - // In a real scenario, you might want to capture log output - assert.NotPanics(t, func() { - httpSrv.logAttestedTLSStart(tt.mtls) - }) - }) - } -} - -func TestHttpServer_logRegularTLSStart(t *testing.T) { - tests := []struct { - name string - mtls bool - }{ - { - name: "log regular mTLS start", - mtls: true, - }, - { - name: "log regular TLS start", - mtls: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cancel := func() {} - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{ - certFile: "cert.pem", - keyFile: "key.pem", - serverCAFile: "server-ca.pem", - clientCAFile: "client-ca.pem", - }, - } - - server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil) - httpSrv := server.(*httpServer) - - // This test mainly ensures the method doesn't panic - assert.NotPanics(t, func() { - httpSrv.logRegularTLSStart(tt.mtls) - }) - }) - } -} - -func TestHttpServer_startWithoutTLS(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - } - handler := &mockHandler{} - - server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil) - httpSrv := server.(*httpServer) - - // Use a test server to avoid binding to actual ports - testServer := httptest.NewServer(handler) - defer testServer.Close() - - httpSrv.server = testServer.Config - - err := httpSrv.startWithoutTLS() - // The error will be related to context cancellation or server shutdown - assert.Error(t, err) -} - -func TestHttpServer_Protocol(t *testing.T) { - tests := []struct { - name string - setupTLS func(*httpServer) - expectedProto string - }{ - { - name: "HTTP protocol without TLS", - setupTLS: func(s *httpServer) { - s.Protocol = httpProtocol - }, - expectedProto: httpProtocol, - }, - { - name: "HTTPS protocol with TLS", - setupTLS: func(s *httpServer) { - s.Protocol = httpsProtocol - s.server.TLSConfig = &tls.Config{} - }, - expectedProto: httpsProtocol, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cancel := func() {} - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - } - - server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil) - httpSrv := server.(*httpServer) - - tt.setupTLS(httpSrv) - - assert.Equal(t, tt.expectedProto, httpSrv.Protocol) - }) - } -} - -func TestHttpServer_ContextCancellation(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{}, - } - handler := &mockHandler{} - - server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil) - httpSrv := server.(*httpServer) - - // Cancel the context immediately - cancel() - - // The listenAndServe method should handle context cancellation - err := httpSrv.listenAndServe(false) - assert.NoError(t, err) // Should return no error when context is cancelled and Stop() succeeds -} - -func TestHttpServer_TLSConfiguration(t *testing.T) { - ctx := context.Background() - cancel := func() {} - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{ - certFile: "cert.pem", - keyFile: "key.pem", - }, - } - - server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil) - httpSrv := server.(*httpServer) - - // Test TLS configuration setup - httpSrv.server.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - } - - assert.NotNil(t, httpSrv.server.TLSConfig) - assert.Equal(t, uint16(tls.VersionTLS12), httpSrv.server.TLSConfig.MinVersion) -} - -// Integration-style test for server lifecycle. -func TestHttpServer_Lifecycle(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer cancel() - - config := &mockServerConfig{ - baseConfig: &mockBaseConfig{ - host: "localhost", - port: "8080", - }, - } - handler := &mockHandler{} - - server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil) - - // Test that server can be created and has expected initial state - httpSrv, ok := server.(*httpServer) - require.True(t, ok) - assert.NotNil(t, httpSrv.server) - assert.Equal(t, "localhost:8080", httpSrv.server.Addr) - - // Test Stop without Start (should not panic) - err := httpSrv.Stop() - assert.NoError(t, err) -} diff --git a/pkg/server/mocks/server.go b/pkg/server/mocks/server.go deleted file mode 100644 index 06b3ffcd..00000000 --- a/pkg/server/mocks/server.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 - -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mocks - -import ( - mock "github.com/stretchr/testify/mock" -) - -// NewServer creates a new instance of Server. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewServer(t interface { - mock.TestingT - Cleanup(func()) -}) *Server { - mock := &Server{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// Server is an autogenerated mock type for the Server type -type Server struct { - mock.Mock -} - -type Server_Expecter struct { - mock *mock.Mock -} - -func (_m *Server) EXPECT() *Server_Expecter { - return &Server_Expecter{mock: &_m.Mock} -} - -// Start provides a mock function for the type Server -func (_mock *Server) Start() error { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for Start") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func() error); ok { - r0 = returnFunc() - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Server_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type Server_Start_Call struct { - *mock.Call -} - -// Start is a helper method to define mock.On call -func (_e *Server_Expecter) Start() *Server_Start_Call { - return &Server_Start_Call{Call: _e.mock.On("Start")} -} - -func (_c *Server_Start_Call) Run(run func()) *Server_Start_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *Server_Start_Call) Return(err error) *Server_Start_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Server_Start_Call) RunAndReturn(run func() error) *Server_Start_Call { - _c.Call.Return(run) - return _c -} - -// Stop provides a mock function for the type Server -func (_mock *Server) Stop() error { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for Stop") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func() error); ok { - r0 = returnFunc() - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Server_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' -type Server_Stop_Call struct { - *mock.Call -} - -// Stop is a helper method to define mock.On call -func (_e *Server_Expecter) Stop() *Server_Stop_Call { - return &Server_Stop_Call{Call: _e.mock.On("Stop")} -} - -func (_c *Server_Stop_Call) Run(run func()) *Server_Stop_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *Server_Stop_Call) Return(err error) *Server_Stop_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Server_Stop_Call) RunAndReturn(run func() error) *Server_Stop_Call { - _c.Call.Return(run) - return _c -} diff --git a/pkg/server/server.go b/pkg/server/server.go deleted file mode 100644 index 6a7097dc..00000000 --- a/pkg/server/server.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package server - -import ( - "context" - "fmt" - "log/slog" - "os" - "os/signal" - "syscall" -) - -type Server interface { - Start() error - Stop() error -} - -type ServerConfiguration interface { - GetBaseConfig() ServerConfig -} - -type Config struct { - Host string `env:"HOST" envDefault:"localhost"` - Port string `env:"PORT" envDefault:"7001"` - ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""` - CertFile string `env:"SERVER_CERT" envDefault:""` - KeyFile string `env:"SERVER_KEY" envDefault:""` - ClientCAFile string `env:"CLIENT_CA_CERTS" envDefault:""` -} - -type ServerConfig struct { - Config -} -type AgentConfig struct { - ServerConfig - AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"` -} - -type BaseServer struct { - Ctx context.Context - Cancel context.CancelFunc - Name string - Address string - Config ServerConfiguration - Logger *slog.Logger - Protocol string -} - -func (s ServerConfig) GetBaseConfig() ServerConfig { - return s -} - -func (a AgentConfig) GetBaseConfig() ServerConfig { - return a.ServerConfig -} - -func NewBaseServer( - ctx context.Context, cancel context.CancelFunc, name string, config ServerConfiguration, logger *slog.Logger, -) BaseServer { - cfg := config.GetBaseConfig() - address := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port) - - return BaseServer{ - Ctx: ctx, - Cancel: cancel, - Name: name, - Address: address, - Config: config, - Logger: logger, - } -} - -func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, svcName string, servers ...Server) error { - var err error - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGINT, syscall.SIGABRT) - select { - case sig := <-c: - defer cancel() - err = stopAllServer(servers...) - if err != nil { - logger.Error(fmt.Sprintf("%s service error during shutdown: %v", svcName, err)) - } - logger.Info(fmt.Sprintf("%s service shutdown by signal: %s", svcName, sig)) - return err - case <-ctx.Done(): - return nil - } -} - -func stopAllServer(servers ...Server) error { - var errs []error - for _, server := range servers { - if err := server.Stop(); err != nil { - errs = append(errs, err) - } - } - - if len(errs) > 0 { - return fmt.Errorf("encountered errors while stopping servers: %v", errs) - } - - return nil -} diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go deleted file mode 100644 index 1b7f8848..00000000 --- a/pkg/server/server_test.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package server - -import ( - "context" - "errors" - "log/slog" - "os" - "syscall" - "testing" - "time" - - "github.com/ultravioletrs/cocos/pkg/server/mocks" -) - -func TestStopAllServer(t *testing.T) { - server1 := new(mocks.Server) - server2 := new(mocks.Server) - server1.On("Stop").Return(nil) - server2.On("Stop").Return(errors.New("failed to stop")) - tests := []struct { - name string - servers []Server - expectedError bool - }{ - { - name: "All servers stop successfully", - servers: []Server{ - server1, - server1, - }, - expectedError: false, - }, - { - name: "One server fails to stop", - servers: []Server{ - server1, - server2, - }, - expectedError: true, - }, - { - name: "No servers", - servers: []Server{}, - expectedError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := stopAllServer(tt.servers...) - if (err != nil) != tt.expectedError { - t.Errorf("stopAllServer() error = %v, expectedError %v", err, tt.expectedError) - } - }) - } -} - -func TestStopHandler(t *testing.T) { - mockServer := new(mocks.Server) - mockServer.On("Stop").Return(nil) - tests := []struct { - name string - setupFunc func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) - triggerSignal bool - expectedError bool - expectCanceled bool - }{ - { - name: "Graceful shutdown on signal", - setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) { - ctx, cancel := context.WithCancel(context.Background()) - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - return ctx, cancel, logger, "test", []Server{mockServer} - }, - triggerSignal: true, - expectedError: false, - expectCanceled: true, - }, - { - name: "Context canceled", - setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) { - ctx, cancel := context.WithCancel(context.Background()) - logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) - go func() { - time.Sleep(100 * time.Millisecond) - cancel() - }() - return ctx, cancel, logger, "test", []Server{mockServer} - }, - triggerSignal: false, - expectedError: false, - expectCanceled: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx, cancel, logger, svcName, servers := tt.setupFunc() - defer cancel() - - errChan := make(chan error) - go func() { - errChan <- StopHandler(ctx, cancel, logger, svcName, servers...) - }() - - if tt.triggerSignal { - // Simulate SIGINT - go func() { - time.Sleep(100 * time.Millisecond) - err := syscall.Kill(syscall.Getpid(), syscall.SIGINT) - if err != nil { - t.Errorf("failed to send signal: %v", err) - } - }() - } - - select { - case err := <-errChan: - if (err != nil) != tt.expectedError { - t.Errorf("StopHandler() error = %v, expectedError %v", err, tt.expectedError) - } - case <-time.After(2 * time.Second): - t.Error("StopHandler() timed out") - } - - if tt.expectCanceled { - select { - case <-ctx.Done(): - // Context was canceled as expected - default: - t.Error("Context was not canceled") - } - } - }) - } -} diff --git a/pkg/server/tlsutil_test.go b/pkg/server/tlsutil_test.go deleted file mode 100644 index 23e8d1d1..00000000 --- a/pkg/server/tlsutil_test.go +++ /dev/null @@ -1,741 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package server - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - "os" - "strings" - "testing" - "time" -) - -// Helper function to generate a test certificate and key. -func generateTestCert() (certPEM, keyPEM []byte, err error) { - // Generate private key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, err - } - - // Create certificate template - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Org"}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"Test City"}, - StreetAddress: []string{""}, - PostalCode: []string{""}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - IPAddresses: nil, - } - - // Create certificate - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - if err != nil { - return nil, nil, err - } - - // Encode certificate - certPEM = pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certDER, - }) - - // Encode private key - privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey) - if err != nil { - return nil, nil, err - } - - keyPEM = pem.EncodeToMemory(&pem.Block{ - Type: "PRIVATE KEY", - Bytes: privateKeyDER, - }) - - return certPEM, keyPEM, nil -} - -// Helper function to create temporary files for testing. -func createTempFile(t *testing.T, content []byte) string { - tmpFile, err := os.CreateTemp("", "test-cert-*.pem") - if err != nil { - t.Fatalf("Failed to create temp file: %v", err) - } - defer tmpFile.Close() - - if _, err := tmpFile.Write(content); err != nil { - t.Fatalf("Failed to write temp file: %v", err) - } - - return tmpFile.Name() -} - -func TestLoadCertFile(t *testing.T) { - certPEM, _, err := generateTestCert() - if err != nil { - t.Fatalf("Failed to generate test cert: %v", err) - } - - tests := []struct { - name string - certFile string - wantErr bool - setup func() string - cleanup func(string) - }{ - { - name: "empty cert file path", - certFile: "", - wantErr: false, - }, - { - name: "valid cert file", - wantErr: false, - setup: func() string { - return createTempFile(t, certPEM) - }, - cleanup: func(path string) { - os.Remove(path) - }, - }, - { - name: "non-existent file", - certFile: "/non/existent/file.pem", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - certFile := tt.certFile - if tt.setup != nil { - certFile = tt.setup() - } - if tt.cleanup != nil { - defer tt.cleanup(certFile) - } - - data, err := LoadCertFile(certFile) - if (err != nil) != tt.wantErr { - t.Errorf("LoadCertFile() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.certFile != "" && !tt.wantErr && len(data) == 0 { - t.Errorf("LoadCertFile() with valid file should return data, got empty") - } - }) - } -} - -func TestReadFileOrData(t *testing.T) { - testData := "test certificate data" - tempFile := createTempFile(t, []byte(testData)) - defer os.Remove(tempFile) - - tests := []struct { - name string - input string - want string - wantErr bool - }{ - { - name: "file path", - input: tempFile, - want: testData, - wantErr: false, - }, - { - name: "raw data with newlines", - input: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----", - want: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----", - wantErr: false, - }, - { - name: "short raw data without newlines", - input: "short data", - want: "short data", - wantErr: true, - }, - { - name: "non-existent file path", - input: "/non/existent/file.pem", - want: "", - wantErr: true, - }, - { - name: "empty input", - input: "", - want: "", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := ReadFileOrData(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && string(got) != tt.want { - t.Errorf("ReadFileOrData() = %v, want %v", string(got), tt.want) - } - }) - } -} - -func TestLoadX509KeyPair(t *testing.T) { - certPEM, keyPEM, err := generateTestCert() - if err != nil { - t.Fatalf("Failed to generate test cert: %v", err) - } - - certFile := createTempFile(t, certPEM) - keyFile := createTempFile(t, keyPEM) - defer os.Remove(certFile) - defer os.Remove(keyFile) - - tests := []struct { - name string - certfile string - keyfile string - wantErr bool - }{ - { - name: "valid cert and key files", - certfile: certFile, - keyfile: keyFile, - wantErr: false, - }, - { - name: "valid cert and key data", - certfile: string(certPEM), - keyfile: string(keyPEM), - wantErr: false, - }, - { - name: "non-existent cert file", - certfile: "/non/existent/cert.pem", - keyfile: keyFile, - wantErr: true, - }, - { - name: "non-existent key file", - certfile: certFile, - keyfile: "/non/existent/key.pem", - wantErr: true, - }, - { - name: "invalid cert data", - certfile: "invalid cert data", - keyfile: string(keyPEM), - wantErr: true, - }, - { - name: "invalid key data", - certfile: string(certPEM), - keyfile: "invalid key data", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cert, err := LoadX509KeyPair(tt.certfile, tt.keyfile) - if (err != nil) != tt.wantErr { - t.Errorf("LoadX509KeyPair() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && len(cert.Certificate) == 0 { - t.Errorf("LoadX509KeyPair() returned empty certificate") - } - }) - } -} - -func TestConfigureRootCA(t *testing.T) { - certPEM, _, err := generateTestCert() - if err != nil { - t.Fatalf("Failed to generate test cert: %v", err) - } - - caFile := createTempFile(t, certPEM) - defer os.Remove(caFile) - - tests := []struct { - name string - tlsConfig *tls.Config - serverCAFile string - wantErr bool - expectCA bool - }{ - { - name: "valid CA file", - tlsConfig: &tls.Config{}, - serverCAFile: caFile, - wantErr: false, - expectCA: true, - }, - { - name: "valid CA data", - tlsConfig: &tls.Config{}, - serverCAFile: string(certPEM), - wantErr: false, - expectCA: true, - }, - { - name: "empty CA file", - tlsConfig: &tls.Config{}, - serverCAFile: "", - wantErr: false, - expectCA: false, - }, - { - name: "non-existent CA file", - tlsConfig: &tls.Config{}, - serverCAFile: "/non/existent/ca.pem", - wantErr: true, - expectCA: false, - }, - { - name: "invalid CA data", - tlsConfig: &tls.Config{}, - serverCAFile: "invalid ca data", - wantErr: true, - expectCA: false, - }, - { - name: "existing RootCAs pool", - tlsConfig: &tls.Config{RootCAs: x509.NewCertPool()}, - serverCAFile: caFile, - wantErr: false, - expectCA: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ConfigureRootCA(tt.tlsConfig, tt.serverCAFile) - if (err != nil) != tt.wantErr { - t.Errorf("ConfigureRootCA() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.expectCA && tt.tlsConfig.RootCAs == nil { - t.Errorf("ConfigureRootCA() should have created RootCAs pool") - } - - if !tt.expectCA && tt.tlsConfig.RootCAs != nil && tt.serverCAFile == "" { - t.Errorf("ConfigureRootCA() should not have created RootCAs pool for empty file") - } - }) - } -} - -func TestConfigureClientCA(t *testing.T) { - certPEM, _, err := generateTestCert() - if err != nil { - t.Fatalf("Failed to generate test cert: %v", err) - } - - caFile := createTempFile(t, certPEM) - defer os.Remove(caFile) - - tests := []struct { - name string - tlsConfig *tls.Config - clientCAFile string - wantConfigured bool - wantErr bool - }{ - { - name: "valid client CA file", - tlsConfig: &tls.Config{}, - clientCAFile: caFile, - wantConfigured: true, - wantErr: false, - }, - { - name: "valid client CA data", - tlsConfig: &tls.Config{}, - clientCAFile: string(certPEM), - wantConfigured: true, - wantErr: false, - }, - { - name: "empty client CA file", - tlsConfig: &tls.Config{}, - clientCAFile: "", - wantConfigured: false, - wantErr: false, - }, - { - name: "non-existent client CA file", - tlsConfig: &tls.Config{}, - clientCAFile: "/non/existent/ca.pem", - wantConfigured: false, - wantErr: true, - }, - { - name: "invalid client CA data", - tlsConfig: &tls.Config{}, - clientCAFile: "invalid ca data", - wantConfigured: false, - wantErr: true, - }, - { - name: "existing ClientCAs pool", - tlsConfig: &tls.Config{ClientCAs: x509.NewCertPool()}, - clientCAFile: caFile, - wantConfigured: true, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - configured, err := ConfigureClientCA(tt.tlsConfig, tt.clientCAFile) - if (err != nil) != tt.wantErr { - t.Errorf("ConfigureClientCA() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if configured != tt.wantConfigured { - t.Errorf("ConfigureClientCA() configured = %v, want %v", configured, tt.wantConfigured) - } - - if tt.wantConfigured && tt.tlsConfig.ClientCAs == nil { - t.Errorf("ConfigureClientCA() should have created ClientCAs pool") - } - }) - } -} - -func TestConfigureCertificateAuthorities(t *testing.T) { - certPEM, _, err := generateTestCert() - if err != nil { - t.Fatalf("Failed to generate test cert: %v", err) - } - - caFile := createTempFile(t, certPEM) - defer os.Remove(caFile) - - tests := []struct { - name string - tlsConfig *tls.Config - serverCAFile string - clientCAFile string - wantMTLS bool - wantErr bool - }{ - { - name: "both server and client CA", - tlsConfig: &tls.Config{}, - serverCAFile: caFile, - clientCAFile: caFile, - wantMTLS: true, - wantErr: false, - }, - { - name: "only server CA", - tlsConfig: &tls.Config{}, - serverCAFile: caFile, - clientCAFile: "", - wantMTLS: false, - wantErr: false, - }, - { - name: "only client CA", - tlsConfig: &tls.Config{}, - serverCAFile: "", - clientCAFile: caFile, - wantMTLS: true, - wantErr: false, - }, - { - name: "no CAs", - tlsConfig: &tls.Config{}, - serverCAFile: "", - clientCAFile: "", - wantMTLS: false, - wantErr: false, - }, - { - name: "invalid server CA", - tlsConfig: &tls.Config{}, - serverCAFile: "/non/existent/server-ca.pem", - clientCAFile: caFile, - wantMTLS: false, - wantErr: true, - }, - { - name: "invalid client CA", - tlsConfig: &tls.Config{}, - serverCAFile: caFile, - clientCAFile: "/non/existent/client-ca.pem", - wantMTLS: false, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mtls, err := ConfigureCertificateAuthorities(tt.tlsConfig, tt.serverCAFile, tt.clientCAFile) - if (err != nil) != tt.wantErr { - t.Errorf("ConfigureCertificateAuthorities() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if mtls != tt.wantMTLS { - t.Errorf("ConfigureCertificateAuthorities() mtls = %v, want %v", mtls, tt.wantMTLS) - } - }) - } -} - -func TestSetupRegularTLS(t *testing.T) { - certPEM, keyPEM, err := generateTestCert() - if err != nil { - t.Fatalf("Failed to generate test cert: %v", err) - } - - certFile := createTempFile(t, certPEM) - keyFile := createTempFile(t, keyPEM) - caFile := createTempFile(t, certPEM) - defer func() { - os.Remove(certFile) - os.Remove(keyFile) - os.Remove(caFile) - }() - - tests := []struct { - name string - certFile string - keyFile string - serverCAFile string - clientCAFile string - wantMTLS bool - wantErr bool - expectedAuth tls.ClientAuthType - }{ - { - name: "regular TLS without mTLS", - certFile: certFile, - keyFile: keyFile, - serverCAFile: "", - clientCAFile: "", - wantMTLS: false, - wantErr: false, - expectedAuth: tls.NoClientCert, - }, - { - name: "TLS with mTLS", - certFile: certFile, - keyFile: keyFile, - serverCAFile: caFile, - clientCAFile: caFile, - wantMTLS: true, - wantErr: false, - expectedAuth: tls.RequireAndVerifyClientCert, - }, - { - name: "TLS with only server CA", - certFile: certFile, - keyFile: keyFile, - serverCAFile: caFile, - clientCAFile: "", - wantMTLS: false, - wantErr: false, - expectedAuth: tls.NoClientCert, - }, - { - name: "invalid certificate file", - certFile: "/non/existent/cert.pem", - keyFile: keyFile, - serverCAFile: "", - clientCAFile: "", - wantMTLS: false, - wantErr: true, - expectedAuth: tls.NoClientCert, - }, - { - name: "invalid key file", - certFile: certFile, - keyFile: "/non/existent/key.pem", - serverCAFile: "", - clientCAFile: "", - wantMTLS: false, - wantErr: true, - expectedAuth: tls.NoClientCert, - }, - { - name: "invalid server CA file", - certFile: certFile, - keyFile: keyFile, - serverCAFile: "/non/existent/server-ca.pem", - clientCAFile: "", - wantMTLS: false, - wantErr: true, - expectedAuth: tls.NoClientCert, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := SetupRegularTLS(tt.certFile, tt.keyFile, tt.serverCAFile, tt.clientCAFile) - if (err != nil) != tt.wantErr { - t.Errorf("SetupRegularTLS() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.wantErr { - return - } - - if result == nil { - t.Errorf("SetupRegularTLS() returned nil result") - return - } - - if result.MTLS != tt.wantMTLS { - t.Errorf("SetupRegularTLS() MTLS = %v, want %v", result.MTLS, tt.wantMTLS) - } - - if result.Config.ClientAuth != tt.expectedAuth { - t.Errorf("SetupRegularTLS() ClientAuth = %v, want %v", result.Config.ClientAuth, tt.expectedAuth) - } - - if len(result.Config.Certificates) == 0 { - t.Errorf("SetupRegularTLS() should have at least one certificate") - } - }) - } -} - -func TestBuildMTLSDescription(t *testing.T) { - tests := []struct { - name string - serverCAFile string - clientCAFile string - want string - }{ - { - name: "both server and client CA files", - serverCAFile: "/path/to/server-ca.pem", - clientCAFile: "/path/to/client-ca.pem", - want: "root ca /path/to/server-ca.pem client ca /path/to/client-ca.pem", - }, - { - name: "only server CA file", - serverCAFile: "/path/to/server-ca.pem", - clientCAFile: "", - want: "root ca /path/to/server-ca.pem", - }, - { - name: "only client CA file", - serverCAFile: "", - clientCAFile: "/path/to/client-ca.pem", - want: "client ca /path/to/client-ca.pem", - }, - { - name: "no CA files", - serverCAFile: "", - clientCAFile: "", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildMTLSDescription(tt.serverCAFile, tt.clientCAFile) - if got != tt.want { - t.Errorf("BuildMTLSDescription() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestErrorConstants(t *testing.T) { - // Test that error constants are properly defined - if ErrAppendServerCA == nil { - t.Error("ErrAppendServerCA should not be nil") - } - - if ErrAppendClientCA == nil { - t.Error("ErrAppendClientCA should not be nil") - } - - if ErrAppendServerCA.Error() != "failed to append server ca to tls.Config" { - t.Errorf("ErrAppendServerCA message = %v, want 'failed to append server ca to tls.Config'", ErrAppendServerCA.Error()) - } - - if ErrAppendClientCA.Error() != "failed to append client ca to tls.Config" { - t.Errorf("ErrAppendClientCA message = %v, want 'failed to append client ca to tls.Config'", ErrAppendClientCA.Error()) - } -} - -func TestTLSSetupResult(t *testing.T) { - // Test that TLSSetupResult struct works as expected - config := &tls.Config{} - result := &TLSSetupResult{ - Config: config, - MTLS: true, - } - - if result.Config != config { - t.Error("TLSSetupResult Config field should match assigned value") - } - - if !result.MTLS { - t.Error("TLSSetupResult MTLS field should be true") - } -} - -func TestReadFileOrDataEdgeCases(t *testing.T) { - tests := []struct { - name string - input string - wantErr bool - }{ - { - name: "999 chars without newline (should try file)", - input: strings.Repeat("a", 999), - wantErr: true, // Should fail as file doesn't exist - }, - { - name: "1001 chars without newline (should treat as data)", - input: strings.Repeat("a", 1001), - wantErr: false, - }, - { - name: "short string with newline (should treat as data)", - input: "short\ndata", - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ReadFileOrData(tt.input) - if (err != nil) != tt.wantErr { - t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/test/cvms/main.go b/test/cvms/main.go index 5c523299..98b653ca 100644 --- a/test/cvms/main.go +++ b/test/cvms/main.go @@ -15,12 +15,12 @@ import ( "strings" mglog "github.com/absmach/supermq/logger" + smqserver "github.com/absmach/supermq/pkg/server" + grpcserver "github.com/absmach/supermq/pkg/server/grpc" "github.com/caarlos0/env/v11" "github.com/ultravioletrs/cocos/agent/cvms" cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc" "github.com/ultravioletrs/cocos/internal" - "github.com/ultravioletrs/cocos/pkg/server" - grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -314,24 +314,22 @@ func main() { reflection.Register(srv) cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger})) } - grpcServerConfig := server.ServerConfig{ - Config: server.Config{ - Port: defaultPort, - }, + grpcServerConfig := smqserver.Config{ + Port: defaultPort, } if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) return } - gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, nil) + gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger) g.Go(func() error { return gs.Start() }) g.Go(func() error { - return server.StopHandler(ctx, cancel, logger, svcName, gs) + return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs) }) if err := g.Wait(); err != nil {