NOISSUE - Introduce computation runner, log forwarder, ingress, and egress proxy services. (#559)

* feat: Introduce computation runner, log forwarder, ingress, and egress proxy services.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update Go environment variable parsing and build system to use new architecture and repository.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update package sources to `sammyoina/cocos-ai` at a specific commit, add log-forwarder pre-start hook, and rename proxy binaries.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Update build system references to a specific commit and enhance logging for service connections and message processing.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* build: Update package source repositories and versions, migrate client logging to slog, and adjust ingress/egress proxy build and install steps.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug stuck

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add HTTP/2 support to egress proxy and update build system to use specific commit hashes

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: enhance egress proxy CONNECT handling, update package sources, and add gRPC test utility

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update build system for various services to a specific commit from a new repository, change agent gRPC port to 7001, and add a gRPC test client.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Migrate agent-internal gRPC communication to Unix sockets, set ingress proxy to port 7002, and update build hashes.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Remove standalone ingress-proxy systemd service and update component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Prevent computation re-initialization in agent and update component versions across several packages.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: update package versions and enable h2c support in ingress proxy.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: refactor ingress proxy to support HTTP/2 over Unix sockets and update component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update build system package sources to `ultravioletrs/cocos` and reduce agent logging verbosity.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: improve error handling in proxy commands and remove unused gRPC test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add mock service state return value in handleRunReqChunks test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add comprehensive tests for service and proxy components

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix linter

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add gRPC client and ingress adapter tests, and update egress proxy tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-02-09 12:38:21 +03:00
committed by GitHub
parent ee52551ca4
commit a3265bc346
57 changed files with 6529 additions and 162 deletions
+392
View File
@@ -0,0 +1,392 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package attestation
import (
"context"
"net"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
attestation_v1 "github.com/ultravioletrs/cocos/internal/proto/attestation/v1"
"github.com/ultravioletrs/cocos/pkg/attestation"
"google.golang.org/grpc"
)
// mockAttestationServer is a mock implementation of the AttestationServiceServer.
type mockAttestationServer struct {
attestation_v1.UnimplementedAttestationServiceServer
fetchAttestationCalled bool
fetchAzureTokenCalled bool
lastReportData []byte
lastNonce []byte
lastPlatformType attestation_v1.PlatformType
attestationErr error
azureTokenErr error
}
func (m *mockAttestationServer) FetchAttestation(ctx context.Context, req *attestation_v1.AttestationRequest) (*attestation_v1.AttestationResponse, error) {
m.fetchAttestationCalled = true
m.lastReportData = req.ReportData
m.lastNonce = req.Nonce
m.lastPlatformType = req.PlatformType
if m.attestationErr != nil {
return nil, m.attestationErr
}
return &attestation_v1.AttestationResponse{
Quote: []byte("mock-attestation-quote"),
}, nil
}
func (m *mockAttestationServer) FetchAzureToken(ctx context.Context, req *attestation_v1.AzureTokenRequest) (*attestation_v1.AzureTokenResponse, error) {
m.fetchAzureTokenCalled = true
m.lastNonce = req.Nonce
if m.azureTokenErr != nil {
return nil, m.azureTokenErr
}
return &attestation_v1.AzureTokenResponse{
Token: []byte("mock-azure-token"),
}, nil
}
// TestNewClient tests creating a new attestation client.
func TestNewClient(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-test.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
require.NotNil(t, client)
err = client.Close()
assert.NoError(t, err)
}
// TestGetAttestationSNP tests getting SNP attestation.
func TestGetAttestationSNP(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-snp.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
copy(reportData[:], []byte("test-report-data"))
copy(nonce[:], []byte("test-nonce"))
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
require.NoError(t, err)
assert.Equal(t, []byte("mock-attestation-quote"), quote)
assert.True(t, mockServer.fetchAttestationCalled)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP, mockServer.lastPlatformType)
assert.Equal(t, reportData[:], mockServer.lastReportData)
assert.Equal(t, nonce[:], mockServer.lastNonce)
}
// TestGetAttestationTDX tests getting TDX attestation.
func TestGetAttestationTDX(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-tdx.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.TDX)
require.NoError(t, err)
assert.NotNil(t, quote)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
}
// TestGetAttestationVTPM tests getting vTPM attestation.
func TestGetAttestationVTPM(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-vtpm.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.VTPM)
require.NoError(t, err)
assert.NotNil(t, quote)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_VTPM, mockServer.lastPlatformType)
}
// TestGetAttestationSNPvTPM tests getting SNP+vTPM attestation.
func TestGetAttestationSNPvTPM(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-snpvtpm.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.SNPvTPM)
require.NoError(t, err)
assert.NotNil(t, quote)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM, mockServer.lastPlatformType)
}
// TestGetAttestationUnspecified tests getting attestation with unspecified platform.
func TestGetAttestationUnspecified(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-unspec.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
// Use an invalid platform type (999)
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.PlatformType(999))
require.NoError(t, err)
assert.NotNil(t, quote)
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED, mockServer.lastPlatformType)
}
// TestGetAzureToken tests getting Azure token.
func TestGetAzureToken(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-azure.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
var nonce [32]byte
copy(nonce[:], []byte("azure-nonce"))
token, err := client.GetAzureToken(ctx, nonce)
require.NoError(t, err)
assert.Equal(t, []byte("mock-azure-token"), token)
assert.True(t, mockServer.fetchAzureTokenCalled)
assert.Equal(t, nonce[:], mockServer.lastNonce)
}
// TestGetAttestationWithCanceledContext tests GetAttestation with canceled context.
func TestGetAttestationWithCanceledContext(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-cancel.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
var reportData [64]byte
var nonce [32]byte
_, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
}
// TestClientClose tests closing the attestation client.
func TestClientClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
err = client.Close()
assert.NoError(t, err)
}
// TestClientOperationsAfterClose tests operations after closing.
func TestClientOperationsAfterClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "attestation-after-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockAttestationServer{}
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
err = client.Close()
require.NoError(t, err)
ctx := context.Background()
var reportData [64]byte
var nonce [32]byte
_, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
assert.Error(t, err)
}
+64
View File
@@ -0,0 +1,64 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package log
import (
"context"
"time"
"github.com/ultravioletrs/cocos/agent/log"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/timestamppb"
)
type Client interface {
SendLog(ctx context.Context, entry *log.LogEntry) error
SendEvent(ctx context.Context, entry *log.EventEntry) error
Close() error
}
type client struct {
conn *grpc.ClientConn
client log.LogCollectorClient
}
func NewClient(socketPath string) (Client, error) {
conn, err := grpc.NewClient("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
return &client{
conn: conn,
client: log.NewLogCollectorClient(conn),
}, nil
}
func (c *client) Close() error {
return c.conn.Close()
}
func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if entry.Timestamp == nil {
entry.Timestamp = timestamppb.Now()
}
_, err := c.client.SendLog(ctx, entry)
return err
}
func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if entry.Timestamp == nil {
entry.Timestamp = timestamppb.Now()
}
_, err := c.client.SendEvent(ctx, entry)
return err
}
+332
View File
@@ -0,0 +1,332 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package log
import (
"context"
"net"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/log"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
)
// mockLogCollectorServer is a mock implementation of the LogCollectorServer.
type mockLogCollectorServer struct {
log.UnimplementedLogCollectorServer
sendLogCalled bool
sendEventCalled bool
lastLogEntry *log.LogEntry
lastEventEntry *log.EventEntry
sendLogErr error
sendEventErr error
}
func (m *mockLogCollectorServer) SendLog(ctx context.Context, entry *log.LogEntry) (*emptypb.Empty, error) {
m.sendLogCalled = true
m.lastLogEntry = entry
if m.sendLogErr != nil {
return nil, m.sendLogErr
}
return &emptypb.Empty{}, nil
}
func (m *mockLogCollectorServer) SendEvent(ctx context.Context, entry *log.EventEntry) (*emptypb.Empty, error) {
m.sendEventCalled = true
m.lastEventEntry = entry
if m.sendEventErr != nil {
return nil, m.sendEventErr
}
return &emptypb.Empty{}, nil
}
// TestNewClient tests creating a new log client.
func TestNewClient(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-test.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
require.NotNil(t, client)
err = client.Close()
assert.NoError(t, err)
}
// TestClientSendLog tests sending a log entry.
func TestClientSendLog(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-sendlog.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.LogEntry{
Level: "INFO",
Message: "test log message",
ComputationId: "test-computation",
}
err = client.SendLog(ctx, entry)
require.NoError(t, err)
assert.True(t, mockServer.sendLogCalled)
assert.Equal(t, "INFO", mockServer.lastLogEntry.Level)
assert.Equal(t, "test log message", mockServer.lastLogEntry.Message)
assert.Equal(t, "test-computation", mockServer.lastLogEntry.ComputationId)
assert.NotNil(t, mockServer.lastLogEntry.Timestamp)
}
// TestClientSendLogWithTimestamp tests sending a log entry with existing timestamp.
func TestClientSendLogWithTimestamp(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-timestamp.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
customTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
entry := &log.LogEntry{
Level: "ERROR",
Message: "test error",
ComputationId: "test",
Timestamp: timestamppb.New(customTime),
}
err = client.SendLog(ctx, entry)
require.NoError(t, err)
assert.True(t, mockServer.sendLogCalled)
assert.Equal(t, customTime.Unix(), mockServer.lastLogEntry.Timestamp.AsTime().Unix())
}
// TestClientSendEvent tests sending an event entry.
func TestClientSendEvent(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-sendevent.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
entry := &log.EventEntry{
EventType: "computation.started",
ComputationId: "test-computation",
Originator: "agent",
Status: "started",
}
err = client.SendEvent(ctx, entry)
require.NoError(t, err)
assert.True(t, mockServer.sendEventCalled)
assert.Equal(t, "computation.started", mockServer.lastEventEntry.EventType)
assert.Equal(t, "agent", mockServer.lastEventEntry.Originator)
assert.NotNil(t, mockServer.lastEventEntry.Timestamp)
}
// TestClientSendEventWithTimestamp tests sending an event with existing timestamp.
func TestClientSendEventWithTimestamp(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-event-timestamp.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx := context.Background()
customTime := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
entry := &log.EventEntry{
EventType: "test.event",
ComputationId: "test",
Timestamp: timestamppb.New(customTime),
}
err = client.SendEvent(ctx, entry)
require.NoError(t, err)
assert.True(t, mockServer.sendEventCalled)
assert.Equal(t, customTime.Unix(), mockServer.lastEventEntry.Timestamp.AsTime().Unix())
}
// TestClientSendLogWithCanceledContext tests SendLog with canceled context.
func TestClientSendLogWithCanceledContext(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-cancel.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
entry := &log.LogEntry{
Level: "INFO",
Message: "test",
}
err = client.SendLog(ctx, entry)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
}
// TestClientClose tests closing the client.
func TestClientClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
err = client.Close()
assert.NoError(t, err)
}
// TestClientOperationsAfterClose tests operations after closing.
func TestClientOperationsAfterClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "log-after-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockLogCollectorServer{}
log.RegisterLogCollectorServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
err = client.Close()
require.NoError(t, err)
ctx := context.Background()
entry := &log.LogEntry{
Level: "INFO",
Message: "test",
}
err = client.SendLog(ctx, entry)
assert.Error(t, err)
}
+52
View File
@@ -0,0 +1,52 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package runner
import (
"context"
"time"
pb "github.com/ultravioletrs/cocos/agent/runner"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/types/known/emptypb"
)
type Client interface {
Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error)
Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error)
Close() error
}
type client struct {
conn *grpc.ClientConn
client pb.ComputationRunnerClient
}
func NewClient(socketPath string) (Client, error) {
conn, err := grpc.NewClient("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
return &client{
conn: conn,
client: pb.NewComputationRunnerClient(conn),
}, nil
}
func (c *client) Close() error {
return c.conn.Close()
}
func (c *client) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
// Run might take long time, so we need unlimited timeout or rely on context cancellation
return c.client.Run(ctx, req)
}
func (c *client) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
return c.client.Stop(ctx, req)
}
+349
View File
@@ -0,0 +1,349 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package runner
import (
"context"
"net"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
pb "github.com/ultravioletrs/cocos/agent/runner"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"
)
// mockComputationRunnerServer is a mock implementation of the ComputationRunnerServer.
type mockComputationRunnerServer struct {
pb.UnimplementedComputationRunnerServer
runCalled bool
stopCalled bool
runErr error
stopErr error
}
func (m *mockComputationRunnerServer) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
m.runCalled = true
if m.runErr != nil {
return nil, m.runErr
}
return &pb.RunResponse{
ComputationId: req.ComputationId,
Error: "",
}, nil
}
func (m *mockComputationRunnerServer) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
m.stopCalled = true
if m.stopErr != nil {
return nil, m.stopErr
}
return &emptypb.Empty{}, nil
}
// TestNewClient tests creating a new gRPC client.
func TestNewClient(t *testing.T) {
// Create a temporary directory for the socket
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
// Start a mock gRPC server
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// Create client
client, err := NewClient(socketPath)
require.NoError(t, err)
require.NotNil(t, client)
// Clean up
err = client.Close()
assert.NoError(t, err)
}
// TestNewClientInvalidSocket tests creating a client with invalid socket path.
func TestNewClientInvalidSocket(t *testing.T) {
// Use a non-existent socket path
socketPath := "/tmp/nonexistent-" + time.Now().Format("20060102150405") + ".sock"
// Create client - this should succeed as grpc.NewClient is lazy
client, err := NewClient(socketPath)
require.NoError(t, err)
require.NotNil(t, client)
// Close should work even if never connected
err = client.Close()
assert.NoError(t, err)
}
// TestClientRun tests the Run method.
func TestClientRun(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test-run.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
// Call Run
ctx := context.Background()
req := &pb.RunRequest{
ComputationId: "test-computation",
}
resp, err := client.Run(ctx, req)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, "test-computation", resp.ComputationId)
assert.True(t, mockServer.runCalled)
}
// TestClientRunWithCanceledContext tests Run with a canceled context.
func TestClientRunWithCanceledContext(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test-run-cancel.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
// Create a canceled context
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := &pb.RunRequest{
ComputationId: "test-computation",
}
_, err = client.Run(ctx, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
}
// TestClientStop tests the Stop method.
func TestClientStop(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test-stop.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
// Call Stop
ctx := context.Background()
req := &pb.StopRequest{
ComputationId: "test-computation",
}
resp, err := client.Stop(ctx, req)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, mockServer.stopCalled)
}
// TestClientStopWithTimeout tests Stop with context timeout.
func TestClientStopWithTimeout(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test-stop-timeout.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
defer client.Close()
// Create a context that will timeout
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req := &pb.StopRequest{
ComputationId: "test-computation",
}
// Stop should complete within the timeout
resp, err := client.Stop(ctx, req)
require.NoError(t, err)
assert.NotNil(t, resp)
assert.True(t, mockServer.stopCalled)
}
// TestClientClose tests the Close method.
func TestClientClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
// Close should succeed
err = client.Close()
assert.NoError(t, err)
}
// TestClientOperationsAfterClose tests that operations fail gracefully after close.
func TestClientOperationsAfterClose(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test-after-close.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
client, err := NewClient(socketPath)
require.NoError(t, err)
// Close the client
err = client.Close()
require.NoError(t, err)
// Try to use the client after closing
ctx := context.Background()
runReq := &pb.RunRequest{ComputationId: "test"}
// This should fail because connection is closed
_, err = client.Run(ctx, runReq)
assert.Error(t, err)
}
// TestNewClientWithRelativePath tests creating client with relative socket path.
func TestNewClientWithRelativePath(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
// Change to temp directory
oldWd, err := os.Getwd()
require.NoError(t, err)
defer func() {
err := os.Chdir(oldWd)
require.NoError(t, err)
}()
err = os.Chdir(tmpDir)
require.NoError(t, err)
socketPath := "relative-test.sock"
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
defer listener.Close()
grpcServer := grpc.NewServer()
mockServer := &mockComputationRunnerServer{}
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
go func() {
_ = grpcServer.Serve(listener)
}()
defer grpcServer.Stop()
time.Sleep(100 * time.Millisecond)
// Create client with relative path
client, err := NewClient(socketPath)
require.NoError(t, err)
require.NotNil(t, client)
err = client.Close()
assert.NoError(t, err)
}
+223
View File
@@ -0,0 +1,223 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
mock "github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent/runner"
"google.golang.org/protobuf/types/known/emptypb"
)
// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewClient(t interface {
mock.TestingT
Cleanup(func())
}) *Client {
mock := &Client{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Client is an autogenerated mock type for the Client type
type Client struct {
mock.Mock
}
type Client_Expecter struct {
mock *mock.Mock
}
func (_m *Client) EXPECT() *Client_Expecter {
return &Client_Expecter{mock: &_m.Mock}
}
// Close provides a mock function for the type Client
func (_mock *Client) Close() error {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func() error); ok {
r0 = returnFunc()
} else {
r0 = ret.Error(0)
}
return r0
}
// Client_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type Client_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *Client_Expecter) Close() *Client_Close_Call {
return &Client_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *Client_Close_Call) Run(run func()) *Client_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *Client_Close_Call) Return(err error) *Client_Close_Call {
_c.Call.Return(err)
return _c
}
func (_c *Client_Close_Call) RunAndReturn(run func() error) *Client_Close_Call {
_c.Call.Return(run)
return _c
}
// Run provides a mock function for the type Client
func (_mock *Client) Run(ctx context.Context, req *runner.RunRequest) (*runner.RunResponse, error) {
ret := _mock.Called(ctx, req)
if len(ret) == 0 {
panic("no return value specified for Run")
}
var r0 *runner.RunResponse
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.RunRequest) (*runner.RunResponse, error)); ok {
return returnFunc(ctx, req)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.RunRequest) *runner.RunResponse); ok {
r0 = returnFunc(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*runner.RunResponse)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *runner.RunRequest) error); ok {
r1 = returnFunc(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Client_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
type Client_Run_Call struct {
*mock.Call
}
// Run is a helper method to define mock.On call
// - ctx context.Context
// - req *runner.RunRequest
func (_e *Client_Expecter) Run(ctx interface{}, req interface{}) *Client_Run_Call {
return &Client_Run_Call{Call: _e.mock.On("Run", ctx, req)}
}
func (_c *Client_Run_Call) Run(run func(ctx context.Context, req *runner.RunRequest)) *Client_Run_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *runner.RunRequest
if args[1] != nil {
arg1 = args[1].(*runner.RunRequest)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Client_Run_Call) Return(runResponse *runner.RunResponse, err error) *Client_Run_Call {
_c.Call.Return(runResponse, err)
return _c
}
func (_c *Client_Run_Call) RunAndReturn(run func(ctx context.Context, req *runner.RunRequest) (*runner.RunResponse, error)) *Client_Run_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function for the type Client
func (_mock *Client) Stop(ctx context.Context, req *runner.StopRequest) (*emptypb.Empty, error) {
ret := _mock.Called(ctx, req)
if len(ret) == 0 {
panic("no return value specified for Stop")
}
var r0 *emptypb.Empty
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.StopRequest) (*emptypb.Empty, error)); ok {
return returnFunc(ctx, req)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.StopRequest) *emptypb.Empty); ok {
r0 = returnFunc(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*emptypb.Empty)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *runner.StopRequest) error); ok {
r1 = returnFunc(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Client_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type Client_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
// - ctx context.Context
// - req *runner.StopRequest
func (_e *Client_Expecter) Stop(ctx interface{}, req interface{}) *Client_Stop_Call {
return &Client_Stop_Call{Call: _e.mock.On("Stop", ctx, req)}
}
func (_c *Client_Stop_Call) Run(run func(ctx context.Context, req *runner.StopRequest)) *Client_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *runner.StopRequest
if args[1] != nil {
arg1 = args[1].(*runner.StopRequest)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Client_Stop_Call) Return(empty *emptypb.Empty, err error) *Client_Stop_Call {
_c.Call.Return(empty, err)
return _c
}
func (_c *Client_Stop_Call) RunAndReturn(run func(ctx context.Context, req *runner.StopRequest) (*emptypb.Empty, error)) *Client_Stop_Call {
_c.Call.Return(run)
return _c
}
+248
View File
@@ -0,0 +1,248 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package egress
import (
"context"
"crypto/tls"
"io"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
"golang.org/x/net/http2"
)
// Proxy is an egress proxy server.
type Proxy struct {
logger *slog.Logger
server *http.Server
addr string
transport *http.Transport
}
// NewProxy creates a new egress proxy.
func NewProxy(logger *slog.Logger, addr string) *Proxy {
// Create HTTP/2 capable transport
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: false,
},
ForceAttemptHTTP2: true,
}
// Enable HTTP/2
if err := http2.ConfigureTransport(transport); err != nil {
logger.Warn("Failed to configure HTTP/2 transport", "error", err)
}
p := &Proxy{
logger: logger,
addr: addr,
transport: transport,
}
p.server = &http.Server{
Addr: addr,
Handler: http.HandlerFunc(p.handle),
}
return p
}
// Start starts the proxy server.
func (p *Proxy) Start() error {
p.logger.Info("Starting egress proxy", "addr", p.addr)
return p.server.ListenAndServe()
}
// Stop stops the proxy server.
func (p *Proxy) Stop(ctx context.Context) error {
p.logger.Info("Stopping egress proxy")
return p.server.Shutdown(ctx)
}
func (p *Proxy) handle(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
p.handleConnect(w, r)
} else if r.ProtoMajor == 2 {
p.handleHTTP2(w, r)
} else {
p.handleHTTP(w, r)
}
}
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
host := r.Host
p.logger.Info("CONNECT request received", "host", host)
// nolint:godox // TODO: Check allowlist here - allowlist implementation deferred
p.logger.Debug("Dialing destination", "host", host)
destConn, err := net.DialTimeout("tcp", host, 10*time.Second)
if err != nil {
p.logger.Error("Failed to dial destination", "host", host, "error", err)
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer destConn.Close()
p.logger.Info("Successfully connected to destination", "host", host)
p.logger.Debug("Hijacking client connection")
hijacker, ok := w.(http.Hijacker)
if !ok {
p.logger.Error("Hijacking not supported")
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
p.logger.Error("Failed to hijack connection", "error", err)
return
}
defer clientConn.Close()
p.logger.Info("Successfully hijacked client connection", "host", host)
// Send 200 Connection Established response
p.logger.Debug("Sending 200 Connection Established")
_, err = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil {
p.logger.Error("Failed to send CONNECT response", "error", err)
return
}
p.logger.Info("Starting bidirectional pipe", "host", host)
p.pipe(clientConn, destConn)
p.logger.Info("Pipe completed", "host", host)
}
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
p.logger.Info("HTTP request", "method", r.Method, "url", r.URL.String())
// nolint:godox // TODO: Check allowlist here - allowlist implementation deferred
r.RequestURI = "" // RequestURI must be empty for Client.Do
// Remove hop-by-hop headers
delHopHeaders(r.Header)
// Create a client to send the request
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
resp, err := client.Do(r)
if err != nil {
p.logger.Error("Failed to execute request", "error", err)
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Copy headers
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
p.logger.Error("Failed to copy response body", "error", err)
}
}
func (p *Proxy) handleHTTP2(w http.ResponseWriter, r *http.Request) {
p.logger.Info("HTTP/2 request", "method", r.Method, "host", r.Host, "path", r.URL.Path)
// nolint:godox // TODO: Check allowlist here - allowlist implementation deferred
// Parse the target URL from the request
targetURL := &url.URL{
Scheme: "http",
Host: r.Host,
}
// If the request has a full URL (absolute form), use it
if r.URL.IsAbs() {
targetURL = r.URL
}
// Create a reverse proxy with HTTP/2 transport
proxy := &httputil.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.Host = targetURL.Host
// Preserve the original path and query
if !r.URL.IsAbs() {
req.URL.Path = r.URL.Path
req.URL.RawQuery = r.URL.RawQuery
}
// Remove hop-by-hop headers
delHopHeaders(req.Header)
},
Transport: p.transport,
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
p.logger.Error("HTTP/2 proxy error", "error", err, "host", r.Host)
http.Error(w, err.Error(), http.StatusBadGateway)
},
}
proxy.ServeHTTP(w, r)
}
func (p *Proxy) pipe(src, dst net.Conn) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
n, err := io.Copy(dst, src)
p.logger.Debug("Pipe src->dst completed", "bytes", n, "error", err)
// Close write end of dst if possible, or just close it
if c, ok := dst.(*net.TCPConn); ok {
if err := c.CloseWrite(); err != nil {
p.logger.Debug("Failed to close write end of dst", "error", err)
}
}
}()
go func() {
defer wg.Done()
n, err := io.Copy(src, dst)
p.logger.Debug("Pipe dst->src completed", "bytes", n, "error", err)
if c, ok := src.(*net.TCPConn); ok {
if err := c.CloseWrite(); err != nil {
p.logger.Debug("Failed to close write end of src", "error", err)
}
}
}()
wg.Wait()
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func delHopHeaders(header http.Header) {
// Standard hop-by-hop headers
hopHeaders := []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
for _, h := range hopHeaders {
header.Del(h)
}
}
+795
View File
@@ -0,0 +1,795 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package egress
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProxyHTTP(t *testing.T) {
// 1. Start a backend server
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("backend response")); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
// 2. Start Proxy
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
proxy := NewProxy(logger, ":0")
// Listen on a random port
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
// waiting for server start
time.Sleep(100 * time.Millisecond)
// 3. Make request via proxy
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
resp, err := client.Get(backend.URL)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "backend response", string(body))
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestProxyConnect(t *testing.T) {
// 1. Start a backend TLS server
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("secure backend response")); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
// 2. Start Proxy
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
// Listen on a random port
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// 3. Configure client to use proxy
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTPS_PROXY", proxyURL)
defer os.Unsetenv("HTTPS_PROXY")
client := backend.Client() // This client trusts the test cert
// But we need to update its transport proxy
if transport, ok := client.Transport.(*http.Transport); ok {
transport.Proxy = http.ProxyFromEnvironment
} else {
// Create new transport if needed, but backend.Client() returns transport with TLS config
tr := &http.Transport{
TLSClientConfig: client.Transport.(*http.Transport).TLSClientConfig,
Proxy: http.ProxyFromEnvironment,
}
client.Transport = tr
}
resp, err := client.Get(backend.URL)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "secure backend response", string(body))
}
// TestProxyHTTP2 tests HTTP/2 requests through the proxy.
func TestProxyHTTP2(t *testing.T) {
// 1. Start a backend server
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("http2 response")); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
// 2. Start Proxy
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// 3. Make HTTP/2 request via proxy
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
},
}
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
// This will be an HTTP/1.1 request unless explicitly configured for HTTP/2
resp, err := client.Get(backend.URL)
if err == nil {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, "http2 response", string(body))
}
}
// TestProxyHeaderHandling tests that headers are properly handled.
func TestProxyHeaderHandling(t *testing.T) {
// Start a backend server that echoes headers
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom-Header", "custom-value")
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(r.Header.Get("X-Request-Header"))); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
// Start Proxy
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// Make request with custom headers
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
req, _ := http.NewRequest("GET", backend.URL, nil)
req.Header.Set("X-Request-Header", "request-value")
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
assert.Equal(t, "custom-value", resp.Header.Get("X-Custom-Header"))
}
}
// TestProxyWithDifferentMethods tests different HTTP methods.
func TestProxyWithDifferentMethods(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(r.Method)); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
methods := []string{"GET", "POST", "PUT", "DELETE"}
for _, method := range methods {
req, _ := http.NewRequest(method, backend.URL, nil)
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
resp, err := client.Do(req)
if err == nil {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
assert.Equal(t, method, string(body))
}
}
}
// TestProxyErrorHandling tests error handling in the proxy.
func TestProxyErrorHandling(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// Try to connect to a non-existent backend
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
// This should fail because the backend doesn't exist
resp, err := client.Get("http://nonexistent.example.com:99999")
if err != nil {
return
}
if resp != nil {
defer resp.Body.Close()
// Status should be error
assert.NotEqual(t, http.StatusOK, resp.StatusCode)
}
}
// TestProxyWithLargeBody tests proxy with large response body.
func TestProxyWithLargeBody(t *testing.T) {
largeBody := make([]byte, 1024*1024) // 1MB
for i := range largeBody {
largeBody[i] = byte(i % 256)
}
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write(largeBody); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
resp, err := client.Get(backend.URL)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, len(largeBody), len(body))
}
// TestCopyHeader tests the copyHeader utility function.
func TestCopyHeader(t *testing.T) {
src := http.Header{}
src.Add("X-Custom-Header", "value1")
src.Add("X-Custom-Header", "value2")
src.Add("Content-Type", "application/json")
dst := http.Header{}
copyHeader(dst, src)
assert.Equal(t, []string{"value1", "value2"}, dst["X-Custom-Header"])
assert.Equal(t, []string{"application/json"}, dst["Content-Type"])
}
// TestDelHopHeaders tests the delHopHeaders utility function.
func TestDelHopHeaders(t *testing.T) {
header := http.Header{}
header.Set("Connection", "keep-alive")
header.Set("Keep-Alive", "timeout=5")
header.Set("Proxy-Authenticate", "Basic")
header.Set("Proxy-Authorization", "Bearer token")
header.Set("Te", "trailers")
header.Set("Trailers", "X-Custom")
header.Set("Transfer-Encoding", "chunked")
header.Set("Upgrade", "websocket")
header.Set("X-Custom-Header", "should-remain")
delHopHeaders(header)
// Hop-by-hop headers should be removed
assert.Empty(t, header.Get("Connection"))
assert.Empty(t, header.Get("Keep-Alive"))
assert.Empty(t, header.Get("Proxy-Authenticate"))
assert.Empty(t, header.Get("Proxy-Authorization"))
assert.Empty(t, header.Get("Te"))
assert.Empty(t, header.Get("Trailers"))
assert.Empty(t, header.Get("Transfer-Encoding"))
assert.Empty(t, header.Get("Upgrade"))
// Custom headers should remain
assert.Equal(t, "should-remain", header.Get("X-Custom-Header"))
}
// TestProxyConnectDialTimeout tests CONNECT with dial timeout.
func TestProxyConnectDialTimeout(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// Try to CONNECT to a non-routable address (should timeout)
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTPS_PROXY", proxyURL)
defer os.Unsetenv("HTTPS_PROXY")
// Create a client with very short timeout
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
Timeout: 2 * time.Second,
}
// This should fail because 192.0.2.1 is a TEST-NET address (non-routable)
_, err = client.Get("https://192.0.2.1:9999/test")
assert.Error(t, err)
}
// TestProxyHTTPWithRedirect tests HTTP proxy handling redirects.
func TestProxyHTTPWithRedirect(t *testing.T) {
// Create a backend that redirects
redirectCount := 0
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if redirectCount == 0 {
redirectCount++
http.Redirect(w, r, "/redirected", http.StatusFound)
return
}
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("redirected response")); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
}
resp, err := client.Get(backend.URL)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "redirected response", string(body))
}
// TestProxyStopContext tests proxy stop with context.
func TestProxyStopContext(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = proxy.Stop(ctx)
assert.NoError(t, err)
}
// TestProxyPipeWithRealConnections tests the pipe function with real TCP connections.
func TestProxyPipeWithRealConnections(t *testing.T) {
// Create two connected TCP connections
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
// Channel to receive the server connection
serverConnChan := make(chan net.Conn, 1)
go func() {
conn, err := listener.Accept()
if err == nil {
serverConnChan <- conn
}
}()
// Create client connection
clientConn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer clientConn.Close()
// Get server connection
serverConn := <-serverConnChan
defer serverConn.Close()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
proxy := NewProxy(logger, ":0")
// Test data transfer
testData := []byte("test data for pipe")
// Start pipe in goroutine
go proxy.pipe(clientConn, serverConn)
// Write from client
_, err = clientConn.Write(testData)
require.NoError(t, err)
// Read from server
buf := make([]byte, len(testData))
if err := serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
t.Logf("Failed to set read deadline: %v", err)
}
n, err := serverConn.Read(buf)
require.NoError(t, err)
assert.Equal(t, testData, buf[:n])
// Close connections to trigger pipe completion
clientConn.Close()
serverConn.Close()
// Give pipe time to complete
time.Sleep(100 * time.Millisecond)
}
// TestProxyHTTP2ErrorPath tests HTTP/2 error handler.
func TestProxyHTTP2ErrorPath(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// Create a request that will trigger HTTP/2 handling
req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+"/test", nil)
require.NoError(t, err)
// Force HTTP/2 by setting the request protocol
req.ProtoMajor = 2
req.ProtoMinor = 0
req.Host = "nonexistent.invalid:9999" // This should cause an error
// Create a response recorder
rr := httptest.NewRecorder()
// Call the handler directly to test HTTP/2 error path
proxy.server.Handler.ServeHTTP(rr, req)
// Should get an error response
assert.Equal(t, http.StatusBadGateway, rr.Code)
}
// TestNewProxyHTTP2ConfigWarning tests NewProxy when HTTP/2 configuration might fail.
func TestNewProxyHTTP2ConfigWarning(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
// Create proxy - HTTP/2 configuration should succeed normally
proxy := NewProxy(logger, ":0")
assert.NotNil(t, proxy)
assert.NotNil(t, proxy.transport)
assert.True(t, proxy.transport.ForceAttemptHTTP2)
}
// TestProxyHandleHTTPError tests HTTP handler error path.
func TestProxyHandleHTTPError(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTP_PROXY", proxyURL)
defer os.Unsetenv("HTTP_PROXY")
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
},
Timeout: 2 * time.Second,
}
// Try to connect to invalid backend
resp, err := client.Get("http://invalid.backend.test:99999/test")
if err == nil {
defer resp.Body.Close()
// Should get error status
assert.NotEqual(t, http.StatusOK, resp.StatusCode)
}
// Either error or bad gateway response is acceptable
}
// TestProxyConnectWriteError tests CONNECT with write error after hijacking.
func TestProxyConnectWriteError(t *testing.T) {
// This test is challenging because we need to trigger a write error
// after successful hijacking. We'll test the path by using a closed connection.
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// Create a backend server for CONNECT
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
os.Setenv("HTTPS_PROXY", proxyURL)
defer os.Unsetenv("HTTPS_PROXY")
client := backend.Client()
if transport, ok := client.Transport.(*http.Transport); ok {
transport.Proxy = http.ProxyFromEnvironment
}
// Make a request through CONNECT
_, err = client.Get(backend.URL)
// The request may succeed or fail, but we're testing the code path
if err != nil {
t.Logf("Request error (expected in some cases): %v", err)
}
}
// TestProxyHTTP2WithAbsoluteURL tests HTTP/2 handling with absolute URL.
func TestProxyHTTP2WithAbsoluteURL(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte("http2 absolute url response")); err != nil {
t.Logf("Failed to write response: %v", err)
}
}))
defer backend.Close()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ln, err := net.Listen("tcp", ":0")
require.NoError(t, err)
proxy := NewProxy(logger, ln.Addr().String())
proxy.server.Addr = ln.Addr().String()
go func() {
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
t.Logf("Proxy server error: %v", err)
}
}()
defer func() {
if err := proxy.Stop(context.Background()); err != nil {
t.Logf("Failed to stop proxy: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
// Create request with absolute URL
req, err := http.NewRequest("GET", backend.URL+"/test", nil)
require.NoError(t, err)
req.ProtoMajor = 2
req.ProtoMinor = 0
rr := httptest.NewRecorder()
proxy.server.Handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code)
}
+25
View File
@@ -0,0 +1,25 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ingress
import "github.com/ultravioletrs/cocos/agent"
// AgentConfigToProxyConfig converts agent.AgentConfig to ProxyConfig.
func AgentConfigToProxyConfig(cfg agent.AgentConfig) ProxyConfig {
return ProxyConfig{
Port: "7002", // Ingress-proxy always uses port 7002
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
ServerCAFile: cfg.ServerCAFile,
ClientCAFile: cfg.ClientCAFile,
AttestedTLS: cfg.AttestedTls,
}
}
// ComputationToProxyContext converts agent.Computation to ProxyContext.
func ComputationToProxyContext(cmp agent.Computation) ProxyContext {
return ProxyContext{
ID: cmp.ID,
Name: cmp.Name,
}
}
+166
View File
@@ -0,0 +1,166 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ingress
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/agent"
)
// TestAgentConfigToProxyConfig tests conversion from AgentConfig to ProxyConfig.
func TestAgentConfigToProxyConfig(t *testing.T) {
tests := []struct {
name string
input agent.AgentConfig
expected ProxyConfig
}{
{
name: "basic config without TLS",
input: agent.AgentConfig{
CertFile: "",
KeyFile: "",
ServerCAFile: "",
ClientCAFile: "",
AttestedTls: false,
},
expected: ProxyConfig{
Port: "7002",
CertFile: "",
KeyFile: "",
ServerCAFile: "",
ClientCAFile: "",
AttestedTLS: false,
},
},
{
name: "config with regular TLS",
input: agent.AgentConfig{
CertFile: "/path/to/cert.pem",
KeyFile: "/path/to/key.pem",
ServerCAFile: "/path/to/server-ca.pem",
ClientCAFile: "/path/to/client-ca.pem",
AttestedTls: false,
},
expected: ProxyConfig{
Port: "7002",
CertFile: "/path/to/cert.pem",
KeyFile: "/path/to/key.pem",
ServerCAFile: "/path/to/server-ca.pem",
ClientCAFile: "/path/to/client-ca.pem",
AttestedTLS: false,
},
},
{
name: "config with attested TLS",
input: agent.AgentConfig{
CertFile: "",
KeyFile: "",
ServerCAFile: "/path/to/server-ca.pem",
ClientCAFile: "/path/to/client-ca.pem",
AttestedTls: true,
},
expected: ProxyConfig{
Port: "7002",
CertFile: "",
KeyFile: "",
ServerCAFile: "/path/to/server-ca.pem",
ClientCAFile: "/path/to/client-ca.pem",
AttestedTLS: true,
},
},
{
name: "config with mTLS",
input: agent.AgentConfig{
CertFile: "/path/to/cert.pem",
KeyFile: "/path/to/key.pem",
ServerCAFile: "/path/to/server-ca.pem",
ClientCAFile: "/path/to/client-ca.pem",
AttestedTls: false,
},
expected: ProxyConfig{
Port: "7002",
CertFile: "/path/to/cert.pem",
KeyFile: "/path/to/key.pem",
ServerCAFile: "/path/to/server-ca.pem",
ClientCAFile: "/path/to/client-ca.pem",
AttestedTLS: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AgentConfigToProxyConfig(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// TestComputationToProxyContext tests conversion from Computation to ProxyContext.
func TestComputationToProxyContext(t *testing.T) {
tests := []struct {
name string
input agent.Computation
expected ProxyContext
}{
{
name: "computation with name",
input: agent.Computation{
ID: "comp-123",
Name: "test-computation",
Description: "A test computation",
},
expected: ProxyContext{
ID: "comp-123",
Name: "test-computation",
},
},
{
name: "computation without name",
input: agent.Computation{
ID: "comp-456",
Name: "",
Description: "Another test computation",
},
expected: ProxyContext{
ID: "comp-456",
Name: "",
},
},
{
name: "computation with special characters in name",
input: agent.Computation{
ID: "comp-789",
Name: "test-computation-with-dashes_and_underscores",
Description: "Computation with special chars",
},
expected: ProxyContext{
ID: "comp-789",
Name: "test-computation-with-dashes_and_underscores",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ComputationToProxyContext(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
// TestAgentConfigToProxyConfigPortIsFixed tests that port is always set to 7002.
func TestAgentConfigToProxyConfigPortIsFixed(t *testing.T) {
configs := []agent.AgentConfig{
{},
{CertFile: "/cert.pem", KeyFile: "/key.pem"},
{AttestedTls: true},
}
for i, cfg := range configs {
result := AgentConfigToProxyConfig(cfg)
assert.Equal(t, "7002", result.Port, "Port should always be 7002 for config %d", i)
}
}
+223
View File
@@ -0,0 +1,223 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ingress
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"github.com/ultravioletrs/cocos/pkg/atls"
"github.com/ultravioletrs/cocos/pkg/server"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
// ProxyConfig contains configuration for starting a proxy instance.
type ProxyConfig struct {
Port string
CertFile string
KeyFile string
ServerCAFile string
ClientCAFile string
AttestedTLS bool
}
// ProxyContext provides context information for logging and tracking.
type ProxyContext struct {
ID string
Name string
}
// ProxyServer manages ingress proxy instances.
type ProxyServer interface {
Start(cfg ProxyConfig, ctx ProxyContext) error
Stop() error
}
type proxyServer struct {
mu sync.RWMutex
logger *slog.Logger
backendURL *url.URL
certProvider atls.CertificateProvider
httpServer *http.Server
started bool
stopped bool
}
// NewProxyServer creates a new ingress proxy server manager.
func NewProxyServer(logger *slog.Logger, backendURL *url.URL, certProvider atls.CertificateProvider) ProxyServer {
return &proxyServer{
logger: logger,
backendURL: backendURL,
certProvider: certProvider,
}
}
// Start starts the proxy server with the given configuration.
func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.started {
return fmt.Errorf("proxy server already started")
}
if p.stopped {
return fmt.Errorf("proxy server already stopped")
}
if cfg.Port == "" {
cfg.Port = "7002"
}
addr := fmt.Sprintf("0.0.0.0:%s", cfg.Port)
// Configure Reverse Proxy
var rp *httputil.ReverseProxy
// Check if backend is Unix socket or TCP
if p.backendURL.Scheme == "unix" {
// For Unix socket backend, we need to manually configure the reverse proxy
// because NewSingleHostReverseProxy doesn't support unix:// scheme
targetURL := &url.URL{
Scheme: "http",
Host: "unix",
}
rp = httputil.NewSingleHostReverseProxy(targetURL)
// Override the Director to not modify the request
originalDirector := rp.Director
rp.Director = func(req *http.Request) {
originalDirector(req)
// Set the URL to point to the backend service
req.URL.Scheme = "http"
req.URL.Host = "unix"
}
// Configure Transport for Unix socket with HTTP/2
rp.Transport = &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
var d net.Dialer
// Use Unix socket path from URL
return d.DialContext(ctx, "unix", p.backendURL.Path)
},
}
} else {
// TCP backend
rp = httputil.NewSingleHostReverseProxy(p.backendURL)
rp.Transport = &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
},
}
}
// Wrap handler with h2c for HTTP/2 cleartext support (required for gRPC without TLS)
h2cHandler := h2c.NewHandler(rp, &http2.Server{})
p.httpServer = &http.Server{
Addr: addr,
Handler: h2cHandler,
}
// Configure TLS
var tlsConfig *tls.Config
contextDesc := fmt.Sprintf("context %s", ctx.ID)
if ctx.Name != "" {
contextDesc = fmt.Sprintf("%s (%s)", ctx.Name, ctx.ID)
}
if cfg.AttestedTLS {
if p.certProvider == nil {
return fmt.Errorf("attested TLS requested but no certificate provider available")
}
tlsConfig = &tls.Config{
GetCertificate: p.certProvider.GetCertificate,
ClientAuth: tls.NoClientCert,
NextProtos: []string{"h2", "http/1.1"},
}
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to configure certificate authorities: %w", err)
}
if mtls {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested mTLS for %s", addr, contextDesc))
} else {
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested TLS for %s", addr, contextDesc))
}
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
// Regular TLS
tlsSetup, err := server.SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
if err != nil {
return fmt.Errorf("failed to setup TLS: %w", err)
}
tlsConfig = tlsSetup.Config
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
if tlsSetup.MTLS {
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with mTLS for %s", addr, contextDesc))
} else {
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with TLS for %s", addr, contextDesc))
}
} else {
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s without TLS for %s", addr, contextDesc))
}
p.started = true
// Start server in goroutine
go func() {
var err error
if tlsConfig != nil {
ln, listenErr := net.Listen("tcp", addr)
if listenErr != nil {
p.logger.Error(fmt.Sprintf("failed to listen: %s", listenErr))
return
}
tlsLn := tls.NewListener(ln, tlsConfig)
err = p.httpServer.Serve(tlsLn)
} else {
err = p.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", err))
}
}()
return nil
}
// Stop stops the proxy server.
func (p *proxyServer) Stop() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.stopped {
return nil
}
p.stopped = true
if p.httpServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*1000000000) // 5 seconds
defer cancel()
if err := p.httpServer.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown server: %w", err)
}
p.logger.Info("ingress-proxy stopped")
}
return nil
}
+477
View File
@@ -0,0 +1,477 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package ingress
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"log/slog"
"math/big"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
func createTempCert(t *testing.T) (certFile, keyFile string) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
require.NoError(t, err)
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "server.crt")
keyPath := filepath.Join(tmpDir, "server.key")
certOut, err := os.Create(certPath)
require.NoError(t, err)
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
require.NoError(t, err)
certOut.Close()
keyOut, err := os.Create(keyPath)
require.NoError(t, err)
b, err := x509.MarshalECPrivateKey(priv)
require.NoError(t, err)
err = pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
require.NoError(t, err)
keyOut.Close()
return certPath, keyPath
}
func getBackendURL() *url.URL {
u, _ := url.Parse("http://localhost:8080")
return u
}
// TestNewProxyServer tests the creation of a new proxy server.
func TestNewProxyServer(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
require.NotNil(t, ps)
}
// TestProxyStartStop tests basic start and stop operations.
func TestProxyStartStop(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "test-1", Name: "test-proxy"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = ps.Stop()
require.NoError(t, err)
}
// TestProxyStartWithoutPort tests proxy without explicit port.
func TestProxyStartWithoutPort(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
cfg := ProxyConfig{Port: ""}
ctx := ProxyContext{ID: "test-2"}
err := ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
}
// TestProxyStartAlreadyStarted tests error when starting twice.
func TestProxyStartAlreadyStarted(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "test-3"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
err = ps.Start(cfg, ctx)
assert.Error(t, err)
assert.Equal(t, "proxy server already started", err.Error())
}
// TestProxyStartAfterStopped tests error when starting after stop.
func TestProxyStartAfterStopped(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "test-4"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = ps.Stop()
require.NoError(t, err)
err = ps.Start(cfg, ctx)
assert.Error(t, err)
// After stop, attempts to start will fail with "already started" error first
assert.Contains(t, err.Error(), "proxy server already")
}
// TestProxyWithName tests proxy context with name.
func TestProxyWithName(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "id-1", Name: "named-proxy"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
}
// TestProxyWithoutName tests proxy context without name.
func TestProxyWithoutName(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "id-only"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
}
// TestProxyMultipleStops tests multiple stop calls.
func TestProxyMultipleStops(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
ctx := ProxyContext{ID: "test-multi-stop"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = ps.Stop()
require.NoError(t, err)
err = ps.Stop()
require.NoError(t, err)
}
// TestProxyWithoutTLS tests proxy without TLS.
func TestProxyWithoutTLS(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
AttestedTLS: false,
}
ctx := ProxyContext{ID: "test-no-tls"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
}
func TestProxyWithUnixBackend(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
// Create a temp directory for socket
dir := t.TempDir()
sockPath := filepath.Join(dir, "backend.sock")
// Start a dummy backend on unix socket
l, err := net.Listen("unix", sockPath)
require.NoError(t, err)
defer l.Close()
backendCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
h2s := &http2.Server{}
h2cHandler := h2c.NewHandler(handler, h2s)
go func() {
_ = http.Serve(l, h2cHandler)
}()
// Configure proxy to use this unix socket
backendURL, _ := url.Parse("unix://" + sockPath)
ps := NewProxyServer(logger, backendURL, nil)
// Find free port for proxy
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
}
ctx := ProxyContext{ID: "test-unix"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
// Make request to proxy
resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port))
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.True(t, backendCalled)
}
func TestProxyRegularTLS(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
certFile, keyFile := createTempCert(t)
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
CertFile: certFile,
KeyFile: keyFile,
}
ctx := ProxyContext{ID: "test-tls"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
time.Sleep(100 * time.Millisecond)
// Client with skipped verification
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: tr}
resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port))
// Backend is not running/reachable so 502 or error is expected from reverse proxy,
// but 502 means connection to proxy succeeded.
// If the proxy itself was not working, we'd get connection refused or similar.
require.NoError(t, err)
if err == nil {
resp.Body.Close()
}
}
func TestProxyRegularTLSInvalidFiles(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
CertFile: "non-existent.crt",
KeyFile: "non-existent.key",
}
ctx := ProxyContext{ID: "test-tls-fail"}
err = ps.Start(cfg, ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to setup TLS")
}
func TestProxyAttestedTLS(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
mockProvider := mocks.NewCertificateProvider(t)
// We don't expect calls during Listen, only during handshake.
// But Start logic doesn't block waiting for handshake.
ps := NewProxyServer(logger, getBackendURL(), mockProvider)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
AttestedTLS: true,
}
ctx := ProxyContext{ID: "test-attested-tls"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
}
func TestProxyAttestedTLSMissingProvider(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ps := NewProxyServer(logger, getBackendURL(), nil)
cfg := ProxyConfig{
Port: "0",
AttestedTLS: true,
}
ctx := ProxyContext{ID: "test-attested-fail"}
err := ps.Start(cfg, ctx)
assert.Error(t, err)
assert.Equal(t, "attested TLS requested but no certificate provider available", err.Error())
}
func TestProxyMTLS(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
certFile, _ := createTempCert(t)
mockProvider := mocks.NewCertificateProvider(t)
ps := NewProxyServer(logger, getBackendURL(), mockProvider)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// Test case: AttestedTLS with ClientCAFile (mTLS)
// server.ConfigureCertificateAuthorities reads the file.
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
AttestedTLS: true,
ClientCAFile: certFile, // Use self-signed cert as CA
ServerCAFile: certFile, // Also for server CA
}
ctx := ProxyContext{ID: "test-mtls"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
}
func TestProxyRegularMTLS(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
certFile, keyFile := createTempCert(t)
ps := NewProxyServer(logger, getBackendURL(), nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := ProxyConfig{
Port: fmt.Sprintf("%d", port),
CertFile: certFile,
KeyFile: keyFile,
ServerCAFile: certFile,
ClientCAFile: certFile,
}
ctx := ProxyContext{ID: "test-regular-mtls"}
err = ps.Start(cfg, ctx)
require.NoError(t, err)
defer func() { _ = ps.Stop() }()
}
func TestProxyAttestedTLSInvalidCA(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
mockProvider := mocks.NewCertificateProvider(t)
ps := NewProxyServer(logger, getBackendURL(), mockProvider)
cfg := ProxyConfig{
Port: "0",
AttestedTLS: true,
ServerCAFile: "non-existent.pem",
}
ctx := ProxyContext{ID: "test-attested-invalid-ca"}
err := ps.Start(cfg, ctx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to configure certificate authorities")
}
+27 -5
View File
@@ -9,6 +9,7 @@ import (
"fmt"
"log/slog"
"net"
"os"
"sync"
"time"
@@ -111,10 +112,26 @@ func (s *Server) Start() error {
grpcServerOptions = append(grpcServerOptions, creds)
// Create listener
listener, err := net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
// Create listener - detect Unix socket vs TCP
var listener net.Listener
baseConfig := s.Config.GetBaseConfig()
// Check if this is a Unix socket path (starts with /)
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
// Unix socket
// Remove existing socket file if it exists
_ = os.Remove(baseConfig.Host)
listener, err = net.Listen("unix", baseConfig.Host)
if err != nil {
return fmt.Errorf("failed to listen on Unix socket %s: %w", baseConfig.Host, err)
}
} else {
// TCP socket
listener, err = net.Listen("tcp", s.Address)
if err != nil {
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
}
}
// Create and configure server
@@ -160,7 +177,12 @@ func (s *Server) configureCredentials() (grpc.ServerOption, error) {
}
// Use insecure credentials
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
// Determine address for logging
addr := s.Address
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
addr = baseConfig.Host
}
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, addr))
return grpc.Creds(insecure.NewCredentials()), nil
}