Files
cocos/pkg/server/http/http_test.go
T
Sammy Kerata Oina c758b3b216 NOISSUE - Refactor aTLS and gRPC server to use CertificateProvider interface (#522)
* Refactor ATLS and gRPC server to use CertificateProvider interface

- Removed unused test cases and mock dependencies in atls_test.go.
- Updated TestGetPlatformVerifier to use CertificateVerifier struct.
- Introduced CertificateProvider interface for better abstraction in TLS handling.
- Refactored gRPC server to accept CertificateProvider and configure TLS accordingly.
- Simplified TLS configuration logic in both gRPC and HTTP servers.
- Removed unnecessary parameters from server initialization in tests and main function.
- Enhanced logging for TLS configurations.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix comments for consistency and clarity in atls.go

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Update expected error messages in VM command tests for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance tests by integrating mock providers and improving error messages for clarity

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add comprehensive tests for certificate generation and attestation providers

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Implement certificate and attestation providers with unified generation logic

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor certificate and attestation provider structures for consistency; implement CertificateVerifier interface and related methods

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Refactor attestation and certificate provider methods for consistency; rename methods and update related logic

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2025-09-23 14:49:23 +02:00

412 lines
10 KiB
Go

// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"crypto/tls"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
"github.com/ultravioletrs/cocos/pkg/server"
)
// Mock implementations for testing.
type mockHandler struct {
mock.Mock
}
func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
m.Called(w, r)
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("test response")); err != nil {
panic(err)
}
}
type mockBaseConfig struct {
certFile string
keyFile string
serverCAFile string
clientCAFile string
host string
port string
}
func (m *mockBaseConfig) GetCertFile() string { return m.certFile }
func (m *mockBaseConfig) GetKeyFile() string { return m.keyFile }
func (m *mockBaseConfig) GetServerCAFile() string { return m.serverCAFile }
func (m *mockBaseConfig) GetClientCAFile() string { return m.clientCAFile }
type mockServerConfig struct {
baseConfig *mockBaseConfig
}
func (m *mockServerConfig) GetHost() string { return "localhost" }
func (m *mockServerConfig) GetPort() string { return "8080" }
func (m *mockServerConfig) GetBaseConfig() server.ServerConfig {
return server.ServerConfig{Config: server.Config{CertFile: m.baseConfig.certFile, KeyFile: m.baseConfig.keyFile, ServerCAFile: m.baseConfig.serverCAFile, ClientCAFile: m.baseConfig.clientCAFile, Host: m.baseConfig.host, Port: m.baseConfig.port}}
}
func TestNewServer(t *testing.T) {
ctx := context.Background()
cancel := func() {}
name := "test-server"
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
logger := slog.Default()
server := NewServer(ctx, cancel, name, config, handler, logger, nil)
assert.NotNil(t, server)
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, handler, httpSrv.server.Handler)
}
func TestHttpServer_shouldUseAttestedTLS(t *testing.T) {
mockCertProvider := new(mocks.CertificateProvider)
tests := []struct {
name string
config server.ServerConfiguration
expected bool
certProvider atls.CertificateProvider
}{
{
name: "should use attested TLS when config is AgentConfig and AttestedTLS is true and certProvider is not empty",
config: server.AgentConfig{
AttestedTLS: true,
},
certProvider: mockCertProvider,
expected: true,
},
{
name: "should not use attested TLS when certProvider is empty",
config: server.AgentConfig{
AttestedTLS: true,
},
certProvider: nil,
expected: false,
},
{
name: "should not use attested TLS when AttestedTLS is false",
config: server.AgentConfig{
AttestedTLS: false,
},
certProvider: mockCertProvider,
expected: false,
},
{
name: "should not use attested TLS when config is not AgentConfig",
config: &mockServerConfig{
baseConfig: &mockBaseConfig{},
},
certProvider: mockCertProvider,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
server := NewServer(ctx, cancel, "test", tt.config, &mockHandler{}, slog.Default(), tt.certProvider)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseAttestedTLS()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHttpServer_shouldUseRegularTLS(t *testing.T) {
tests := []struct {
name string
certFile string
keyFile string
expected bool
}{
{
name: "should use regular TLS when both cert and key files are provided",
certFile: "cert.pem",
keyFile: "key.pem",
expected: true,
},
{
name: "should use regular TLS when only cert file is provided",
certFile: "cert.pem",
keyFile: "",
expected: true,
},
{
name: "should use regular TLS when only key file is provided",
certFile: "",
keyFile: "key.pem",
expected: true,
},
{
name: "should not use regular TLS when neither cert nor key files are provided",
certFile: "",
keyFile: "",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: tt.certFile,
keyFile: tt.keyFile,
},
}
server := NewServer(ctx, cancel, "test", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
result := httpSrv.shouldUseRegularTLS()
assert.Equal(t, tt.expected, result)
})
}
}
func TestHttpServer_Stop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Start a test server that we can control
testServer := httptest.NewServer(handler)
defer testServer.Close()
// Replace the server's HTTP server with our test server's
httpSrv.server = testServer.Config
err := httpSrv.Stop()
assert.NoError(t, err)
}
func TestHttpServer_logAttestedTLSStart(t *testing.T) {
tests := []struct {
name string
mtls bool
}{
{
name: "log attested mTLS start",
mtls: true,
},
{
name: "log attested TLS start",
mtls: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
// In a real scenario, you might want to capture log output
assert.NotPanics(t, func() {
httpSrv.logAttestedTLSStart(tt.mtls)
})
})
}
}
func TestHttpServer_logRegularTLSStart(t *testing.T) {
tests := []struct {
name string
mtls bool
}{
{
name: "log regular mTLS start",
mtls: true,
},
{
name: "log regular TLS start",
mtls: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: "cert.pem",
keyFile: "key.pem",
serverCAFile: "server-ca.pem",
clientCAFile: "client-ca.pem",
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// This test mainly ensures the method doesn't panic
assert.NotPanics(t, func() {
httpSrv.logRegularTLSStart(tt.mtls)
})
})
}
}
func TestHttpServer_startWithoutTLS(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Use a test server to avoid binding to actual ports
testServer := httptest.NewServer(handler)
defer testServer.Close()
httpSrv.server = testServer.Config
err := httpSrv.startWithoutTLS()
// The error will be related to context cancellation or server shutdown
assert.Error(t, err)
}
func TestHttpServer_Protocol(t *testing.T) {
tests := []struct {
name string
setupTLS func(*httpServer)
expectedProto string
}{
{
name: "HTTP protocol without TLS",
setupTLS: func(s *httpServer) {
s.Protocol = httpProtocol
},
expectedProto: httpProtocol,
},
{
name: "HTTPS protocol with TLS",
setupTLS: func(s *httpServer) {
s.Protocol = httpsProtocol
s.server.TLSConfig = &tls.Config{}
},
expectedProto: httpsProtocol,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
tt.setupTLS(httpSrv)
assert.Equal(t, tt.expectedProto, httpSrv.Protocol)
})
}
}
func TestHttpServer_ContextCancellation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
config := &mockServerConfig{
baseConfig: &mockBaseConfig{},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Cancel the context immediately
cancel()
// The listenAndServe method should handle context cancellation
err := httpSrv.listenAndServe(false)
assert.NoError(t, err) // Should return no error when context is cancelled and Stop() succeeds
}
func TestHttpServer_TLSConfiguration(t *testing.T) {
ctx := context.Background()
cancel := func() {}
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
certFile: "cert.pem",
keyFile: "key.pem",
},
}
server := NewServer(ctx, cancel, "test-server", config, &mockHandler{}, slog.Default(), nil)
httpSrv := server.(*httpServer)
// Test TLS configuration setup
httpSrv.server.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
assert.NotNil(t, httpSrv.server.TLSConfig)
assert.Equal(t, uint16(tls.VersionTLS12), httpSrv.server.TLSConfig.MinVersion)
}
// Integration-style test for server lifecycle.
func TestHttpServer_Lifecycle(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
config := &mockServerConfig{
baseConfig: &mockBaseConfig{
host: "localhost",
port: "8080",
},
}
handler := &mockHandler{}
server := NewServer(ctx, cancel, "test-server", config, handler, slog.Default(), nil)
// Test that server can be created and has expected initial state
httpSrv, ok := server.(*httpServer)
require.True(t, ok)
assert.NotNil(t, httpSrv.server)
assert.Equal(t, "localhost:8080", httpSrv.server.Addr)
// Test Stop without Start (should not panic)
err := httpSrv.Stop()
assert.NoError(t, err)
}