NOISSUE - Implement structured logging with log forwarding for ingress-proxy and computation-runner, update component versions, and improve aTLS initialization and error handling. (#583)
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled

* feat: Implement structured logging with log forwarding for `ingress-proxy` and `computation-runner`, update component versions, and improve aTLS initialization and error handling.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Remove explicit AGENT_ENABLE_ATLS configuration and update component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Correct aTLS nonce verification for truncated hashes, delegate internal CVM server TLS to Ingress Proxy, and update component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Update package build sources to ultravioletrs/cocos main branch and remove local development keys and encrypted algorithm.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove the `pkg/server` module, including its generic gRPC and HTTP server implementations.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: clarify nonce truncation in the certificate verifier.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-03-23 21:05:15 +03:00
committed by GitHub
parent c1cbcec851
commit 42b05524c8
36 changed files with 282 additions and 2651 deletions
+53 -40
View File
@@ -4,17 +4,17 @@
package server
import (
context "context"
"fmt"
"log/slog"
"net"
"os"
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection"
)
@@ -29,55 +29,68 @@ type AgentServer interface {
}
type agentServer struct {
gs server.Server
logger *slog.Logger
svc agent.Service
host string
certProvider atls.CertificateProvider
gs *grpc.Server
logger *slog.Logger
svc agent.Service
host string
}
func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider atls.CertificateProvider) AgentServer {
func NewServer(logger *slog.Logger, svc agent.Service, host string) AgentServer {
return &agentServer{
logger: logger,
svc: svc,
host: host,
certProvider: certProvider,
logger: logger,
svc: svc,
host: host,
}
}
func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
agentGrpcServerConfig := server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: defSvcGRPCSocket,
Port: "",
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
ServerCAFile: cfg.ServerCAFile,
ClientCAFile: cfg.ClientCAFile,
},
},
AttestedTLS: cfg.AttestedTls,
}
registerAgentServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(as.svc))
}
authSvc, err := auth.New(cmp)
if err != nil {
as.logger.WithGroup(cmp.ID).Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
return err
}
ctx, cancel := context.WithCancel(context.Background())
grpcServerOptions := []grpc.ServerOption{
grpc.StatsHandler(otelgrpc.NewServerHandler()),
}
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, authSvc, as.certProvider)
// Add authentication interceptors
unary, stream := agentgrpc.NewAuthInterceptor(authSvc)
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
// Internal Unix socket is pure plaintext HTTP/2; Ingress Proxy handles external aTLS termination
grpcServerOptions = append(grpcServerOptions, grpc.Creds(insecure.NewCredentials()))
as.gs = grpc.NewServer(grpcServerOptions...)
reflection.Register(as.gs)
agent.RegisterAgentServiceServer(as.gs, agentgrpc.NewServer(as.svc))
socketPath := as.host
if socketPath == "" || socketPath == "0.0.0.0" {
socketPath = defSvcGRPCSocket
}
var listener net.Listener
if socketPath[0] == '/' || socketPath[0] == '.' {
// Remove existing socket file if it exists
_ = os.Remove(socketPath)
listener, err = net.Listen("unix", socketPath)
} else {
listener, err = net.Listen("tcp", socketPath)
}
if err != nil {
as.logger.Error(fmt.Sprintf("failed to listen on %s: %s", socketPath, err))
return err
}
as.logger.Info(fmt.Sprintf("agent service gRPC server listening at %s without TLS", socketPath))
go func() {
err := as.gs.Start()
if err != nil {
err := as.gs.Serve(listener)
if err != nil && err != grpc.ErrServerStopped {
as.logger.Error(fmt.Sprintf("failed to start grpc server %s", err.Error()))
}
}()
@@ -86,8 +99,8 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error
}
func (as *agentServer) Stop() error {
if as.gs == nil {
return nil
if as.gs != nil {
as.gs.GracefulStop()
}
return as.gs.Stop()
return nil
}
+7 -7
View File
@@ -21,7 +21,7 @@ import (
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, []byte) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
mockSvc := new(mocks.Service)
host := "localhost"
host := "localhost:0"
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.NoError(t, err, "Failed to generate ECDSA key")
@@ -70,7 +70,7 @@ func TestNewServer(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewServer(tt.logger, tt.svc, tt.host, nil)
server := NewServer(tt.logger, tt.svc, tt.host)
assert.NotNil(t, server)
@@ -194,7 +194,7 @@ func TestAgentServer_Start(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tt.setupMocks(svc)
server := NewServer(logger, svc, host, nil)
server := NewServer(logger, svc, host)
err := server.Start(tt.cfg, tt.cmp)
@@ -268,7 +268,7 @@ func TestAgentServer_Stop(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewServer(logger, svc, host, nil)
server := NewServer(logger, svc, host)
err := tt.setupServer(server)
if err != nil {
@@ -296,7 +296,7 @@ func TestAgentServer_Stop(t *testing.T) {
func TestAgentServer_StopMultipleTimes(t *testing.T) {
logger, svc, host, pubKey := setupTest(t)
server := NewServer(logger, svc, host, nil)
server := NewServer(logger, svc, host)
// Start the server
cfg := agent.AgentConfig{}
@@ -340,7 +340,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
func TestAgentServer_StartAfterStop(t *testing.T) {
logger, svc, host, pubKey := setupTest(t)
server := NewServer(logger, svc, host, nil)
server := NewServer(logger, svc, host)
cfg := agent.AgentConfig{}
cmp := agent.Computation{
@@ -488,7 +488,7 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := NewServer(logger, svc, host, nil)
server := NewServer(logger, svc, host)
err := server.Start(tt.config, tt.cmp)
Binary file not shown.
+9 -6
View File
@@ -59,7 +59,6 @@ type config struct {
AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"`
AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"`
AttestationServiceSocket string `env:"ATTESTATION_SERVICE_SOCKET" envDefault:"/run/cocos/attestation.sock"`
EnableATLS bool `env:"AGENT_ENABLE_ATLS" envDefault:"true"`
}
func main() {
@@ -139,6 +138,7 @@ func main() {
})
ccPlatform := attestation.CCPlatform()
logger.Info(fmt.Sprintf("Detected confidential computing platform: %v", ccPlatform))
azureConfig := azure.NewEnvConfigFromAgent(
cfg.AgentOSBuild,
@@ -209,7 +209,8 @@ func main() {
}
var certProvider atls.CertificateProvider
if cfg.EnableATLS && ccPlatform != attestation.NoCC {
if ccPlatform != attestation.NoCC {
logger.Info(fmt.Sprintf("Initializing aTLS for platform %v with attestation service at %s", ccPlatform, cfg.AttestationServiceSocket))
var certsSDK sdk.SDK
if cfg.CAUrl != "" {
certsSDK = sdk.NewSDK(sdk.Config{
@@ -218,10 +219,12 @@ func main() {
}
certProvider, err = atls.NewProvider(attClient, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK)
if err != nil {
logger.Error(fmt.Sprintf("failed to create certificate provider: %s", err))
exitCode = 1
return
logger.Error(fmt.Sprintf("failed to create certificate provider for aTLS: %s. Continuing without attested TLS.", err))
} else {
logger.Info("Successfully created aTLS certificate provider")
}
} else {
logger.Warn("No Confidential Computing platform detected (NoCC). Certificate provider remains nil; aTLS will not be available for computations.")
}
// Create ingress proxy server
@@ -240,7 +243,7 @@ func main() {
return
}
mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), ingressProxy, storageDir, reconnectFn, cvmGRPCClient)
mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost), ingressProxy, storageDir, reconnectFn, cvmGRPCClient)
if err != nil {
logger.Error(err.Error())
exitCode = 1
+32 -2
View File
@@ -13,9 +13,12 @@ import (
mglog "github.com/absmach/supermq/logger"
"github.com/caarlos0/env/v11"
"github.com/ultravioletrs/cocos/agent/cvms"
logpb "github.com/ultravioletrs/cocos/agent/log"
pb "github.com/ultravioletrs/cocos/agent/runner"
runnerevents "github.com/ultravioletrs/cocos/agent/runner/events"
"github.com/ultravioletrs/cocos/agent/runner/service"
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
@@ -51,16 +54,43 @@ func main() {
return
}
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
logger := slog.New(handler)
// Connect to Log Forwarder
logClient, err := logclient.NewClient(cfg.LogForwarder)
if err != nil {
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Events will not be forwarded.", err))
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs and events will not be forwarded.", err))
} else {
defer logClient.Close()
}
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case msg := <-logQueue:
if logClient == nil {
continue
}
switch m := msg.Message.(type) {
case *cvms.ClientStreamMessage_AgentLog:
err := logClient.SendLog(ctx, &logpb.LogEntry{
Message: m.AgentLog.Message,
ComputationId: m.AgentLog.ComputationId,
Level: m.AgentLog.Level,
Timestamp: m.AgentLog.Timestamp,
})
if err != nil {
logger.Error("failed to send log", "error", err)
}
}
}
}
})
eventSvc := runnerevents.NewAdapter(logClient, svcName)
// Remove existing socket if it exists
+46 -5
View File
@@ -5,14 +5,18 @@ package main
import (
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"syscall"
mglog "github.com/absmach/supermq/logger"
"github.com/caarlos0/env/v11"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/ultravioletrs/cocos/agent/cvms"
logpb "github.com/ultravioletrs/cocos/agent/log"
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
"github.com/ultravioletrs/cocos/pkg/egress"
"golang.org/x/sync/errgroup"
)
@@ -22,8 +26,9 @@ const (
)
type config struct {
Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"`
Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"`
Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"`
Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"`
LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"`
}
func main() {
@@ -51,9 +56,20 @@ func main() {
}
func run(cfg config) error {
logger, err := mglog.New(os.Stdout, cfg.Level)
var level slog.Level
if err := level.UnmarshalText([]byte(cfg.Level)); err != nil {
return fmt.Errorf("invalid log level: %w", err)
}
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
logger := slog.New(handler)
logClient, err := logclient.NewClient(cfg.LogForwarder)
if err != nil {
return fmt.Errorf("failed to create logger: %w", err)
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs will not be forwarded.", err))
} else {
defer logClient.Close()
}
ctx, cancel := context.WithCancel(context.Background())
@@ -61,6 +77,31 @@ func run(cfg config) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case msg := <-logQueue:
if logClient == nil {
continue
}
switch m := msg.Message.(type) {
case *cvms.ClientStreamMessage_AgentLog:
err := logClient.SendLog(ctx, &logpb.LogEntry{
Message: m.AgentLog.Message,
ComputationId: m.AgentLog.ComputationId,
Level: m.AgentLog.Level,
Timestamp: m.AgentLog.Timestamp,
})
if err != nil {
logger.Error("failed to send log", "error", err)
}
}
}
}
})
proxy := egress.NewProxy(logger, ":"+cfg.Port)
g.Go(func() error {
+50 -9
View File
@@ -5,20 +5,24 @@ package main
import (
"context"
"fmt"
"log/slog"
"net/url"
"os"
"os/signal"
"syscall"
"github.com/absmach/certs/sdk"
mglog "github.com/absmach/supermq/logger"
"github.com/caarlos0/env/v11"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/ultravioletrs/cocos/agent/cvms"
logpb "github.com/ultravioletrs/cocos/agent/log"
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
"github.com/ultravioletrs/cocos/pkg/ingress"
"golang.org/x/sync/errgroup"
)
@@ -39,6 +43,7 @@ type config struct {
AgentOSBuild string `env:"AGENT_OS_BUILD" envDefault:"UVC"`
AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"`
AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"`
LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"`
}
func main() {
@@ -66,11 +71,52 @@ func main() {
}
func run(cfg config) error {
logger, err := mglog.New(os.Stdout, cfg.LogLevel)
if err != nil {
return fmt.Errorf("failed to create logger: %w", err)
var level slog.Level
if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil {
return fmt.Errorf("invalid log level: %w", err)
}
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
logger := slog.New(handler)
logClient, err := logclient.NewClient(cfg.LogForwarder)
if err != nil {
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Logs will not be forwarded.", err))
} else {
defer logClient.Close()
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case msg := <-logQueue:
if logClient == nil {
continue
}
switch m := msg.Message.(type) {
case *cvms.ClientStreamMessage_AgentLog:
err := logClient.SendLog(ctx, &logpb.LogEntry{
Message: m.AgentLog.Message,
ComputationId: m.AgentLog.ComputationId,
Level: m.AgentLog.Level,
Timestamp: m.AgentLog.Timestamp,
})
if err != nil {
logger.Error("failed to send log", "error", err)
}
}
}
}
})
backendURL, err := url.Parse(cfg.Backend)
if err != nil {
return fmt.Errorf("failed to parse backend URL: %w", err)
@@ -111,11 +157,6 @@ func run(cfg config) error {
logger.Warn("No Confidential Computing platform detected. ATLS will not be available.")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
g, ctx := errgroup.WithContext(ctx)
// Create proxy server (but don't start it yet - it will be started per-computation)
_ = ingress.NewProxyServer(logger, backendURL, certProvider)
+4 -5
View File
@@ -16,6 +16,7 @@ import (
"github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/prometheus"
smqserver "github.com/absmach/supermq/pkg/server"
grpcserver "github.com/absmach/supermq/pkg/server/grpc"
httpserver "github.com/absmach/supermq/pkg/server/http"
"github.com/absmach/supermq/pkg/uuid"
"github.com/caarlos0/env/v11"
@@ -26,8 +27,6 @@ import (
"github.com/ultravioletrs/cocos/manager/api/http"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/tracing"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
@@ -113,7 +112,7 @@ func main() {
args := qemuCfg.ConstructQemuArgs()
logger.Info(strings.Join(args, " "))
managerGRPCConfig := server.ServerConfig{}
managerGRPCConfig := smqserver.Config{}
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
exitCode = 1
@@ -145,7 +144,7 @@ func main() {
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
}
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil)
gs := grpcserver.NewServer(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, http.MakeHandler(chi.NewMux(), svcName, cfg.InstanceID), logger)
@@ -158,7 +157,7 @@ func main() {
})
g.Go(func() error {
return server.StopHandler(ctx, cancel, logger, svcName, gs, hs)
return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs, hs)
})
if err := g.Wait(); err != nil {
-1
View File
@@ -1 +0,0 @@
bbf3a1198ee889f77a227fe01e329864fd6a37a2d23135ea8e2c5a2ebc07f0d3
+2 -2
View File
@@ -4,8 +4,8 @@
#
################################################################################
AGENT_VERSION = 913bbccf3a22053e1979da004c732007336fc890
AGENT_SITE = $(call github,sammyoina,cocos-ai,$(AGENT_VERSION))
AGENT_VERSION = main
AGENT_SITE = $(call github,ultravioletrs,cocos,$(AGENT_VERSION))
define AGENT_BUILD_CMDS
$(MAKE) -C $(@D) agent EMBED_ENABLED=$(AGENT_EMBED_ENABLED)
@@ -4,8 +4,8 @@
#
################################################################################
ATTESTATION_SERVICE_VERSION = 913bbccf3a22053e1979da004c732007336fc890
ATTESTATION_SERVICE_SITE = $(call github,sammyoina,cocos-ai,$(ATTESTATION_SERVICE_VERSION))
ATTESTATION_SERVICE_VERSION = main
ATTESTATION_SERVICE_SITE = $(call github,ultravioletrs,cocos,$(ATTESTATION_SERVICE_VERSION))
define ATTESTATION_SERVICE_BUILD_CMDS
$(MAKE) -C $(@D) attestation-service
@@ -4,8 +4,8 @@
#
################################################################################
COMPUTATION_RUNNER_VERSION = 913bbccf3a22053e1979da004c732007336fc890
COMPUTATION_RUNNER_SITE = $(call github,sammyoina,cocos-ai,$(COMPUTATION_RUNNER_VERSION))
COMPUTATION_RUNNER_VERSION = main
COMPUTATION_RUNNER_SITE = $(call github,ultravioletrs,cocos,$(COMPUTATION_RUNNER_VERSION))
define COMPUTATION_RUNNER_BUILD_CMDS
$(MAKE) -C $(@D) computation-runner
@@ -4,8 +4,8 @@
#
################################################################################
EGRESS_PROXY_VERSION = 913bbccf3a22053e1979da004c732007336fc890
EGRESS_PROXY_SITE = $(call github,sammyoina,cocos-ai,$(EGRESS_PROXY_VERSION))
EGRESS_PROXY_VERSION = main
EGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(EGRESS_PROXY_VERSION))
define EGRESS_PROXY_BUILD_CMDS
$(MAKE) -C $(@D) egress-proxy
@@ -4,8 +4,8 @@
#
################################################################################
INGRESS_PROXY_VERSION = 913bbccf3a22053e1979da004c732007336fc890
INGRESS_PROXY_SITE = $(call github,sammyoina,cocos-ai,$(INGRESS_PROXY_VERSION))
INGRESS_PROXY_VERSION = main
INGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(INGRESS_PROXY_VERSION))
define INGRESS_PROXY_BUILD_CMDS
$(MAKE) -C $(@D) ingress-proxy
@@ -4,8 +4,8 @@
#
################################################################################
LOG_FORWARDER_VERSION = 913bbccf3a22053e1979da004c732007336fc890
LOG_FORWARDER_SITE = $(call github,sammyoina,cocos-ai,$(LOG_FORWARDER_VERSION))
LOG_FORWARDER_VERSION = main
LOG_FORWARDER_SITE = $(call github,ultravioletrs,cocos,$(LOG_FORWARDER_VERSION))
define LOG_FORWARDER_BUILD_CMDS
$(MAKE) -C $(@D) log-forwarder
-1
View File
@@ -9,7 +9,6 @@ WorkingDirectory=/cocos
Environment="HTTP_PROXY=http://localhost:3128"
Environment="HTTPS_PROXY=http://localhost:3128"
Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/"
Environment="AGENT_ENABLE_ATLS=false"
StandardOutput=file:/var/log/cocos/agent.stdout
StandardError=file:/var/log/cocos/agent.stderr
EnvironmentFile=/etc/cocos/environment
-3
View File
@@ -1,3 +0,0 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEICgJcXfNueGCu8jFFNGBXm9r25OGBEc0OEqCUVjyI4fY
-----END PRIVATE KEY-----
-3
View File
@@ -1,3 +0,0 @@
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEAPbPOfwsJkxpNBluGOg/lgNVE/o0AEM7J11wvkXvHXSw=
-----END PUBLIC KEY-----
+33 -8
View File
@@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/asn1"
"fmt"
"log/slog"
"os"
"time"
@@ -36,20 +37,35 @@ func NewCertificateVerifier(rootCAs *x509.CertPool) CertificateVerifier {
}
func (v *certificateVerifier) VerifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate, nonce []byte) error {
slog.Debug("Starting peer certificate verification for aTLS")
if len(rawCerts) == 0 {
return fmt.Errorf("no certificates provided")
err := fmt.Errorf("no certificates provided")
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("failed to parse x509 certificate: %w", err)
err = fmt.Errorf("failed to parse x509 certificate: %w", err)
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
slog.Debug("Successfully parsed peer x509 certificate", "subject", cert.Subject.String())
if err := v.verifyCertificateSignature(cert); err != nil {
return fmt.Errorf("certificate signature verification failed: %w", err)
err = fmt.Errorf("certificate signature verification failed: %w", err)
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
slog.Debug("Successfully verified peer certificate signature")
return v.verifyAttestationExtension(cert, nonce)
err = v.verifyAttestationExtension(cert, nonce)
if err != nil {
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
slog.Debug("Successfully verified aTLS attestation extension")
return nil
}
func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate) error {
@@ -71,6 +87,7 @@ func (v *certificateVerifier) verifyCertificateSignature(cert *x509.Certificate)
func (v *certificateVerifier) verifyAttestationExtension(cert *x509.Certificate, nonce []byte) error {
for _, ext := range cert.Extensions {
if platformType, err := platformTypeFromOID(ext.Id); err == nil {
slog.Debug("Found attestation extension in peer certificate", "platform_type", platformType)
pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal public key: %w", err)
@@ -93,22 +110,29 @@ func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKe
// Verify nonce matches
teeNonce := append(pubKey, nonce...)
hashNonce := sha3.Sum512(teeNonce)
// The attestation provider truncates the 64-byte hash to 32 bytes for the EAT token nonce claim
// This matches the Attestation Service API and standard cryptographic nonce sizes.
expectedNonce := hashNonce[:32]
// Compare nonces (EAT nonce should match our computed nonce)
if len(claims.Nonce) != len(hashNonce) {
return fmt.Errorf("nonce length mismatch: expected %d, got %d", len(hashNonce), len(claims.Nonce))
if len(claims.Nonce) != len(expectedNonce) {
err := fmt.Errorf("nonce length mismatch: expected %d, got %d", len(expectedNonce), len(claims.Nonce))
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
nonceMatch := true
for i := range claims.Nonce {
if claims.Nonce[i] != hashNonce[i] {
if claims.Nonce[i] != expectedNonce[i] {
nonceMatch = false
break
}
}
if !nonceMatch {
return fmt.Errorf("nonce mismatch in EAT token")
err := fmt.Errorf("nonce mismatch in EAT token")
slog.Error("aTLS handshake failed", "reason", err.Error())
return err
}
// Get platform verifier
@@ -151,6 +175,7 @@ func (v *certificateVerifier) verifyCertificateExtension(extension []byte, pubKe
return fmt.Errorf("failed to verify attestation with CoRIM: %w", err)
}
slog.Debug("CoRIM verification passed for aTLS peer certificate")
return nil
}
+5 -5
View File
@@ -79,7 +79,7 @@ func TestVerifyPeerCertificate_Success(t *testing.T) {
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{
Nonce: hashNonce[:],
Nonce: hashNonce[:32],
RawReport: []byte("mock-report"),
}
eatBytes, err := cbor.Marshal(claims)
@@ -155,7 +155,7 @@ func TestVerifyPeerCertificate_AzureSuccess(t *testing.T) {
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
peerTemplate := &x509.Certificate{
@@ -210,7 +210,7 @@ func TestVerifyPeerCertificate_TDXSuccess(t *testing.T) {
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
peerTemplate := &x509.Certificate{
@@ -273,7 +273,7 @@ func TestVerifyPeerCertificate_Failures_More(t *testing.T) {
nonce := []byte("nonce")
teeNonce := append(peerPubKeyDER, nonce...)
hashNonce := sha3.Sum512(teeNonce)
claims := eat.EATClaims{Nonce: hashNonce[:], RawReport: []byte("rep")}
claims := eat.EATClaims{Nonce: hashNonce[:32], RawReport: []byte("rep")}
eatBytes, _ := cbor.Marshal(claims)
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
certDERWithExt, _ := x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
@@ -323,7 +323,7 @@ func TestVerifyPeerCertificate_Failures_Ext(t *testing.T) {
peerPubKeyDER, _ := x509.MarshalPKIXPublicKey(&peerKey.PublicKey)
wrongTeeNonce := append(peerPubKeyDER, []byte("wrong-nonce")...)
wrongHashNonce := sha3.Sum512(wrongTeeNonce)
claims.Nonce = wrongHashNonce[:]
claims.Nonce = wrongHashNonce[:32]
eatBytes, _ = cbor.Marshal(claims)
peerTemplate.ExtraExtensions = []pkix.Extension{{Id: SNPvTPMOID, Value: eatBytes}}
certDER, _ = x509.CreateCertificate(rand.Reader, peerTemplate, caCert, &peerKey.PublicKey, caKey)
+12 -6
View File
@@ -45,7 +45,7 @@ func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
}
// Retry with exponential backoff for concurrent request handling
maxRetries := 3
maxRetries := 10
for attempt := 0; attempt < maxRetries; attempt++ {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
_, err := c.client.SendLog(ctx, entry)
@@ -57,8 +57,11 @@ func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
// Don't retry on last attempt
if attempt < maxRetries-1 {
// Exponential backoff: 10ms, 20ms, 40ms
backoff := time.Duration(10*(1<<uint(attempt))) * time.Millisecond
// Backoff: 100ms, 200ms, 400ms... max 2s
backoff := time.Duration(100*(1<<uint(attempt))) * time.Millisecond
if backoff > 2*time.Second {
backoff = 2 * time.Second
}
time.Sleep(backoff)
}
}
@@ -76,7 +79,7 @@ func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
}
// Retry with exponential backoff for concurrent request handling
maxRetries := 3
maxRetries := 10
for attempt := 0; attempt < maxRetries; attempt++ {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
_, err := c.client.SendEvent(ctx, entry)
@@ -88,8 +91,11 @@ func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
// Don't retry on last attempt
if attempt < maxRetries-1 {
// Exponential backoff: 10ms, 20ms, 40ms
backoff := time.Duration(10*(1<<uint(attempt))) * time.Millisecond
// Backoff: 100ms, 200ms, 400ms... max 2s
backoff := time.Duration(100*(1<<uint(attempt))) * time.Millisecond
if backoff > 2*time.Second {
backoff = 2 * time.Second
}
time.Sleep(backoff)
}
}
+8 -8
View File
@@ -488,8 +488,8 @@ func TestClientSendLogAllRetriesFail(t *testing.T) {
grpcServer := grpc.NewServer()
mockServer := &retryMockLogCollectorServer{
failCount: 10, // Fail all attempts
maxFailCount: 10,
failCount: 20, // Fail all attempts
maxFailCount: 20,
}
log.RegisterLogCollectorServer(grpcServer, mockServer)
@@ -513,8 +513,8 @@ func TestClientSendLogAllRetriesFail(t *testing.T) {
// Should fail after all retries
err = client.SendLog(ctx, entry)
assert.Error(t, err)
// 3 retries + 1 final attempt = 4 calls
assert.Equal(t, 4, mockServer.callCount)
// 10 retries + 1 final attempt = 11 calls
assert.Equal(t, 11, mockServer.callCount)
}
func TestClientSendEventAllRetriesFail(t *testing.T) {
@@ -527,8 +527,8 @@ func TestClientSendEventAllRetriesFail(t *testing.T) {
grpcServer := grpc.NewServer()
mockServer := &retryMockLogCollectorServer{
failCount: 10,
maxFailCount: 10,
failCount: 20,
maxFailCount: 20,
}
log.RegisterLogCollectorServer(grpcServer, mockServer)
@@ -551,8 +551,8 @@ func TestClientSendEventAllRetriesFail(t *testing.T) {
// Should fail after all retries
err = client.SendEvent(ctx, entry)
assert.Error(t, err)
// 3 retries + 1 final attempt = 4 calls
assert.Equal(t, 4, mockServer.eventCallCount)
// 10 retries + 1 final attempt = 11 calls
assert.Equal(t, 11, mockServer.eventCallCount)
}
// retryMockLogCollectorServer is a mock server that fails a specified number of times.
+3 -4
View File
@@ -14,7 +14,6 @@ import (
"sync"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
@@ -138,7 +137,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
if cfg.AttestedTLS {
if p.certProvider == nil {
return fmt.Errorf("attested TLS requested but no certificate provider available")
return fmt.Errorf("attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running")
}
tlsConfig = &tls.Config{
GetCertificate: p.certProvider.GetCertificate,
@@ -146,7 +145,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
NextProtos: []string{"h2", "http/1.1"},
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
mtls, err := ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
@@ -159,7 +158,7 @@ func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
}
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
// Regular TLS
tlsSetup, err := server.SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
tlsSetup, err := SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup TLS: %w", err)
}
+1 -1
View File
@@ -402,7 +402,7 @@ func TestProxyAttestedTLSMissingProvider(t *testing.T) {
err := ps.Start(cfg, ctx)
assert.Error(t, err)
assert.Equal(t, "attested TLS requested but no certificate provider available", err.Error())
assert.Equal(t, "attested TLS requested for ingress proxy but no certificate provider available. Please ensure a CC platform is detected (not NoCC), aTLS is enabled, and the attestation service is running", err.Error())
}
func TestProxyMTLS(t *testing.T) {
+1 -16
View File
@@ -1,7 +1,7 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
package ingress
import (
"crypto/tls"
@@ -144,18 +144,3 @@ func SetupRegularTLS(certFile, keyFile, serverCAFile, clientCAFile string) (*TLS
return &TLSSetupResult{Config: tlsConfig, MTLS: mtls}, nil
}
// BuildMTLSDescription builds a description string for mTLS configuration.
func BuildMTLSDescription(serverCAFile, clientCAFile string) string {
var parts []string
if serverCAFile != "" {
parts = append(parts, fmt.Sprintf("root ca %s", serverCAFile))
}
if clientCAFile != "" {
parts = append(parts, fmt.Sprintf("client ca %s", clientCAFile))
}
return strings.Join(parts, " ")
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package server contains the gRPC server implementation.
package server
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package grpc contains the gRPC server implementation.
package grpc
-265
View File
@@ -1,265 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"os"
"sync"
"time"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
const (
stopWaitTime = 5 * time.Second
)
type Server struct {
server.BaseServer
mu sync.RWMutex
server *grpc.Server
health *health.Server
registerService serviceRegister
authSvc auth.Authenticator
certProvider atls.CertificateProvider
attestedTLSEnabled bool
started bool
stopped bool
}
type serviceRegister func(srv *grpc.Server)
var _ server.Server = (*Server)(nil)
func New(
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
registerService serviceRegister, logger *slog.Logger, authSvc auth.Authenticator, certProvider atls.CertificateProvider,
) server.Server {
base := config.GetBaseConfig()
listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port)
var attestedTLS bool
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
if certProvider == nil {
logger.Error("Failed to create certificate provider")
} else {
attestedTLS = true
}
}
return &Server{
BaseServer: server.BaseServer{
Ctx: ctx,
Cancel: cancel,
Name: name,
Address: listenFullAddress,
Config: config,
Logger: logger,
},
registerService: registerService,
authSvc: authSvc,
certProvider: certProvider,
attestedTLSEnabled: attestedTLS,
}
}
func (s *Server) Start() error {
s.mu.Lock()
if s.started {
s.mu.Unlock()
return fmt.Errorf("server already started")
}
if s.stopped {
s.mu.Unlock()
return fmt.Errorf("server already stopped")
}
s.started = true
s.mu.Unlock()
errCh := make(chan error)
grpcServerOptions := []grpc.ServerOption{
grpc.StatsHandler(otelgrpc.NewServerHandler()),
}
// Add authentication interceptors if auth service is available
if s.authSvc != nil {
unary, stream := agentgrpc.NewAuthInterceptor(s.authSvc)
grpcServerOptions = append(grpcServerOptions, grpc.UnaryInterceptor(unary))
grpcServerOptions = append(grpcServerOptions, grpc.StreamInterceptor(stream))
}
// Configure credentials
creds, err := s.configureCredentials()
if err != nil {
return fmt.Errorf("failed to configure credentials: %w", err)
}
grpcServerOptions = append(grpcServerOptions, creds)
// Create listener - detect Unix socket vs TCP
var listener net.Listener
baseConfig := s.Config.GetBaseConfig()
// Check if this is a Unix socket path (starts with /)
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
// Unix socket
// Remove existing socket file if it exists
_ = os.Remove(baseConfig.Host)
listener, err = net.Listen("unix", baseConfig.Host)
if err != nil {
return fmt.Errorf("failed to listen on Unix socket %s: %w", baseConfig.Host, err)
}
} else {
// TCP socket
listener, err = net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
}
}
// Create and configure server
s.mu.Lock()
s.server = grpc.NewServer(grpcServerOptions...)
s.health = health.NewServer()
grpchealth.RegisterHealthServer(s.server, s.health)
s.registerService(s.server)
s.health.SetServingStatus(s.Name, grpchealth.HealthCheckResponse_SERVING)
s.mu.Unlock()
// Start server
go func() {
s.mu.RLock()
server := s.server
s.mu.RUnlock()
if server != nil {
errCh <- server.Serve(listener)
}
}()
select {
case <-s.Ctx.Done():
return s.Stop()
case err := <-errCh:
s.Cancel()
return err
}
}
func (s *Server) configureCredentials() (grpc.ServerOption, error) {
baseConfig := s.Config.GetBaseConfig()
// Check if attested TLS should be used
if s.shouldUseAttestedTLS() {
return s.configureAttestedTLS(baseConfig.Config)
}
// Check if regular TLS should be used
if s.shouldUseRegularTLS(baseConfig.Config) {
return s.configureRegularTLS(baseConfig.Config)
}
// Use insecure credentials
// Determine address for logging
addr := s.Address
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
addr = baseConfig.Host
}
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, addr))
return grpc.Creds(insecure.NewCredentials()), nil
}
func (s *Server) shouldUseAttestedTLS() bool {
return s.attestedTLSEnabled && s.certProvider != nil
}
func (s *Server) shouldUseRegularTLS(config server.Config) bool {
return config.CertFile != "" || config.KeyFile != ""
}
func (s *Server) configureAttestedTLS(config server.Config) (grpc.ServerOption, error) {
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
GetCertificate: s.certProvider.GetCertificate,
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, config.ServerCAFile, config.ClientCAFile)
if err != nil {
return nil, fmt.Errorf("failed to configure certificate authorities: %w", err)
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested mTLS", s.Name, s.Address))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with Attested TLS", s.Name, s.Address))
}
return grpc.Creds(credentials.NewTLS(tlsConfig)), nil
}
func (s *Server) configureRegularTLS(config server.Config) (grpc.ServerOption, error) {
tlsSetup, err := server.SetupRegularTLS(config.CertFile, config.KeyFile, config.ServerCAFile, config.ClientCAFile)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS: %w", err)
}
if tlsSetup.MTLS {
mtlsCA := server.BuildMTLSDescription(config.ServerCAFile, config.ClientCAFile)
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS/mTLS cert %s , key %s and %s",
s.Name, s.Address, config.CertFile, config.KeyFile, mtlsCA))
} else {
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s with TLS cert %s and key %s",
s.Name, s.Address, config.CertFile, config.KeyFile))
}
return grpc.Creds(credentials.NewTLS(tlsSetup.Config)), nil
}
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.stopped {
return nil
}
s.stopped = true
defer s.Cancel()
c := make(chan bool)
go func() {
defer close(c)
if s.health != nil {
s.health.Shutdown()
}
if s.server != nil {
s.server.GracefulStop()
}
}()
select {
case <-c:
case <-time.After(stopWaitTime):
}
s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address))
return nil
}
-527
View File
@@ -1,527 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log/slog"
"math/big"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks"
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/server"
"google.golang.org/grpc"
"google.golang.org/grpc/test/bufconn"
)
const bufSize = 1024 * 1024
var lis *bufconn.Listener
func init() {
lis = bufconn.Listen(bufSize)
}
func TestNew(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "50051",
},
},
}
logger := slog.Default()
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
assert.NotNil(t, srv)
assert.IsType(t, &Server{}, srv)
}
func TestServerStartWithTLSFile(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
certFile, err := os.CreateTemp("", "cert*.pem")
assert.NoError(t, err)
keyFile, err := os.CreateTemp("", "key*.pem")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(certFile.Name())
os.Remove(keyFile.Name())
})
_, err = certFile.Write(cert)
assert.NoError(t, err)
_, err = keyFile.Write(key)
assert.NoError(t, err)
err = certFile.Close()
assert.NoError(t, err)
err = keyFile.Close()
assert.NoError(t, err)
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
CertFile: certFile.Name(),
KeyFile: keyFile.Name(),
},
},
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStartWithmTLSFile(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
caCertFile, clientCertFile, clientKeyFile, err := createCertificatesFiles()
assert.NoError(t, err)
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
CertFile: string(clientCertFile),
KeyFile: string(clientKeyFile),
ServerCAFile: caCertFile,
},
},
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
assert.NoError(t, err)
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, "TestServer service gRPC server listening at localhost:0 with TLS")
}
func TestServerStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
},
},
}
buf := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, authSvc, nil)
go func() {
err := srv.Start()
assert.NoError(t, err)
}()
time.Sleep(100 * time.Millisecond)
cancel()
time.Sleep(100 * time.Millisecond)
err := srv.Stop()
assert.NoError(t, err)
assert.Contains(t, buf.String(), "TestServer gRPC service shutdown at localhost:0")
}
func generateSelfSignedCert() ([]byte, []byte, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
cert, err := generateSelfSignedCertFromKey(key)
if err != nil {
return nil, nil, err
}
return cert, pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}), nil
}
func generateSelfSignedCertFromKey(key *rsa.PrivateKey) ([]byte, error) {
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), nil
}
type ThreadSafeBuffer struct {
buffer strings.Builder
mu sync.Mutex
}
func (b *ThreadSafeBuffer) Write(p []byte) (n int, err error) {
b.mu.Lock()
defer b.mu.Unlock()
return b.buffer.Write(p)
}
func (b *ThreadSafeBuffer) String() string {
b.mu.Lock()
defer b.mu.Unlock()
return b.buffer.String()
}
func TestServerInitializationAndStartup(t *testing.T) {
vtpm.ExternalTPM = &vtpm.DummyRWC{}
testCases := []struct {
name string
config server.AgentConfig
expectedLog string
expectError bool
setupCallback func(*testing.T, *server.AgentConfig, *ThreadSafeBuffer)
}{
{
name: "Non-TLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
},
},
},
expectedLog: "TestServer service gRPC server listening at localhost:0 without TLS",
},
{
name: "TLS Server Startup with Self-Signed Certificate",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
},
},
},
setupCallback: setupTLSConfig,
expectedLog: "TestServer service gRPC server listening at localhost:0 with TLS",
},
{
name: "TLS Server Startup with Invalid Certificates",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
CertFile: "invalid",
KeyFile: "invalid",
},
},
},
expectError: true,
expectedLog: "failed to load auth certificates",
},
{
name: "maTLS Server Startup",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "",
ClientCAFile: "",
},
},
AttestedTLS: true,
},
setupCallback: setupMTLSConfig,
expectError: false,
expectedLog: "with Attested mTLS",
},
{
name: "maTLS Server Startup with Invalid Server CA file",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
AttestedTLS: true,
},
setupCallback: setupInvalidRootCAConfig,
expectError: true,
expectedLog: "failed to load server ca file",
},
{
name: "maTLS Server Startup with Invalid Clinet CA file",
config: server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: "localhost",
Port: "0",
ServerCAFile: "invalid",
},
},
AttestedTLS: true,
},
setupCallback: setupInvalidClientCAConfig,
expectError: true,
expectedLog: "failed to load client ca file",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if tc.setupCallback != nil {
tc.setupCallback(t, &tc.config, nil)
}
logBuffer := &ThreadSafeBuffer{}
logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug}))
authSvc := new(authmocks.Authenticator)
mockCertProvider := new(mocks.CertificateProvider)
srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, authSvc, mockCertProvider)
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Done()
err := srv.Start()
if tc.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedLog)
} else {
assert.NoError(t, err)
}
}()
wg.Wait()
time.Sleep(200 * time.Millisecond)
cancel()
time.Sleep(200 * time.Millisecond)
if !tc.expectError {
logContent := logBuffer.String()
fmt.Println(logContent)
assert.Contains(t, logContent, tc.expectedLog)
}
})
}
}
func setupTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
}
func setupMTLSConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ServerCAFile = string(cert)
config.ClientCAFile = string(cert)
}
func setupInvalidRootCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ServerCAFile = "invalid"
config.ClientCAFile = string(cert)
}
func setupInvalidClientCAConfig(t *testing.T, config *server.AgentConfig, _ *ThreadSafeBuffer) {
cert, key, err := generateSelfSignedCert()
assert.NoError(t, err)
config.CertFile = string(cert)
config.KeyFile = string(key)
config.ClientCAFile = "invalid"
config.ServerCAFile = string(cert)
}
func createCertificatesFiles() (string, string, string, error) {
caKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", "", err
}
caTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
if err != nil {
return "", "", "", err
}
caCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: caCertDER}))
if err != nil {
return "", "", "", err
}
clientKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", "", "", err
}
clientTemplate := x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, &clientTemplate, &caTemplate, &clientKey.PublicKey, caKey)
if err != nil {
return "", "", "", err
}
clientCertFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCertDER}))
if err != nil {
return "", "", "", err
}
clientKeyFile, err := createTempFile(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)}))
if err != nil {
return "", "", "", err
}
return caCertFile, clientCertFile, clientKeyFile, nil
}
func createTempFile(data []byte) (string, error) {
file, err := createTempFileHandle()
if err != nil {
return "", err
}
_, err = file.Write(data)
if err != nil {
return "", err
}
err = file.Close()
if err != nil {
return "", err
}
return file.Name(), nil
}
func createTempFileHandle() (*os.File, error) {
return os.CreateTemp("", "test")
}
-177
View File
@@ -1,177 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net/http"
smqserver "github.com/absmach/supermq/pkg/server"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
)
const (
httpProtocol = "http"
httpsProtocol = "https"
)
type httpServer struct {
server.BaseServer
server *http.Server
certProvider atls.CertificateProvider
attestedTLSEnabled bool
}
var _ server.Server = (*httpServer)(nil)
func NewServer(
ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration,
handler http.Handler, logger *slog.Logger, certProvider atls.CertificateProvider,
) server.Server {
baseServer := server.NewBaseServer(ctx, cancel, name, config, logger)
hserver := &http.Server{Addr: baseServer.Address, Handler: handler}
var attestedTLS bool
if agentConfig, ok := config.(server.AgentConfig); ok && agentConfig.AttestedTLS {
if certProvider == nil {
logger.Error("Failed to create certificate provider")
} else {
attestedTLS = true
}
}
return &httpServer{
BaseServer: baseServer,
server: hserver,
certProvider: certProvider,
attestedTLSEnabled: attestedTLS,
}
}
func (s *httpServer) Start() error {
s.Protocol = httpProtocol
if s.shouldUseAttestedTLS() {
return s.startWithAttestedTLS()
}
if s.shouldUseRegularTLS() {
return s.startWithRegularTLS()
}
return s.startWithoutTLS()
}
func (s *httpServer) Stop() error {
defer s.Cancel()
ctx, cancel := context.WithTimeout(context.Background(), smqserver.StopWaitTime)
defer cancel()
if err := s.server.Shutdown(ctx); err != nil {
s.Logger.Error(fmt.Sprintf(
"%s service %s server error occurred during shutdown at %s: %s", s.Name, s.Protocol, s.Address, err))
return fmt.Errorf("%s service %s server error occurred during shutdown at %s: %w", s.Name, s.Protocol, s.Address, err)
}
s.Logger.Info(fmt.Sprintf("%s %s service shutdown of http at %s", s.Name, s.Protocol, s.Address))
return nil
}
func (s *httpServer) shouldUseAttestedTLS() bool {
return s.attestedTLSEnabled && s.certProvider != nil
}
func (s *httpServer) shouldUseRegularTLS() bool {
return s.Config.GetBaseConfig().CertFile != "" || s.Config.GetBaseConfig().KeyFile != ""
}
func (s *httpServer) startWithAttestedTLS() error {
tlsConfig := &tls.Config{
ClientAuth: tls.NoClientCert,
GetCertificate: s.certProvider.GetCertificate,
}
baseConfig := s.Config.GetBaseConfig()
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
s.server.TLSConfig = tlsConfig
s.Protocol = httpsProtocol
s.logAttestedTLSStart(mtls)
return s.listenAndServe(true)
}
func (s *httpServer) startWithRegularTLS() error {
baseConfig := s.Config.GetBaseConfig()
tlsSetup, err := server.SetupRegularTLS(baseConfig.CertFile, baseConfig.KeyFile, baseConfig.ServerCAFile, baseConfig.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup TLS: %w", err)
}
s.server.TLSConfig = tlsSetup.Config
s.Protocol = httpsProtocol
s.logRegularTLSStart(tlsSetup.MTLS)
return s.listenAndServe(true)
}
func (s *httpServer) startWithoutTLS() error {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s without TLS", s.Name, s.Protocol, s.Address))
return s.listenAndServe(false)
}
func (s *httpServer) logAttestedTLSStart(mtls bool) {
if mtls {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested mTLS", s.Name, s.Protocol, s.Address))
} else {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with Attested TLS", s.Name, s.Protocol, s.Address))
}
}
func (s *httpServer) logRegularTLSStart(mtls bool) {
baseConfig := s.Config.GetBaseConfig()
if mtls {
s.Logger.Info(fmt.Sprintf(
"%s service %s server listening at %s with TLS/mTLS cert %s , key %s and CAs %s, %s",
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile,
baseConfig.ServerCAFile, baseConfig.ClientCAFile))
} else {
s.Logger.Info(fmt.Sprintf("%s service %s server listening at %s with TLS cert %s and key %s",
s.Name, s.Protocol, s.Address, baseConfig.CertFile, baseConfig.KeyFile))
}
}
func (s *httpServer) listenAndServe(useTLS bool) error {
errCh := make(chan error, 1)
go func() {
if useTLS {
cfg := s.Config.GetBaseConfig()
errCh <- s.server.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile)
} else {
errCh <- s.server.ListenAndServe()
}
}()
select {
case <-s.Ctx.Done():
return s.Stop()
case err := <-errCh:
return err
}
}
-411
View File
@@ -1,411 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"crypto/tls"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
"github.com/ultravioletrs/cocos/pkg/server"
)
// Mock implementations for testing.
type mockHandler struct {
mock.Mock
}
func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.Called(w, r)
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("test response")); err != nil {
panic(err)
}
}
type mockBaseConfig struct {
certFile string
keyFile string
serverCAFile string
clientCAFile string
host string
port string
}
func (m *mockBaseConfig) GetCertFile() string { return m.certFile }
func (m *mockBaseConfig) GetKeyFile() string { return m.keyFile }
func (m *mockBaseConfig) GetServerCAFile() string { return m.serverCAFile }
func (m *mockBaseConfig) GetClientCAFile() string { return m.clientCAFile }
type mockServerConfig struct {
baseConfig *mockBaseConfig
}
func (m *mockServerConfig) GetHost() string { return "localhost" }
func (m *mockServerConfig) GetPort() string { return "8080" }
func (m *mockServerConfig) GetBaseConfig() server.ServerConfig {
return server.ServerConfig{Config: server.Config{CertFile: m.baseConfig.certFile, KeyFile: m.baseConfig.keyFile, ServerCAFile: m.baseConfig.serverCAFile, ClientCAFile: m.baseConfig.clientCAFile, Host: m.baseConfig.host, Port: m.baseConfig.port}}
}
func TestNewServer(t *testing.T) {
ctx := context.Background()
cancel := func() {}
name := "test-server"
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
logger := slog.Default()
server := NewServer(ctx, cancel, name, config, handler, logger, nil)
assert.NotNil(t, server)
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, handler, httpSrv.server.Handler)
}
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
mockCertProvider := new(mocks.CertificateProvider)
tests := []struct {
name string
config server.ServerConfiguration
expected bool
certProvider atls.CertificateProvider
}{
{
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and certProvider is not empty",
config: server.AgentConfig{
AttestedTLS: true,
},
certProvider: mockCertProvider,
expected: true,
},
{
name: "should not use attested TLS when certProvider is empty",
config: server.AgentConfig{
AttestedTLS: true,
},
certProvider: nil,
expected: false,
},
{
name: "should not use attested TLS when AttestedTLS is false",
config: server.AgentConfig{
AttestedTLS: false,
},
certProvider: mockCertProvider,
expected: false,
},
{
name: "should not use attested TLS when config is not AgentConfig",
config: &mockServerConfig{
baseConfig: &mockBaseConfig{},
},
certProvider: mockCertProvider,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.certProvider)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseAttestedTLS()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
tests := []struct {
name string
certFile string
keyFile string
expected bool
}{
{
name: "should use regular TLS when both cert and key files are provided",
certFile: "cert.pem",
keyFile: "key.pem",
expected: true,
},
{
name: "should use regular TLS when only cert file is provided",
certFile: "cert.pem",
keyFile: "",
expected: true,
},
{
name: "should use regular TLS when only key file is provided",
certFile: "",
keyFile: "key.pem",
expected: true,
},
{
name: "should not use regular TLS when neither cert nor key files are provided",
certFile: "",
keyFile: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: tt.certFile,
keyFile: tt.keyFile,
},
}
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseRegularTLS()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHttpServer_Stop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Start a test server that we can control
testServer := httptest.NewServer(handler)
defer testServer.Close()
// Replace the server's HTTP server with our test server's
httpSrv.server = testServer.Config
err := httpSrv.Stop()
assert.NoError(t, err)
}
func TestHttpServer_logAttestedTLSStart(t *testing.T) {
tests := []struct {
name string
mtls bool
}{
{
name: "log attested mTLS start",
mtls: true,
},
{
name: "log attested TLS start",
mtls: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
// In a real scenario, you might want to capture log output
assert.NotPanics(t, func() {
httpSrv.logAttestedTLSStart(tt.mtls)
})
})
}
}
func TestHttpServer_logRegularTLSStart(t *testing.T) {
tests := []struct {
name string
mtls bool
}{
{
name: "log regular mTLS start",
mtls: true,
},
{
name: "log regular TLS start",
mtls: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: "cert.pem",
keyFile: "key.pem",
serverCAFile: "server-ca.pem",
clientCAFile: "client-ca.pem",
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
assert.NotPanics(t, func() {
httpSrv.logRegularTLSStart(tt.mtls)
})
})
}
}
func TestHttpServer_startWithoutTLS(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Use a test server to avoid binding to actual ports
testServer := httptest.NewServer(handler)
defer testServer.Close()
httpSrv.server = testServer.Config
err := httpSrv.startWithoutTLS()
// The error will be related to context cancellation or server shutdown
assert.Error(t, err)
}
func TestHttpServer_Protocol(t *testing.T) {
tests := []struct {
name string
setupTLS func(*httpServer)
expectedProto string
}{
{
name: "HTTP protocol without TLS",
setupTLS: func(s *httpServer) {
s.Protocol = httpProtocol
},
expectedProto: httpProtocol,
},
{
name: "HTTPS protocol with TLS",
setupTLS: func(s *httpServer) {
s.Protocol = httpsProtocol
s.server.TLSConfig = &tls.Config{}
},
expectedProto: httpsProtocol,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
tt.setupTLS(httpSrv)
assert.Equal(t, tt.expectedProto, httpSrv.Protocol)
})
}
}
func TestHttpServer_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Cancel the context immediately
cancel()
// The listenAndServe method should handle context cancellation
err := httpSrv.listenAndServe(false)
assert.NoError(t, err) // Should return no error when context is cancelled and Stop() succeeds
}
func TestHttpServer_TLSConfiguration(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: "cert.pem",
keyFile: "key.pem",
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Test TLS configuration setup
httpSrv.server.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
assert.NotNil(t, httpSrv.server.TLSConfig)
assert.Equal(t, uint16(tls.VersionTLS12), httpSrv.server.TLSConfig.MinVersion)
}
// Integration-style test for server lifecycle.
func TestHttpServer_Lifecycle(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
host: "localhost",
port: "8080",
},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
// Test that server can be created and has expected initial state
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, "localhost:8080", httpSrv.server.Addr)
// Test Stop without Start (should not panic)
err := httpSrv.Stop()
assert.NoError(t, err)
}
-127
View File
@@ -1,127 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
mock "github.com/stretchr/testify/mock"
)
// NewServer creates a new instance of Server. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewServer(t interface {
mock.TestingT
Cleanup(func())
}) *Server {
mock := &Server{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Server is an autogenerated mock type for the Server type
type Server struct {
mock.Mock
}
type Server_Expecter struct {
mock *mock.Mock
}
func (_m *Server) EXPECT() *Server_Expecter {
return &Server_Expecter{mock: &_m.Mock}
}
// Start provides a mock function for the type Server
func (_mock *Server) Start() error {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for Start")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func() error); ok {
r0 = returnFunc()
} else {
r0 = ret.Error(0)
}
return r0
}
// Server_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
type Server_Start_Call struct {
*mock.Call
}
// Start is a helper method to define mock.On call
func (_e *Server_Expecter) Start() *Server_Start_Call {
return &Server_Start_Call{Call: _e.mock.On("Start")}
}
func (_c *Server_Start_Call) Run(run func()) *Server_Start_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *Server_Start_Call) Return(err error) *Server_Start_Call {
_c.Call.Return(err)
return _c
}
func (_c *Server_Start_Call) RunAndReturn(run func() error) *Server_Start_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function for the type Server
func (_mock *Server) Stop() error {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for Stop")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func() error); ok {
r0 = returnFunc()
} else {
r0 = ret.Error(0)
}
return r0
}
// Server_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type Server_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
func (_e *Server_Expecter) Stop() *Server_Stop_Call {
return &Server_Stop_Call{Call: _e.mock.On("Stop")}
}
func (_c *Server_Stop_Call) Run(run func()) *Server_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *Server_Stop_Call) Return(err error) *Server_Stop_Call {
_c.Call.Return(err)
return _c
}
func (_c *Server_Stop_Call) RunAndReturn(run func() error) *Server_Stop_Call {
_c.Call.Return(run)
return _c
}
-105
View File
@@ -1,105 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"syscall"
)
type Server interface {
Start() error
Stop() error
}
type ServerConfiguration interface {
GetBaseConfig() ServerConfig
}
type Config struct {
Host string `env:"HOST" envDefault:"localhost"`
Port string `env:"PORT" envDefault:"7001"`
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
CertFile string `env:"SERVER_CERT" envDefault:""`
KeyFile string `env:"SERVER_KEY" envDefault:""`
ClientCAFile string `env:"CLIENT_CA_CERTS" envDefault:""`
}
type ServerConfig struct {
Config
}
type AgentConfig struct {
ServerConfig
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
}
type BaseServer struct {
Ctx context.Context
Cancel context.CancelFunc
Name string
Address string
Config ServerConfiguration
Logger *slog.Logger
Protocol string
}
func (s ServerConfig) GetBaseConfig() ServerConfig {
return s
}
func (a AgentConfig) GetBaseConfig() ServerConfig {
return a.ServerConfig
}
func NewBaseServer(
ctx context.Context, cancel context.CancelFunc, name string, config ServerConfiguration, logger *slog.Logger,
) BaseServer {
cfg := config.GetBaseConfig()
address := fmt.Sprintf("%s:%s", cfg.Host, cfg.Port)
return BaseServer{
Ctx: ctx,
Cancel: cancel,
Name: name,
Address: address,
Config: config,
Logger: logger,
}
}
func StopHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger, svcName string, servers ...Server) error {
var err error
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT, syscall.SIGABRT)
select {
case sig := <-c:
defer cancel()
err = stopAllServer(servers...)
if err != nil {
logger.Error(fmt.Sprintf("%s service error during shutdown: %v", svcName, err))
}
logger.Info(fmt.Sprintf("%s service shutdown by signal: %s", svcName, sig))
return err
case <-ctx.Done():
return nil
}
}
func stopAllServer(servers ...Server) error {
var errs []error
for _, server := range servers {
if err := server.Stop(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return fmt.Errorf("encountered errors while stopping servers: %v", errs)
}
return nil
}
-138
View File
@@ -1,138 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
"context"
"errors"
"log/slog"
"os"
"syscall"
"testing"
"time"
"github.com/ultravioletrs/cocos/pkg/server/mocks"
)
func TestStopAllServer(t *testing.T) {
server1 := new(mocks.Server)
server2 := new(mocks.Server)
server1.On("Stop").Return(nil)
server2.On("Stop").Return(errors.New("failed to stop"))
tests := []struct {
name string
servers []Server
expectedError bool
}{
{
name: "All servers stop successfully",
servers: []Server{
server1,
server1,
},
expectedError: false,
},
{
name: "One server fails to stop",
servers: []Server{
server1,
server2,
},
expectedError: true,
},
{
name: "No servers",
servers: []Server{},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := stopAllServer(tt.servers...)
if (err != nil) != tt.expectedError {
t.Errorf("stopAllServer() error = %v, expectedError %v", err, tt.expectedError)
}
})
}
}
func TestStopHandler(t *testing.T) {
mockServer := new(mocks.Server)
mockServer.On("Stop").Return(nil)
tests := []struct {
name string
setupFunc func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server)
triggerSignal bool
expectedError bool
expectCanceled bool
}{
{
name: "Graceful shutdown on signal",
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
ctx, cancel := context.WithCancel(context.Background())
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
return ctx, cancel, logger, "test", []Server{mockServer}
},
triggerSignal: true,
expectedError: false,
expectCanceled: true,
},
{
name: "Context canceled",
setupFunc: func() (context.Context, context.CancelFunc, *slog.Logger, string, []Server) {
ctx, cancel := context.WithCancel(context.Background())
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
return ctx, cancel, logger, "test", []Server{mockServer}
},
triggerSignal: false,
expectedError: false,
expectCanceled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel, logger, svcName, servers := tt.setupFunc()
defer cancel()
errChan := make(chan error)
go func() {
errChan <- StopHandler(ctx, cancel, logger, svcName, servers...)
}()
if tt.triggerSignal {
// Simulate SIGINT
go func() {
time.Sleep(100 * time.Millisecond)
err := syscall.Kill(syscall.Getpid(), syscall.SIGINT)
if err != nil {
t.Errorf("failed to send signal: %v", err)
}
}()
}
select {
case err := <-errChan:
if (err != nil) != tt.expectedError {
t.Errorf("StopHandler() error = %v, expectedError %v", err, tt.expectedError)
}
case <-time.After(2 * time.Second):
t.Error("StopHandler() timed out")
}
if tt.expectCanceled {
select {
case <-ctx.Done():
// Context was canceled as expected
default:
t.Error("Context was not canceled")
}
}
})
}
}
-741
View File
@@ -1,741 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"strings"
"testing"
"time"
)
// Helper function to generate a test certificate and key.
func generateTestCert() (certPEM, keyPEM []byte, err error) {
// Generate private key
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test City"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
IPAddresses: nil,
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return nil, nil, err
}
// Encode certificate
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
// Encode private key
privateKeyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, nil, err
}
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: privateKeyDER,
})
return certPEM, keyPEM, nil
}
// Helper function to create temporary files for testing.
func createTempFile(t *testing.T, content []byte) string {
tmpFile, err := os.CreateTemp("", "test-cert-*.pem")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer tmpFile.Close()
if _, err := tmpFile.Write(content); err != nil {
t.Fatalf("Failed to write temp file: %v", err)
}
return tmpFile.Name()
}
func TestLoadCertFile(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
tests := []struct {
name string
certFile string
wantErr bool
setup func() string
cleanup func(string)
}{
{
name: "empty cert file path",
certFile: "",
wantErr: false,
},
{
name: "valid cert file",
wantErr: false,
setup: func() string {
return createTempFile(t, certPEM)
},
cleanup: func(path string) {
os.Remove(path)
},
},
{
name: "non-existent file",
certFile: "/non/existent/file.pem",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
certFile := tt.certFile
if tt.setup != nil {
certFile = tt.setup()
}
if tt.cleanup != nil {
defer tt.cleanup(certFile)
}
data, err := LoadCertFile(certFile)
if (err != nil) != tt.wantErr {
t.Errorf("LoadCertFile() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.certFile != "" && !tt.wantErr && len(data) == 0 {
t.Errorf("LoadCertFile() with valid file should return data, got empty")
}
})
}
}
func TestReadFileOrData(t *testing.T) {
testData := "test certificate data"
tempFile := createTempFile(t, []byte(testData))
defer os.Remove(tempFile)
tests := []struct {
name string
input string
want string
wantErr bool
}{
{
name: "file path",
input: tempFile,
want: testData,
wantErr: false,
},
{
name: "raw data with newlines",
input: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
want: "-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----",
wantErr: false,
},
{
name: "short raw data without newlines",
input: "short data",
want: "short data",
wantErr: true,
},
{
name: "non-existent file path",
input: "/non/existent/file.pem",
want: "",
wantErr: true,
},
{
name: "empty input",
input: "",
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ReadFileOrData(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && string(got) != tt.want {
t.Errorf("ReadFileOrData() = %v, want %v", string(got), tt.want)
}
})
}
}
func TestLoadX509KeyPair(t *testing.T) {
certPEM, keyPEM, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
certFile := createTempFile(t, certPEM)
keyFile := createTempFile(t, keyPEM)
defer os.Remove(certFile)
defer os.Remove(keyFile)
tests := []struct {
name string
certfile string
keyfile string
wantErr bool
}{
{
name: "valid cert and key files",
certfile: certFile,
keyfile: keyFile,
wantErr: false,
},
{
name: "valid cert and key data",
certfile: string(certPEM),
keyfile: string(keyPEM),
wantErr: false,
},
{
name: "non-existent cert file",
certfile: "/non/existent/cert.pem",
keyfile: keyFile,
wantErr: true,
},
{
name: "non-existent key file",
certfile: certFile,
keyfile: "/non/existent/key.pem",
wantErr: true,
},
{
name: "invalid cert data",
certfile: "invalid cert data",
keyfile: string(keyPEM),
wantErr: true,
},
{
name: "invalid key data",
certfile: string(certPEM),
keyfile: "invalid key data",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cert, err := LoadX509KeyPair(tt.certfile, tt.keyfile)
if (err != nil) != tt.wantErr {
t.Errorf("LoadX509KeyPair() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(cert.Certificate) == 0 {
t.Errorf("LoadX509KeyPair() returned empty certificate")
}
})
}
}
func TestConfigureRootCA(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
caFile := createTempFile(t, certPEM)
defer os.Remove(caFile)
tests := []struct {
name string
tlsConfig *tls.Config
serverCAFile string
wantErr bool
expectCA bool
}{
{
name: "valid CA file",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
wantErr: false,
expectCA: true,
},
{
name: "valid CA data",
tlsConfig: &tls.Config{},
serverCAFile: string(certPEM),
wantErr: false,
expectCA: true,
},
{
name: "empty CA file",
tlsConfig: &tls.Config{},
serverCAFile: "",
wantErr: false,
expectCA: false,
},
{
name: "non-existent CA file",
tlsConfig: &tls.Config{},
serverCAFile: "/non/existent/ca.pem",
wantErr: true,
expectCA: false,
},
{
name: "invalid CA data",
tlsConfig: &tls.Config{},
serverCAFile: "invalid ca data",
wantErr: true,
expectCA: false,
},
{
name: "existing RootCAs pool",
tlsConfig: &tls.Config{RootCAs: x509.NewCertPool()},
serverCAFile: caFile,
wantErr: false,
expectCA: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ConfigureRootCA(tt.tlsConfig, tt.serverCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("ConfigureRootCA() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.expectCA && tt.tlsConfig.RootCAs == nil {
t.Errorf("ConfigureRootCA() should have created RootCAs pool")
}
if !tt.expectCA && tt.tlsConfig.RootCAs != nil && tt.serverCAFile == "" {
t.Errorf("ConfigureRootCA() should not have created RootCAs pool for empty file")
}
})
}
}
func TestConfigureClientCA(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
caFile := createTempFile(t, certPEM)
defer os.Remove(caFile)
tests := []struct {
name string
tlsConfig *tls.Config
clientCAFile string
wantConfigured bool
wantErr bool
}{
{
name: "valid client CA file",
tlsConfig: &tls.Config{},
clientCAFile: caFile,
wantConfigured: true,
wantErr: false,
},
{
name: "valid client CA data",
tlsConfig: &tls.Config{},
clientCAFile: string(certPEM),
wantConfigured: true,
wantErr: false,
},
{
name: "empty client CA file",
tlsConfig: &tls.Config{},
clientCAFile: "",
wantConfigured: false,
wantErr: false,
},
{
name: "non-existent client CA file",
tlsConfig: &tls.Config{},
clientCAFile: "/non/existent/ca.pem",
wantConfigured: false,
wantErr: true,
},
{
name: "invalid client CA data",
tlsConfig: &tls.Config{},
clientCAFile: "invalid ca data",
wantConfigured: false,
wantErr: true,
},
{
name: "existing ClientCAs pool",
tlsConfig: &tls.Config{ClientCAs: x509.NewCertPool()},
clientCAFile: caFile,
wantConfigured: true,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configured, err := ConfigureClientCA(tt.tlsConfig, tt.clientCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("ConfigureClientCA() error = %v, wantErr %v", err, tt.wantErr)
return
}
if configured != tt.wantConfigured {
t.Errorf("ConfigureClientCA() configured = %v, want %v", configured, tt.wantConfigured)
}
if tt.wantConfigured && tt.tlsConfig.ClientCAs == nil {
t.Errorf("ConfigureClientCA() should have created ClientCAs pool")
}
})
}
}
func TestConfigureCertificateAuthorities(t *testing.T) {
certPEM, _, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
caFile := createTempFile(t, certPEM)
defer os.Remove(caFile)
tests := []struct {
name string
tlsConfig *tls.Config
serverCAFile string
clientCAFile string
wantMTLS bool
wantErr bool
}{
{
name: "both server and client CA",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
clientCAFile: caFile,
wantMTLS: true,
wantErr: false,
},
{
name: "only server CA",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
clientCAFile: "",
wantMTLS: false,
wantErr: false,
},
{
name: "only client CA",
tlsConfig: &tls.Config{},
serverCAFile: "",
clientCAFile: caFile,
wantMTLS: true,
wantErr: false,
},
{
name: "no CAs",
tlsConfig: &tls.Config{},
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: false,
},
{
name: "invalid server CA",
tlsConfig: &tls.Config{},
serverCAFile: "/non/existent/server-ca.pem",
clientCAFile: caFile,
wantMTLS: false,
wantErr: true,
},
{
name: "invalid client CA",
tlsConfig: &tls.Config{},
serverCAFile: caFile,
clientCAFile: "/non/existent/client-ca.pem",
wantMTLS: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mtls, err := ConfigureCertificateAuthorities(tt.tlsConfig, tt.serverCAFile, tt.clientCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("ConfigureCertificateAuthorities() error = %v, wantErr %v", err, tt.wantErr)
return
}
if mtls != tt.wantMTLS {
t.Errorf("ConfigureCertificateAuthorities() mtls = %v, want %v", mtls, tt.wantMTLS)
}
})
}
}
func TestSetupRegularTLS(t *testing.T) {
certPEM, keyPEM, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test cert: %v", err)
}
certFile := createTempFile(t, certPEM)
keyFile := createTempFile(t, keyPEM)
caFile := createTempFile(t, certPEM)
defer func() {
os.Remove(certFile)
os.Remove(keyFile)
os.Remove(caFile)
}()
tests := []struct {
name string
certFile string
keyFile string
serverCAFile string
clientCAFile string
wantMTLS bool
wantErr bool
expectedAuth tls.ClientAuthType
}{
{
name: "regular TLS without mTLS",
certFile: certFile,
keyFile: keyFile,
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: false,
expectedAuth: tls.NoClientCert,
},
{
name: "TLS with mTLS",
certFile: certFile,
keyFile: keyFile,
serverCAFile: caFile,
clientCAFile: caFile,
wantMTLS: true,
wantErr: false,
expectedAuth: tls.RequireAndVerifyClientCert,
},
{
name: "TLS with only server CA",
certFile: certFile,
keyFile: keyFile,
serverCAFile: caFile,
clientCAFile: "",
wantMTLS: false,
wantErr: false,
expectedAuth: tls.NoClientCert,
},
{
name: "invalid certificate file",
certFile: "/non/existent/cert.pem",
keyFile: keyFile,
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: true,
expectedAuth: tls.NoClientCert,
},
{
name: "invalid key file",
certFile: certFile,
keyFile: "/non/existent/key.pem",
serverCAFile: "",
clientCAFile: "",
wantMTLS: false,
wantErr: true,
expectedAuth: tls.NoClientCert,
},
{
name: "invalid server CA file",
certFile: certFile,
keyFile: keyFile,
serverCAFile: "/non/existent/server-ca.pem",
clientCAFile: "",
wantMTLS: false,
wantErr: true,
expectedAuth: tls.NoClientCert,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := SetupRegularTLS(tt.certFile, tt.keyFile, tt.serverCAFile, tt.clientCAFile)
if (err != nil) != tt.wantErr {
t.Errorf("SetupRegularTLS() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if result == nil {
t.Errorf("SetupRegularTLS() returned nil result")
return
}
if result.MTLS != tt.wantMTLS {
t.Errorf("SetupRegularTLS() MTLS = %v, want %v", result.MTLS, tt.wantMTLS)
}
if result.Config.ClientAuth != tt.expectedAuth {
t.Errorf("SetupRegularTLS() ClientAuth = %v, want %v", result.Config.ClientAuth, tt.expectedAuth)
}
if len(result.Config.Certificates) == 0 {
t.Errorf("SetupRegularTLS() should have at least one certificate")
}
})
}
}
func TestBuildMTLSDescription(t *testing.T) {
tests := []struct {
name string
serverCAFile string
clientCAFile string
want string
}{
{
name: "both server and client CA files",
serverCAFile: "/path/to/server-ca.pem",
clientCAFile: "/path/to/client-ca.pem",
want: "root ca /path/to/server-ca.pem client ca /path/to/client-ca.pem",
},
{
name: "only server CA file",
serverCAFile: "/path/to/server-ca.pem",
clientCAFile: "",
want: "root ca /path/to/server-ca.pem",
},
{
name: "only client CA file",
serverCAFile: "",
clientCAFile: "/path/to/client-ca.pem",
want: "client ca /path/to/client-ca.pem",
},
{
name: "no CA files",
serverCAFile: "",
clientCAFile: "",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := BuildMTLSDescription(tt.serverCAFile, tt.clientCAFile)
if got != tt.want {
t.Errorf("BuildMTLSDescription() = %v, want %v", got, tt.want)
}
})
}
}
func TestErrorConstants(t *testing.T) {
// Test that error constants are properly defined
if ErrAppendServerCA == nil {
t.Error("ErrAppendServerCA should not be nil")
}
if ErrAppendClientCA == nil {
t.Error("ErrAppendClientCA should not be nil")
}
if ErrAppendServerCA.Error() != "failed to append server ca to tls.Config" {
t.Errorf("ErrAppendServerCA message = %v, want 'failed to append server ca to tls.Config'", ErrAppendServerCA.Error())
}
if ErrAppendClientCA.Error() != "failed to append client ca to tls.Config" {
t.Errorf("ErrAppendClientCA message = %v, want 'failed to append client ca to tls.Config'", ErrAppendClientCA.Error())
}
}
func TestTLSSetupResult(t *testing.T) {
// Test that TLSSetupResult struct works as expected
config := &tls.Config{}
result := &TLSSetupResult{
Config: config,
MTLS: true,
}
if result.Config != config {
t.Error("TLSSetupResult Config field should match assigned value")
}
if !result.MTLS {
t.Error("TLSSetupResult MTLS field should be true")
}
}
func TestReadFileOrDataEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "999 chars without newline (should try file)",
input: strings.Repeat("a", 999),
wantErr: true, // Should fail as file doesn't exist
},
{
name: "1001 chars without newline (should treat as data)",
input: strings.Repeat("a", 1001),
wantErr: false,
},
{
name: "short string with newline (should treat as data)",
input: "short\ndata",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ReadFileOrData(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ReadFileOrData() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
+6 -8
View File
@@ -15,12 +15,12 @@ import (
"strings"
mglog "github.com/absmach/supermq/logger"
smqserver "github.com/absmach/supermq/pkg/server"
grpcserver "github.com/absmach/supermq/pkg/server/grpc"
"github.com/caarlos0/env/v11"
"github.com/ultravioletrs/cocos/agent/cvms"
cvmsgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/pkg/server"
grpcserver "github.com/ultravioletrs/cocos/pkg/server/grpc"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -314,24 +314,22 @@ func main() {
reflection.Register(srv)
cvms.RegisterServiceServer(srv, cvmsgrpc.NewServer(incomingChan, &svc{logger: logger}))
}
grpcServerConfig := server.ServerConfig{
Config: server.Config{
Port: defaultPort,
},
grpcServerConfig := smqserver.Config{
Port: defaultPort,
}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
return
}
gs := grpcserver.New(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger, nil, nil)
gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerAgentServiceServer, logger)
g.Go(func() error {
return gs.Start()
})
g.Go(func() error {
return server.StopHandler(ctx, cancel, logger, svcName, gs)
return smqserver.StopSignalHandler(ctx, cancel, logger, svcName, gs)
})
if err := g.Wait(); err != nil {