From a3265bc3468ddfd72cb924244af055e10de71f37 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Mon, 9 Feb 2026 12:38:21 +0300 Subject: [PATCH] NOISSUE - Introduce computation runner, log forwarder, ingress, and egress proxy services. (#559) * feat: Introduce computation runner, log forwarder, ingress, and egress proxy services. Signed-off-by: Sammy Oina * feat: Update Go environment variable parsing and build system to use new architecture and repository. Signed-off-by: Sammy Oina * feat: Update package sources to `sammyoina/cocos-ai` at a specific commit, add log-forwarder pre-start hook, and rename proxy binaries. Signed-off-by: Sammy Oina * chore: Update build system references to a specific commit and enhance logging for service connections and message processing. Signed-off-by: Sammy Oina * build: Update package source repositories and versions, migrate client logging to slog, and adjust ingress/egress proxy build and install steps. Signed-off-by: Sammy Oina * debug stuck Signed-off-by: Sammy Oina * debug Signed-off-by: Sammy Oina * debug Signed-off-by: Sammy Oina * feat: add HTTP/2 support to egress proxy and update build system to use specific commit hashes Signed-off-by: Sammy Oina * feat: enhance egress proxy CONNECT handling, update package sources, and add gRPC test utility Signed-off-by: Sammy Oina * feat: Update build system for various services to a specific commit from a new repository, change agent gRPC port to 7001, and add a gRPC test client. Signed-off-by: Sammy Oina * feat: Migrate agent-internal gRPC communication to Unix sockets, set ingress proxy to port 7002, and update build hashes. Signed-off-by: Sammy Oina * refactor: Remove standalone ingress-proxy systemd service and update component versions. Signed-off-by: Sammy Oina * fix: Prevent computation re-initialization in agent and update component versions across several packages. Signed-off-by: Sammy Oina * feat: update package versions and enable h2c support in ingress proxy. Signed-off-by: Sammy Oina * feat: refactor ingress proxy to support HTTP/2 over Unix sockets and update component versions. Signed-off-by: Sammy Oina * feat: Update build system package sources to `ultravioletrs/cocos` and reduce agent logging verbosity. Signed-off-by: Sammy Oina * refactor: improve error handling in proxy commands and remove unused gRPC test Signed-off-by: Sammy Oina * test: add mock service state return value in handleRunReqChunks test Signed-off-by: Sammy Oina * feat: add comprehensive tests for service and proxy components Signed-off-by: Sammy Oina * fix linter Signed-off-by: Sammy Oina * improve coverage Signed-off-by: Sammy Oina * test: add gRPC client and ingress adapter tests, and update egress proxy tests. Signed-off-by: Sammy Oina * improve coverage Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- .mockery.yml | 7 + Makefile | 4 +- agent/computations.go | 1 - agent/computations_test.go | 13 +- agent/cvms/api/grpc/client.go | 41 +- agent/cvms/api/grpc/client_test.go | 264 +++++- agent/cvms/api/grpc/server.go | 10 +- agent/cvms/server/cvm.go | 12 +- agent/cvms/server/cvm_test.go | 36 +- agent/log/log.pb.go | 261 ++++++ agent/log/log.proto | 32 + agent/log/log_grpc.pb.go | 163 ++++ agent/log/service/service.go | 59 ++ agent/log/service/service_test.go | 303 +++++++ agent/runner/events/adapter.go | 38 + agent/runner/events/adapter_test.go | 138 +++ agent/runner/runner.pb.go | 341 ++++++++ agent/runner/runner.proto | 38 + agent/runner/runner_grpc.pb.go | 163 ++++ agent/runner/service/service.go | 141 ++++ agent/runner/service/service_test.go | 271 ++++++ agent/service.go | 93 +- agent/service_test.go | 93 +- cmd/agent/main.go | 103 ++- cmd/computation-runner/main.go | 123 +++ cmd/egress-proxy/main.go | 88 ++ cmd/ingress-proxy/main.go | 137 +++ cmd/log-forwarder/main.go | 155 ++++ hal/linux/Config.in | 4 + hal/linux/package/agent/Config.in | 4 + .../package/computation-runner/Config.in | 5 + .../computation-runner/computation-runner.mk | 22 + hal/linux/package/egress-proxy/Config.in | 6 + .../package/egress-proxy/egress-proxy.mk | 22 + hal/linux/package/ingress-proxy/Config.in | 4 + .../package/ingress-proxy/ingress-proxy.mk | 22 + hal/linux/package/log-forwarder/Config.in | 4 + .../package/log-forwarder/log-forwarder.mk | 22 + init/systemd/cocos-agent.service | 11 +- init/systemd/computation-runner.service | 20 + init/systemd/egress-proxy.service | 15 + init/systemd/log-forwarder.service | 17 + internal/logger/protohandler.go | 1 + pkg/clients/grpc/attestation/client_test.go | 392 +++++++++ pkg/clients/grpc/log/client.go | 64 ++ pkg/clients/grpc/log/client_test.go | 332 ++++++++ pkg/clients/grpc/runner/client.go | 52 ++ pkg/clients/grpc/runner/client_test.go | 349 ++++++++ pkg/clients/grpc/runner/mocks/client.go | 223 +++++ pkg/egress/proxy.go | 248 ++++++ pkg/egress/proxy_test.go | 795 ++++++++++++++++++ pkg/ingress/adapter.go | 25 + pkg/ingress/adapter_test.go | 166 ++++ pkg/ingress/proxy.go | 223 +++++ pkg/ingress/proxy_test.go | 477 +++++++++++ pkg/server/grpc/grpc.go | 32 +- test/cvms/main.go | 6 + 57 files changed, 6529 insertions(+), 162 deletions(-) create mode 100644 agent/log/log.pb.go create mode 100644 agent/log/log.proto create mode 100644 agent/log/log_grpc.pb.go create mode 100644 agent/log/service/service.go create mode 100644 agent/log/service/service_test.go create mode 100644 agent/runner/events/adapter.go create mode 100644 agent/runner/events/adapter_test.go create mode 100644 agent/runner/runner.pb.go create mode 100644 agent/runner/runner.proto create mode 100644 agent/runner/runner_grpc.pb.go create mode 100644 agent/runner/service/service.go create mode 100644 agent/runner/service/service_test.go create mode 100644 cmd/computation-runner/main.go create mode 100644 cmd/egress-proxy/main.go create mode 100644 cmd/ingress-proxy/main.go create mode 100644 cmd/log-forwarder/main.go create mode 100644 hal/linux/package/computation-runner/Config.in create mode 100644 hal/linux/package/computation-runner/computation-runner.mk create mode 100644 hal/linux/package/egress-proxy/Config.in create mode 100644 hal/linux/package/egress-proxy/egress-proxy.mk create mode 100644 hal/linux/package/ingress-proxy/Config.in create mode 100644 hal/linux/package/ingress-proxy/ingress-proxy.mk create mode 100644 hal/linux/package/log-forwarder/Config.in create mode 100644 hal/linux/package/log-forwarder/log-forwarder.mk create mode 100644 init/systemd/computation-runner.service create mode 100644 init/systemd/egress-proxy.service create mode 100644 init/systemd/log-forwarder.service create mode 100644 pkg/clients/grpc/attestation/client_test.go create mode 100644 pkg/clients/grpc/log/client.go create mode 100644 pkg/clients/grpc/log/client_test.go create mode 100644 pkg/clients/grpc/runner/client.go create mode 100644 pkg/clients/grpc/runner/client_test.go create mode 100644 pkg/clients/grpc/runner/mocks/client.go create mode 100644 pkg/egress/proxy.go create mode 100644 pkg/egress/proxy_test.go create mode 100644 pkg/ingress/adapter.go create mode 100644 pkg/ingress/adapter_test.go create mode 100644 pkg/ingress/proxy.go create mode 100644 pkg/ingress/proxy_test.go diff --git a/.mockery.yml b/.mockery.yml index ca0e0070..003f2253 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -146,3 +146,10 @@ packages: dir: '{{.InterfaceDir}}/mocks' structname: '{{.InterfaceName}}' filename: "{{.InterfaceName | lower}}.go" + github.com/ultravioletrs/cocos/pkg/clients/grpc/runner: + interfaces: + Client: + config: + dir: '{{.InterfaceDir}}/mocks' + structname: '{{.InterfaceName}}' + filename: "{{.InterfaceName | lower}}.go" diff --git a/Makefile b/Makefile index ccbb35ed..7f0d8b33 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ BUILD_DIR = build -SERVICES = manager agent cli attestation-service +SERVICES = manager agent cli attestation-service log-forwarder computation-runner egress-proxy ingress-proxy ATTESTATION_POLICY = attestation_policy CGO_ENABLED ?= 0 GOARCH ?= amd64 @@ -41,6 +41,8 @@ protoc: protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/cvms/cvms.proto protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative internal/proto/attestation/v1/attestation.proto + protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/log/log.proto + protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/runner/runner.proto mocks: mockery --config ./.mockery.yml diff --git a/agent/computations.go b/agent/computations.go index ed831577..00191db6 100644 --- a/agent/computations.go +++ b/agent/computations.go @@ -13,7 +13,6 @@ import ( var _ fmt.Stringer = (*Datasets)(nil) type AgentConfig struct { - Port string `json:"port,omitempty"` CertFile string `json:"cert_file,omitempty"` KeyFile string `json:"server_key,omitempty"` ServerCAFile string `json:"server_ca_file,omitempty"` diff --git a/agent/computations_test.go b/agent/computations_test.go index 052375a6..1406eed6 100644 --- a/agent/computations_test.go +++ b/agent/computations_test.go @@ -105,16 +105,15 @@ func TestDecompressToContext(t *testing.T) { } func TestAgentConfigJSON(t *testing.T) { - config := AgentConfig{ - Port: "8080", + cfg := AgentConfig{ CertFile: "cert.pem", KeyFile: "key.pem", - ServerCAFile: "server_ca.pem", - ClientCAFile: "client_ca.pem", + ServerCAFile: "server-ca.pem", + ClientCAFile: "client-ca.pem", AttestedTls: true, } - data, err := json.Marshal(config) + data, err := json.Marshal(cfg) if err != nil { t.Fatalf("Failed to marshal AgentConfig: %v", err) } @@ -125,7 +124,7 @@ func TestAgentConfigJSON(t *testing.T) { t.Fatalf("Failed to unmarshal AgentConfig: %v", err) } - if !reflect.DeepEqual(config, unmarshaledConfig) { - t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config) + if !reflect.DeepEqual(cfg, unmarshaledConfig) { + t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, cfg) } } diff --git a/agent/cvms/api/grpc/client.go b/agent/cvms/api/grpc/client.go index 561f42ba..b25f0485 100644 --- a/agent/cvms/api/grpc/client.go +++ b/agent/cvms/api/grpc/client.go @@ -5,6 +5,7 @@ package grpc import ( "context" "encoding/json" + "fmt" "log/slog" "sync" "time" @@ -17,6 +18,7 @@ import ( "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/clients/grpc" + "github.com/ultravioletrs/cocos/pkg/ingress" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" ) @@ -44,13 +46,14 @@ type CVMSClient struct { logger *slog.Logger runReqManager *runRequestManager sp server.AgentServer + ingressProxy ingress.ProxyServer storage storage.Storage reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error) grpcClient grpc.Client } // NewClient returns new gRPC client instance. -func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) { +func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, ingressProxy ingress.ProxyServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) { store, err := storage.NewFileStorage(storageDir) if err != nil { return nil, err @@ -63,6 +66,7 @@ func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueu logger: logger, runReqManager: newRunRequestManager(), sp: sp, + ingressProxy: ingressProxy, storage: store, reconnectFn: reconnectFn, grpcClient: grpcClient, @@ -205,14 +209,17 @@ func (client *CVMSClient) handleAgentStateReq(mes *cvms.ServerStreamMessage_Agen } func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error { + client.logger.Debug("Received RunReq chunk", "id", msg.RunReqChunks.Id, "size", len(msg.RunReqChunks.Data), "isLast", msg.RunReqChunks.IsLast) buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast) if complete { + client.logger.Info("Received complete computation run request", "id", msg.RunReqChunks.Id, "totalSize", len(buffer)) var runReq cvms.ComputationRunReq if err := proto.Unmarshal(buffer, &runReq); err != nil { return errors.Wrap(err, errCorruptedManifest) } + client.logger.Info("Starting computation execution", "computationId", runReq.Id, "name", runReq.Name) go client.executeRun(ctx, &runReq) } @@ -246,6 +253,15 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati }) } + // Check if the agent is in the correct state to initialize a new computation. + // If the agent is already processing this computation (e.g., after a reconnection), + // skip initialization to avoid state errors. + currentState := client.svc.State() + if currentState != "ReceivingManifest" { + client.logger.Info("Agent already processing computation, skipping initialization", "state", currentState, "computationId", runReq.Id) + return + } + if err := client.svc.InitComputation(ctx, ac); err != nil { client.logger.Warn(err.Error()) return @@ -267,7 +283,6 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati } if err := client.sp.Start(agent.AgentConfig{ - Port: runReq.AgentConfig.Port, CertFile: runReq.AgentConfig.CertFile, KeyFile: runReq.AgentConfig.KeyFile, ServerCAFile: runReq.AgentConfig.ServerCaFile, @@ -278,6 +293,22 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati runRes.RunRes.Error = err.Error() } + // Start ingress proxy if available + if client.ingressProxy != nil { + if err := client.ingressProxy.Start( + ingress.AgentConfigToProxyConfig(agent.AgentConfig{ + CertFile: runReq.AgentConfig.CertFile, + KeyFile: runReq.AgentConfig.KeyFile, + ServerCAFile: runReq.AgentConfig.ServerCaFile, + ClientCAFile: runReq.AgentConfig.ClientCaFile, + AttestedTls: runReq.AgentConfig.AttestedTls, + }), + ingress.ComputationToProxyContext(ac), + ); err != nil { + client.logger.Warn(fmt.Sprintf("failed to start ingress proxy: %s", err.Error())) + } + } + defer func() { if ccPlatform == attestation.Azure || ccPlatform == attestation.SNPvTPM { cmpJson, err := json.Marshal(ac) @@ -309,6 +340,12 @@ func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.S if err := client.sp.Stop(); err != nil { msg.StopComputationRes.Message = err.Error() } + // Stop ingress proxy if available + if client.ingressProxy != nil { + if err := client.ingressProxy.Stop(); err != nil { + client.logger.Warn(fmt.Sprintf("failed to stop ingress proxy: %s", err.Error())) + } + } client.mu.Unlock() client.sendMessage(&cvms.ClientStreamMessage{Message: msg}) diff --git a/agent/cvms/api/grpc/client_test.go b/agent/cvms/api/grpc/client_test.go index 7d832f3f..8a2cc478 100644 --- a/agent/cvms/api/grpc/client_test.go +++ b/agent/cvms/api/grpc/client_test.go @@ -11,10 +11,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/ultravioletrs/cocos/agent/cvms" + "github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage" servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks" "github.com/ultravioletrs/cocos/agent/mocks" pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc" clientmocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/mocks" + "github.com/ultravioletrs/cocos/pkg/ingress" "golang.org/x/crypto/sha3" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -35,6 +37,21 @@ func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error { return args.Error(0) } +// mockIngressProxy is a mock implementation of the ingress proxy. +type mockIngressProxy struct { + mock.Mock +} + +func (m *mockIngressProxy) Start(config ingress.ProxyConfig, ctx ingress.ProxyContext) error { + args := m.Called(config, ctx) + return args.Error(0) +} + +func (m *mockIngressProxy) Stop() error { + args := m.Called() + return args.Error(0) +} + func TestManagerClient_Process(t *testing.T) { tests := []struct { name string @@ -121,7 +138,7 @@ func TestManagerClient_Process(t *testing.T) { grpcClient := new(clientmocks.Client) - client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) @@ -151,7 +168,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) { logger := mglog.NewMock() grpcClient := new(clientmocks.Client) - client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) assert.NoError(t, err) runReq := &cvms.ComputationRunReq{ @@ -187,6 +204,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) { }, } + mockSvc.On("State").Return("ReceivingManifest") mockSvc.On("InitComputation", mock.Anything, mock.Anything).Return(nil) mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -216,7 +234,7 @@ func TestManagerClient_handleStopComputation(t *testing.T) { logger := mglog.NewMock() grpcClient := new(clientmocks.Client) - client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) assert.NoError(t, err) stopReq := &cvms.ServerStreamMessage_StopComputation{ @@ -255,3 +273,243 @@ func TestManagerClient_timeoutRequest(t *testing.T) { assert.Len(t, rm.requests, 0) } + +// TestManagerClient_sendPendingMessages tests sending pending messages on reconnection. +func TestManagerClient_sendPendingMessages(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + mockServerSvc := new(servermocks.AgentServer) + messageQueue := make(chan *cvms.ClientStreamMessage, 10) + logger := mglog.NewMock() + grpcClient := new(clientmocks.Client) + + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + assert.NoError(t, err) + + // Add a pending message to storage + testMsg := &cvms.ClientStreamMessage{ + Message: &cvms.ClientStreamMessage_RunRes{ + RunRes: &cvms.RunResponse{ + ComputationId: "test-id", + }, + }, + } + err = client.storage.Add(testMsg) + assert.NoError(t, err) + + // Mock successful send + mockStream.On("Send", mock.Anything).Return(nil).Once() + + // Load and send pending messages + pending, err := client.storage.Load() + assert.NoError(t, err) + assert.Len(t, pending, 1) + + client.sendPendingMessages(pending) + + mockStream.AssertExpectations(t) +} + +// TestManagerClient_sendPendingMessagesWithError tests pending message send failure. +func TestManagerClient_sendPendingMessagesWithError(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + mockServerSvc := new(servermocks.AgentServer) + messageQueue := make(chan *cvms.ClientStreamMessage, 10) + logger := mglog.NewMock() + grpcClient := new(clientmocks.Client) + + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + assert.NoError(t, err) + + testMsg := &cvms.ClientStreamMessage{ + Message: &cvms.ClientStreamMessage_RunRes{ + RunRes: &cvms.RunResponse{ + ComputationId: "test-id", + }, + }, + } + + // Mock failed send + mockStream.On("Send", mock.Anything).Return(assert.AnError) + + pending := []storage.Message{ + { + Message: testMsg, + Time: time.Now(), + }, + } + + client.sendPendingMessages(pending) + + mockStream.AssertExpectations(t) +} + +// TestManagerClient_addChunkTimeout tests chunk timeout in runRequestManager. +func TestManagerClient_addChunkTimeout(t *testing.T) { + rm := newRunRequestManager() + + // Add first chunk + chunk1 := []byte("chunk1") + buffer, complete := rm.addChunk("test-id", chunk1, false) + assert.Nil(t, buffer) + assert.False(t, complete) + + // Verify request exists + rm.mu.Lock() + assert.Contains(t, rm.requests, "test-id") + rm.mu.Unlock() + + // Wait for timeout + time.Sleep(35 * time.Second) // runReqTimeout is 30 seconds + + // Verify request was removed + rm.mu.Lock() + assert.NotContains(t, rm.requests, "test-id") + rm.mu.Unlock() +} + +// TestManagerClient_addChunkMultiple tests adding multiple chunks. +func TestManagerClient_addChunkMultiple(t *testing.T) { + rm := newRunRequestManager() + + chunk1 := []byte("chunk1") + chunk2 := []byte("chunk2") + chunk3 := []byte("chunk3") + + // Add chunks + buffer, complete := rm.addChunk("test-id", chunk1, false) + assert.Nil(t, buffer) + assert.False(t, complete) + + buffer, complete = rm.addChunk("test-id", chunk2, false) + assert.Nil(t, buffer) + assert.False(t, complete) + + buffer, complete = rm.addChunk("test-id", chunk3, true) + assert.NotNil(t, buffer) + assert.True(t, complete) + + expected := append(append(chunk1, chunk2...), chunk3...) + assert.Equal(t, expected, buffer) +} + +// TestManagerClient_handleStopComputationWithIngressProxy tests stop with ingress proxy. +func TestManagerClient_handleStopComputationWithIngressProxy(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + mockServerSvc := new(servermocks.AgentServer) + mockIngressProxy := new(mockIngressProxy) + messageQueue := make(chan *cvms.ClientStreamMessage, 10) + logger := mglog.NewMock() + grpcClient := new(clientmocks.Client) + + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + assert.NoError(t, err) + + stopReq := &cvms.ServerStreamMessage_StopComputation{ + StopComputation: &cvms.StopComputation{ + ComputationId: "test-comp-id", + }, + } + + mockSvc.On("StopComputation", mock.Anything).Return(nil) + mockServerSvc.On("Stop").Return(nil) + mockIngressProxy.On("Stop").Return(nil) + + client.handleStopComputation(context.Background(), stopReq) + + time.Sleep(50 * time.Millisecond) + + mockSvc.AssertExpectations(t) + mockServerSvc.AssertExpectations(t) + mockIngressProxy.AssertExpectations(t) + assert.Len(t, messageQueue, 1) +} + +// TestManagerClient_handleStopComputationWithIngressProxyError tests stop with ingress proxy error. +func TestManagerClient_handleStopComputationWithIngressProxyError(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + mockServerSvc := new(servermocks.AgentServer) + mockIngressProxy := new(mockIngressProxy) + messageQueue := make(chan *cvms.ClientStreamMessage, 10) + logger := mglog.NewMock() + grpcClient := new(clientmocks.Client) + + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + assert.NoError(t, err) + + stopReq := &cvms.ServerStreamMessage_StopComputation{ + StopComputation: &cvms.StopComputation{ + ComputationId: "test-comp-id", + }, + } + + mockSvc.On("StopComputation", mock.Anything).Return(nil) + mockServerSvc.On("Stop").Return(nil) + mockIngressProxy.On("Stop").Return(assert.AnError) + + client.handleStopComputation(context.Background(), stopReq) + + time.Sleep(50 * time.Millisecond) + + mockIngressProxy.AssertExpectations(t) +} + +// TestManagerClient_sendMessage tests sendMessage with timeout. +func TestManagerClient_sendMessage(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + mockServerSvc := new(servermocks.AgentServer) + messageQueue := make(chan *cvms.ClientStreamMessage, 1) + logger := mglog.NewMock() + grpcClient := new(clientmocks.Client) + + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + assert.NoError(t, err) + + msg := &cvms.ClientStreamMessage{ + Message: &cvms.ClientStreamMessage_RunRes{ + RunRes: &cvms.RunResponse{ + ComputationId: "test-id", + }, + }, + } + + client.sendMessage(msg) + + select { + case received := <-messageQueue: + assert.Equal(t, msg, received) + case <-time.After(1 * time.Second): + t.Fatal("Message not received") + } +} + +// TestManagerClient_sendMessageTimeout tests sendMessage timeout when queue is full. +func TestManagerClient_sendMessageTimeout(t *testing.T) { + mockStream := new(mockStream) + mockSvc := new(mocks.Service) + mockServerSvc := new(servermocks.AgentServer) + messageQueue := make(chan *cvms.ClientStreamMessage) // No buffer + logger := mglog.NewMock() + grpcClient := new(clientmocks.Client) + + client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient) + assert.NoError(t, err) + + msg := &cvms.ClientStreamMessage{ + Message: &cvms.ClientStreamMessage_RunRes{ + RunRes: &cvms.RunResponse{ + ComputationId: "test-id", + }, + }, + } + + // Don't read from queue, so sendMessage will timeout + client.sendMessage(msg) + + // Should complete without blocking + time.Sleep(100 * time.Millisecond) +} diff --git a/agent/cvms/api/grpc/server.go b/agent/cvms/api/grpc/server.go index 53add3d1..dc9ad035 100644 --- a/agent/cvms/api/grpc/server.go +++ b/agent/cvms/api/grpc/server.go @@ -7,6 +7,7 @@ import ( "context" "errors" "io" + "log/slog" "time" "github.com/ultravioletrs/cocos/agent/cvms" @@ -52,16 +53,20 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error { return errors.New("failed to get peer info") } + slog.Info("client connected to cvms server", "address", client.Addr.String()) + eg, ctx := errgroup.WithContext(stream.Context()) eg.Go(func() error { for { select { case <-ctx.Done(): + slog.Info("receive goroutine context done", "address", client.Addr.String()) return ctx.Err() default: req, err := stream.Recv() if err != nil { + slog.Error("failed to receive from stream", "address", client.Addr.String(), "error", err) return err } s.incoming <- req @@ -85,10 +90,13 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error { } s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo) + slog.Info("send goroutine Run() returned", "address", client.Addr.String()) return nil }) - return eg.Wait() + err := eg.Wait() + slog.Info("stream closed", "address", client.Addr.String(), "error", err) + return err } func (s *grpcServer) sendRunReqInChunks(stream cvms.Service_ProcessServer, runReq *cvms.ComputationRunReq) error { diff --git a/agent/cvms/server/cvm.go b/agent/cvms/server/cvm.go index a7c20939..2c989e64 100644 --- a/agent/cvms/server/cvm.go +++ b/agent/cvms/server/cvm.go @@ -19,8 +19,8 @@ import ( ) const ( - svcName = "agent" - defSvcGRPCPort = "7002" + svcName = "agent" + defSvcGRPCSocket = "/run/cocos/agent.sock" ) type AgentServer interface { @@ -46,15 +46,11 @@ func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider } func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error { - if cfg.Port == "" { - cfg.Port = defSvcGRPCPort - } - agentGrpcServerConfig := server.AgentConfig{ ServerConfig: server.ServerConfig{ Config: server.Config{ - Host: as.host, - Port: cfg.Port, + Host: defSvcGRPCSocket, + Port: "", CertFile: cfg.CertFile, KeyFile: cfg.KeyFile, ServerCAFile: cfg.ServerCAFile, diff --git a/agent/cvms/server/cvm_test.go b/agent/cvms/server/cvm_test.go index 1d7d9da3..ddfbc2a5 100644 --- a/agent/cvms/server/cvm_test.go +++ b/agent/cvms/server/cvm_test.go @@ -97,7 +97,6 @@ func TestAgentServer_Start(t *testing.T) { { name: "successful start with default port", cfg: agent.AgentConfig{ - Port: "", CertFile: "cert.pem", KeyFile: "key.pem", ServerCAFile: "server-ca.pem", @@ -131,7 +130,6 @@ func TestAgentServer_Start(t *testing.T) { { name: "successful start with custom port", cfg: agent.AgentConfig{ - Port: "8080", CertFile: "cert.pem", KeyFile: "key.pem", ServerCAFile: "server-ca.pem", @@ -165,7 +163,6 @@ func TestAgentServer_Start(t *testing.T) { { name: "start with minimal config", cfg: agent.AgentConfig{ - Port: "9090", AttestedTls: false, }, cmp: agent.Computation{ @@ -243,9 +240,7 @@ func TestAgentServer_Stop(t *testing.T) { { name: "stop started server", setupServer: func(server AgentServer) error { - cfg := agent.AgentConfig{ - Port: "7004", - } + cfg := agent.AgentConfig{} cmp := agent.Computation{ ID: "test-stop-computation", Name: "Stop Test", @@ -304,7 +299,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) { server := NewServer(logger, svc, host, nil) // Start the server - cfg := agent.AgentConfig{Port: "7005"} + cfg := agent.AgentConfig{} cmp := agent.Computation{ ID: "test-multiple-stop", Name: "Multiple Stop Test", @@ -347,7 +342,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) { logger, svc, host, pubKey := setupTest(t) server := NewServer(logger, svc, host, nil) - cfg := agent.AgentConfig{Port: "7006"} + cfg := agent.AgentConfig{} cmp := agent.Computation{ ID: "test-restart", Name: "Restart Test", @@ -378,7 +373,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) { assert.NoError(t, err) // Start again with different config - cfg2 := agent.AgentConfig{Port: "7007"} + cfg2 := agent.AgentConfig{} cmp2 := agent.Computation{ ID: "test-restart-2", Name: "Restart Test 2", @@ -422,7 +417,6 @@ func TestAgentServer_ConfigValidation(t *testing.T) { { name: "valid config with all fields", config: agent.AgentConfig{ - Port: "8080", CertFile: "cert.pem", KeyFile: "key.pem", ServerCAFile: "server-ca.pem", @@ -451,10 +445,8 @@ func TestAgentServer_ConfigValidation(t *testing.T) { valid: true, }, { - name: "valid config with minimal fields", - config: agent.AgentConfig{ - Port: "9090", - }, + name: "valid config with minimal fields", + config: agent.AgentConfig{}, cmp: agent.Computation{ ID: "minimal-config-test", Name: "Minimal Config Test", @@ -477,10 +469,8 @@ func TestAgentServer_ConfigValidation(t *testing.T) { valid: true, }, { - name: "config with empty port uses default", - config: agent.AgentConfig{ - Port: "", - }, + name: "config with empty port uses default", + config: agent.AgentConfig{}, cmp: agent.Computation{ ID: "default-port-test", Name: "Default Port Test", @@ -505,11 +495,9 @@ func TestAgentServer_ConfigValidation(t *testing.T) { if tt.valid { assert.NoError(t, err) - // Verify default port is used when empty - if tt.config.Port == "" { - agentSrv := server.(*agentServer) - assert.NotNil(t, agentSrv.gs) - } + // Verify server started successfully + agentSrv := server.(*agentServer) + assert.NotNil(t, agentSrv.gs) time.Sleep(10 * time.Millisecond) if err := server.Stop(); err != nil { @@ -526,5 +514,5 @@ func TestAgentServer_ConfigValidation(t *testing.T) { func TestConstants(t *testing.T) { assert.Equal(t, "agent", svcName) - assert.Equal(t, "7002", defSvcGRPCPort) + assert.Equal(t, "/run/cocos/agent.sock", defSvcGRPCSocket) } diff --git a/agent/log/log.pb.go b/agent/log/log.pb.go new file mode 100644 index 00000000..0b4e405d --- /dev/null +++ b/agent/log/log.pb.go @@ -0,0 +1,261 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.8 +// protoc v6.33.1 +// source: agent/log/log.proto + +package log + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type LogEntry struct { + state protoimpl.MessageState `protogen:"open.v1"` + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"` + Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"` + Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *LogEntry) Reset() { + *x = LogEntry{} + mi := &file_agent_log_log_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *LogEntry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*LogEntry) ProtoMessage() {} + +func (x *LogEntry) ProtoReflect() protoreflect.Message { + mi := &file_agent_log_log_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use LogEntry.ProtoReflect.Descriptor instead. +func (*LogEntry) Descriptor() ([]byte, []int) { + return file_agent_log_log_proto_rawDescGZIP(), []int{0} +} + +func (x *LogEntry) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *LogEntry) GetComputationId() string { + if x != nil { + return x.ComputationId + } + return "" +} + +func (x *LogEntry) GetLevel() string { + if x != nil { + return x.Level + } + return "" +} + +func (x *LogEntry) GetTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.Timestamp + } + return nil +} + +type EventEntry struct { + state protoimpl.MessageState `protogen:"open.v1"` + EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"` + Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"` + Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"` // JSON payload + Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"` + Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EventEntry) Reset() { + *x = EventEntry{} + mi := &file_agent_log_log_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EventEntry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EventEntry) ProtoMessage() {} + +func (x *EventEntry) ProtoReflect() protoreflect.Message { + mi := &file_agent_log_log_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EventEntry.ProtoReflect.Descriptor instead. +func (*EventEntry) Descriptor() ([]byte, []int) { + return file_agent_log_log_proto_rawDescGZIP(), []int{1} +} + +func (x *EventEntry) GetEventType() string { + if x != nil { + return x.EventType + } + return "" +} + +func (x *EventEntry) GetTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.Timestamp + } + return nil +} + +func (x *EventEntry) GetComputationId() string { + if x != nil { + return x.ComputationId + } + return "" +} + +func (x *EventEntry) GetDetails() []byte { + if x != nil { + return x.Details + } + return nil +} + +func (x *EventEntry) GetOriginator() string { + if x != nil { + return x.Originator + } + return "" +} + +func (x *EventEntry) GetStatus() string { + if x != nil { + return x.Status + } + return "" +} + +var File_agent_log_log_proto protoreflect.FileDescriptor + +const file_agent_log_log_proto_rawDesc = "" + + "\n" + + "\x13agent/log/log.proto\x12\x03log\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x9b\x01\n" + + "\bLogEntry\x12\x18\n" + + "\amessage\x18\x01 \x01(\tR\amessage\x12%\n" + + "\x0ecomputation_id\x18\x02 \x01(\tR\rcomputationId\x12\x14\n" + + "\x05level\x18\x03 \x01(\tR\x05level\x128\n" + + "\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\xde\x01\n" + + "\n" + + "EventEntry\x12\x1d\n" + + "\n" + + "event_type\x18\x01 \x01(\tR\teventType\x128\n" + + "\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12%\n" + + "\x0ecomputation_id\x18\x03 \x01(\tR\rcomputationId\x12\x18\n" + + "\adetails\x18\x04 \x01(\fR\adetails\x12\x1e\n" + + "\n" + + "originator\x18\x05 \x01(\tR\n" + + "originator\x12\x16\n" + + "\x06status\x18\x06 \x01(\tR\x06status2v\n" + + "\fLogCollector\x120\n" + + "\aSendLog\x12\r.log.LogEntry\x1a\x16.google.protobuf.Empty\x124\n" + + "\tSendEvent\x12\x0f.log.EventEntry\x1a\x16.google.protobuf.EmptyB\aZ\x05./logb\x06proto3" + +var ( + file_agent_log_log_proto_rawDescOnce sync.Once + file_agent_log_log_proto_rawDescData []byte +) + +func file_agent_log_log_proto_rawDescGZIP() []byte { + file_agent_log_log_proto_rawDescOnce.Do(func() { + file_agent_log_log_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc))) + }) + return file_agent_log_log_proto_rawDescData +} + +var file_agent_log_log_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_agent_log_log_proto_goTypes = []any{ + (*LogEntry)(nil), // 0: log.LogEntry + (*EventEntry)(nil), // 1: log.EventEntry + (*timestamppb.Timestamp)(nil), // 2: google.protobuf.Timestamp + (*emptypb.Empty)(nil), // 3: google.protobuf.Empty +} +var file_agent_log_log_proto_depIdxs = []int32{ + 2, // 0: log.LogEntry.timestamp:type_name -> google.protobuf.Timestamp + 2, // 1: log.EventEntry.timestamp:type_name -> google.protobuf.Timestamp + 0, // 2: log.LogCollector.SendLog:input_type -> log.LogEntry + 1, // 3: log.LogCollector.SendEvent:input_type -> log.EventEntry + 3, // 4: log.LogCollector.SendLog:output_type -> google.protobuf.Empty + 3, // 5: log.LogCollector.SendEvent:output_type -> google.protobuf.Empty + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_agent_log_log_proto_init() } +func file_agent_log_log_proto_init() { + if File_agent_log_log_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)), + NumEnums: 0, + NumMessages: 2, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_agent_log_log_proto_goTypes, + DependencyIndexes: file_agent_log_log_proto_depIdxs, + MessageInfos: file_agent_log_log_proto_msgTypes, + }.Build() + File_agent_log_log_proto = out.File + file_agent_log_log_proto_goTypes = nil + file_agent_log_log_proto_depIdxs = nil +} diff --git a/agent/log/log.proto b/agent/log/log.proto new file mode 100644 index 00000000..2643d506 --- /dev/null +++ b/agent/log/log.proto @@ -0,0 +1,32 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package log; + +option go_package = "./log"; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/empty.proto"; + +service LogCollector { + rpc SendLog(LogEntry) returns (google.protobuf.Empty); + rpc SendEvent(EventEntry) returns (google.protobuf.Empty); +} + +message LogEntry { + string message = 1; + string computation_id = 2; + string level = 3; + google.protobuf.Timestamp timestamp = 4; +} + +message EventEntry { + string event_type = 1; + google.protobuf.Timestamp timestamp = 2; + string computation_id = 3; + bytes details = 4; // JSON payload + string originator = 5; + string status = 6; +} diff --git a/agent/log/log_grpc.pb.go b/agent/log/log_grpc.pb.go new file mode 100644 index 00000000..41bf2d4a --- /dev/null +++ b/agent/log/log_grpc.pb.go @@ -0,0 +1,163 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v6.33.1 +// source: agent/log/log.proto + +package log + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + LogCollector_SendLog_FullMethodName = "/log.LogCollector/SendLog" + LogCollector_SendEvent_FullMethodName = "/log.LogCollector/SendEvent" +) + +// LogCollectorClient is the client API for LogCollector service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type LogCollectorClient interface { + SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) + SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) +} + +type logCollectorClient struct { + cc grpc.ClientConnInterface +} + +func NewLogCollectorClient(cc grpc.ClientConnInterface) LogCollectorClient { + return &logCollectorClient{cc} +} + +func (c *logCollectorClient) SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(emptypb.Empty) + err := c.cc.Invoke(ctx, LogCollector_SendLog_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *logCollectorClient) SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(emptypb.Empty) + err := c.cc.Invoke(ctx, LogCollector_SendEvent_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// LogCollectorServer is the server API for LogCollector service. +// All implementations must embed UnimplementedLogCollectorServer +// for forward compatibility. +type LogCollectorServer interface { + SendLog(context.Context, *LogEntry) (*emptypb.Empty, error) + SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error) + mustEmbedUnimplementedLogCollectorServer() +} + +// UnimplementedLogCollectorServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedLogCollectorServer struct{} + +func (UnimplementedLogCollectorServer) SendLog(context.Context, *LogEntry) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method SendLog not implemented") +} +func (UnimplementedLogCollectorServer) SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method SendEvent not implemented") +} +func (UnimplementedLogCollectorServer) mustEmbedUnimplementedLogCollectorServer() {} +func (UnimplementedLogCollectorServer) testEmbeddedByValue() {} + +// UnsafeLogCollectorServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to LogCollectorServer will +// result in compilation errors. +type UnsafeLogCollectorServer interface { + mustEmbedUnimplementedLogCollectorServer() +} + +func RegisterLogCollectorServer(s grpc.ServiceRegistrar, srv LogCollectorServer) { + // If the following call pancis, it indicates UnimplementedLogCollectorServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&LogCollector_ServiceDesc, srv) +} + +func _LogCollector_SendLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(LogEntry) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LogCollectorServer).SendLog(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: LogCollector_SendLog_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LogCollectorServer).SendLog(ctx, req.(*LogEntry)) + } + return interceptor(ctx, in, info, handler) +} + +func _LogCollector_SendEvent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EventEntry) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LogCollectorServer).SendEvent(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: LogCollector_SendEvent_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LogCollectorServer).SendEvent(ctx, req.(*EventEntry)) + } + return interceptor(ctx, in, info, handler) +} + +// LogCollector_ServiceDesc is the grpc.ServiceDesc for LogCollector service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var LogCollector_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "log.LogCollector", + HandlerType: (*LogCollectorServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "SendLog", + Handler: _LogCollector_SendLog_Handler, + }, + { + MethodName: "SendEvent", + Handler: _LogCollector_SendEvent_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "agent/log/log.proto", +} diff --git a/agent/log/service/service.go b/agent/log/service/service.go new file mode 100644 index 00000000..78b72e86 --- /dev/null +++ b/agent/log/service/service.go @@ -0,0 +1,59 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package service + +import ( + "context" + "log/slog" + + "github.com/ultravioletrs/cocos/agent/cvms" + "github.com/ultravioletrs/cocos/agent/log" + "google.golang.org/protobuf/types/known/emptypb" +) + +var _ log.LogCollectorServer = (*LogForwarder)(nil) + +type LogForwarder struct { + log.UnimplementedLogCollectorServer + cvmsClient cvms.ServiceClient + logger *slog.Logger + logQueue chan *cvms.ClientStreamMessage +} + +func New(logger *slog.Logger, cvmsClient cvms.ServiceClient, queue chan *cvms.ClientStreamMessage) *LogForwarder { + return &LogForwarder{ + cvmsClient: cvmsClient, + logger: logger, + logQueue: queue, + } +} + +func (s *LogForwarder) SendLog(ctx context.Context, req *log.LogEntry) (*emptypb.Empty, error) { + s.logQueue <- &cvms.ClientStreamMessage{ + Message: &cvms.ClientStreamMessage_AgentLog{ + AgentLog: &cvms.AgentLog{ + Message: req.Message, + ComputationId: req.ComputationId, + Level: req.Level, + Timestamp: req.Timestamp, + }, + }, + } + return &emptypb.Empty{}, nil +} + +func (s *LogForwarder) SendEvent(ctx context.Context, req *log.EventEntry) (*emptypb.Empty, error) { + s.logQueue <- &cvms.ClientStreamMessage{ + Message: &cvms.ClientStreamMessage_AgentEvent{ + AgentEvent: &cvms.AgentEvent{ + EventType: req.EventType, + Timestamp: req.Timestamp, + ComputationId: req.ComputationId, + Details: req.Details, + Originator: req.Originator, + Status: req.Status, + }, + }, + } + return &emptypb.Empty{}, nil +} diff --git a/agent/log/service/service_test.go b/agent/log/service/service_test.go new file mode 100644 index 00000000..9a570cb3 --- /dev/null +++ b/agent/log/service/service_test.go @@ -0,0 +1,303 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package service + +import ( + "context" + "encoding/json" + "log/slog" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/agent/cvms" + "github.com/ultravioletrs/cocos/agent/log" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// TestNewLogForwarder tests the creation of a new log forwarder. +func TestNewLogForwarder(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 10) + + lf := New(logger, nil, queue) + require.NotNil(t, lf) + assert.NotNil(t, lf.logger) + assert.Nil(t, lf.cvmsClient) + assert.NotNil(t, lf.logQueue) +} + +// TestSendLog tests sending a log entry. +func TestSendLog(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 10) + + lf := New(logger, nil, queue) + + req := &log.LogEntry{ + Message: "Test log message", + ComputationId: "computation-1", + Level: "INFO", + Timestamp: timestamppb.New(time.Now()), + } + + resp, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify message was queued + select { + case msg := <-queue: + require.NotNil(t, msg) + agentLog := msg.GetAgentLog() + assert.NotNil(t, agentLog) + assert.Equal(t, "Test log message", agentLog.Message) + assert.Equal(t, "computation-1", agentLog.ComputationId) + assert.Equal(t, "INFO", agentLog.Level) + default: + t.Fatal("No message in queue") + } +} + +// TestSendEvent tests sending an event entry. +func TestSendEvent(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 10) + + lf := New(logger, nil, queue) + + details, err := json.Marshal(map[string]string{"key": "value"}) + require.NoError(t, err) + + req := &log.EventEntry{ + EventType: "COMPUTATION_STARTED", + Timestamp: timestamppb.New(time.Now()), + ComputationId: "computation-1", + Details: details, + Originator: "runner", + Status: "SUCCESS", + } + + resp, err := lf.SendEvent(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify message was queued + select { + case msg := <-queue: + require.NotNil(t, msg) + agentEvent := msg.GetAgentEvent() + assert.NotNil(t, agentEvent) + assert.Equal(t, "COMPUTATION_STARTED", agentEvent.EventType) + assert.Equal(t, "computation-1", agentEvent.ComputationId) + assert.Equal(t, "runner", agentEvent.Originator) + assert.Equal(t, "SUCCESS", agentEvent.Status) + default: + t.Fatal("No message in queue") + } +} + +// TestSendMultipleLogs tests sending multiple log entries. +func TestSendMultipleLogs(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 100) + + lf := New(logger, nil, queue) + + for i := 0; i < 5; i++ { + req := &log.LogEntry{ + Message: "Log message", + ComputationId: "computation-1", + Level: "INFO", + Timestamp: timestamppb.New(time.Now()), + } + + resp, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + } + + assert.Equal(t, 5, len(queue)) +} + +// TestSendEventWithVariousTypes tests sending events with different types. +func TestSendEventWithVariousTypes(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 100) + + lf := New(logger, nil, queue) + + eventTypes := []string{"STARTED", "RUNNING", "COMPLETED", "FAILED"} + for _, eventType := range eventTypes { + details, err := json.Marshal(map[string]string{"type": eventType}) + require.NoError(t, err) + req := &log.EventEntry{ + EventType: eventType, + Timestamp: timestamppb.New(time.Now()), + ComputationId: "computation-1", + Details: details, + Originator: "runner", + Status: "OK", + } + + resp, err := lf.SendEvent(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + } + + assert.Equal(t, 4, len(queue)) +} + +// TestSendLogWithEmptyMessage tests sending log with empty message. +func TestSendLogWithEmptyMessage(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 10) + + lf := New(logger, nil, queue) + + req := &log.LogEntry{ + Message: "", + ComputationId: "computation-1", + Level: "INFO", + Timestamp: timestamppb.New(time.Now()), + } + + resp, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + select { + case msg := <-queue: + agentLog := msg.GetAgentLog() + assert.Equal(t, "", agentLog.Message) + default: + t.Fatal("No message in queue") + } +} + +// TestSendEventWithNilDetails tests sending event with nil details. +func TestSendEventWithNilDetails(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 10) + + lf := New(logger, nil, queue) + + req := &log.EventEntry{ + EventType: "TEST_EVENT", + Timestamp: timestamppb.New(time.Now()), + ComputationId: "computation-1", + Details: nil, + Originator: "test", + Status: "OK", + } + + resp, err := lf.SendEvent(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + select { + case msg := <-queue: + agentEvent := msg.GetAgentEvent() + assert.Nil(t, agentEvent.Details) + default: + t.Fatal("No message in queue") + } +} + +// TestSendLogWithVariousLevels tests sending logs with various severity levels. +func TestSendLogWithVariousLevels(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 100) + + lf := New(logger, nil, queue) + + levels := []string{"DEBUG", "INFO", "WARN", "ERROR"} + for _, level := range levels { + req := &log.LogEntry{ + Message: "Test " + level, + ComputationId: "computation-1", + Level: level, + Timestamp: timestamppb.New(time.Now()), + } + + resp, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + } + + assert.Equal(t, 4, len(queue)) +} + +// TestSendLogWithDifferentComputationIds tests sending logs with different computation IDs. +func TestSendLogWithDifferentComputationIds(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 100) + + lf := New(logger, nil, queue) + + for i := 0; i < 3; i++ { + req := &log.LogEntry{ + Message: "Message", + ComputationId: "computation-" + string(rune(48+i)), + Level: "INFO", + Timestamp: timestamppb.New(time.Now()), + } + + resp, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + } + + assert.Equal(t, 3, len(queue)) +} + +// TestQueueBehavior tests that queue is properly used. +func TestQueueBehavior(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 1) + + lf := New(logger, nil, queue) + + req := &log.LogEntry{ + Message: "Test", + ComputationId: "computation-1", + Level: "INFO", + Timestamp: timestamppb.New(time.Now()), + } + + resp, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Equal(t, 1, len(queue)) +} + +// TestConcurrentSendLog tests concurrent log sending. +func TestConcurrentSendLog(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + queue := make(chan *cvms.ClientStreamMessage, 100) + + lf := New(logger, nil, queue) + + for i := 0; i < 10; i++ { + go func(id int) { + req := &log.LogEntry{ + Message: "Concurrent log", + ComputationId: "computation-1", + Level: "INFO", + Timestamp: timestamppb.New(time.Now()), + } + + _, err := lf.SendLog(context.Background(), req) + require.NoError(t, err) + }(i) + } + + // Give goroutines time to complete + time.Sleep(100 * time.Millisecond) + + // Should have received all messages + assert.True(t, len(queue) > 0) +} diff --git a/agent/runner/events/adapter.go b/agent/runner/events/adapter.go new file mode 100644 index 00000000..7b238faa --- /dev/null +++ b/agent/runner/events/adapter.go @@ -0,0 +1,38 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package events + +import ( + "context" + "encoding/json" + "log/slog" + + "github.com/ultravioletrs/cocos/agent/events" + logpb "github.com/ultravioletrs/cocos/agent/log" + logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log" +) + +type adapter struct { + client logclient.Client + svc string +} + +func NewAdapter(client logclient.Client, svc string) events.Service { + return &adapter{ + client: client, + svc: svc, + } +} + +func (a *adapter) SendEvent(cmpID, event, status string, details json.RawMessage) { + err := a.client.SendEvent(context.Background(), &logpb.EventEntry{ + EventType: event, + ComputationId: cmpID, + Details: details, + Originator: a.svc, + Status: status, + }) + if err != nil { + slog.Error("failed to send event to log-forwarder", "error", err) + } +} diff --git a/agent/runner/events/adapter_test.go b/agent/runner/events/adapter_test.go new file mode 100644 index 00000000..a5610f24 --- /dev/null +++ b/agent/runner/events/adapter_test.go @@ -0,0 +1,138 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package events + +import ( + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + logpb "github.com/ultravioletrs/cocos/agent/log" +) + +const testServiceName = "test-service" + +// mockLogClient is a mock implementation of the log client. +type mockLogClient struct { + mock.Mock +} + +func (m *mockLogClient) SendLog(ctx context.Context, entry *logpb.LogEntry) error { + args := m.Called(ctx, entry) + return args.Error(0) +} + +func (m *mockLogClient) SendEvent(ctx context.Context, entry *logpb.EventEntry) error { + args := m.Called(ctx, entry) + return args.Error(0) +} + +func (m *mockLogClient) Close() error { + args := m.Called() + return args.Error(0) +} + +// TestNewAdapter tests creating a new adapter. +func TestNewAdapter(t *testing.T) { + mockClient := new(mockLogClient) + svc := testServiceName + + adapter := NewAdapter(mockClient, svc) + + assert.NotNil(t, adapter) +} + +// TestSendEvent tests sending an event successfully. +func TestSendEvent(t *testing.T) { + mockClient := new(mockLogClient) + svc := testServiceName + adapter := NewAdapter(mockClient, svc) + + cmpID := "test-computation-id" + event := "computation.started" + status := "success" + details := json.RawMessage(`{"key": "value"}`) + + expectedEntry := &logpb.EventEntry{ + EventType: event, + ComputationId: cmpID, + Details: details, + Originator: svc, + Status: status, + } + + mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil) + + adapter.SendEvent(cmpID, event, status, details) + + mockClient.AssertExpectations(t) + mockClient.AssertCalled(t, "SendEvent", mock.Anything, expectedEntry) +} + +// TestSendEventWithError tests sending an event when client returns an error. +func TestSendEventWithError(t *testing.T) { + mockClient := new(mockLogClient) + svc := testServiceName + adapter := NewAdapter(mockClient, svc) + + cmpID := "test-computation-id" + event := "computation.failed" + status := "error" + details := json.RawMessage(`{"error": "something went wrong"}`) + + mockClient.On("SendEvent", mock.Anything, mock.Anything).Return(assert.AnError) + + // This should not panic even when error occurs + adapter.SendEvent(cmpID, event, status, details) + + mockClient.AssertExpectations(t) + mockClient.AssertCalled(t, "SendEvent", mock.Anything, mock.Anything) +} + +// TestSendEventWithNilDetails tests sending an event with nil details. +func TestSendEventWithNilDetails(t *testing.T) { + mockClient := new(mockLogClient) + svc := "runner-service" + adapter := NewAdapter(mockClient, svc) + + cmpID := "comp-123" + event := "test.event" + status := "pending" + + expectedEntry := &logpb.EventEntry{ + EventType: event, + ComputationId: cmpID, + Details: nil, + Originator: svc, + Status: status, + } + + mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil) + + adapter.SendEvent(cmpID, event, status, nil) + + mockClient.AssertExpectations(t) +} + +// TestSendEventWithEmptyStrings tests sending an event with empty strings. +func TestSendEventWithEmptyStrings(t *testing.T) { + mockClient := new(mockLogClient) + svc := testServiceName + adapter := NewAdapter(mockClient, svc) + + expectedEntry := &logpb.EventEntry{ + EventType: "", + ComputationId: "", + Details: nil, + Originator: svc, + Status: "", + } + + mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil) + + adapter.SendEvent("", "", "", nil) + + mockClient.AssertExpectations(t) +} diff --git a/agent/runner/runner.pb.go b/agent/runner/runner.pb.go new file mode 100644 index 00000000..68c30a09 --- /dev/null +++ b/agent/runner/runner.pb.go @@ -0,0 +1,341 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.8 +// protoc v6.33.1 +// source: agent/runner/runner.proto + +package runner + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type RunRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"` + AlgoType string `protobuf:"bytes,2,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"` // "binary", "python", "wasm", "docker" + Algorithm []byte `protobuf:"bytes,3,opt,name=algorithm,proto3" json:"algorithm,omitempty"` // The algorithm binary/script content + Requirements []byte `protobuf:"bytes,4,opt,name=requirements,proto3" json:"requirements,omitempty"` // Python requirements.txt content + Args []string `protobuf:"bytes,5,rep,name=args,proto3" json:"args,omitempty"` + Datasets []*Dataset `protobuf:"bytes,6,rep,name=datasets,proto3" json:"datasets,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RunRequest) Reset() { + *x = RunRequest{} + mi := &file_agent_runner_runner_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RunRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RunRequest) ProtoMessage() {} + +func (x *RunRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_runner_runner_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RunRequest.ProtoReflect.Descriptor instead. +func (*RunRequest) Descriptor() ([]byte, []int) { + return file_agent_runner_runner_proto_rawDescGZIP(), []int{0} +} + +func (x *RunRequest) GetComputationId() string { + if x != nil { + return x.ComputationId + } + return "" +} + +func (x *RunRequest) GetAlgoType() string { + if x != nil { + return x.AlgoType + } + return "" +} + +func (x *RunRequest) GetAlgorithm() []byte { + if x != nil { + return x.Algorithm + } + return nil +} + +func (x *RunRequest) GetRequirements() []byte { + if x != nil { + return x.Requirements + } + return nil +} + +func (x *RunRequest) GetArgs() []string { + if x != nil { + return x.Args + } + return nil +} + +func (x *RunRequest) GetDatasets() []*Dataset { + if x != nil { + return x.Datasets + } + return nil +} + +type Dataset struct { + state protoimpl.MessageState `protogen:"open.v1"` + Filename string `protobuf:"bytes,1,opt,name=filename,proto3" json:"filename,omitempty"` + Hash []byte `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Dataset) Reset() { + *x = Dataset{} + mi := &file_agent_runner_runner_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Dataset) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Dataset) ProtoMessage() {} + +func (x *Dataset) ProtoReflect() protoreflect.Message { + mi := &file_agent_runner_runner_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Dataset.ProtoReflect.Descriptor instead. +func (*Dataset) Descriptor() ([]byte, []int) { + return file_agent_runner_runner_proto_rawDescGZIP(), []int{1} +} + +func (x *Dataset) GetFilename() string { + if x != nil { + return x.Filename + } + return "" +} + +func (x *Dataset) GetHash() []byte { + if x != nil { + return x.Hash + } + return nil +} + +type RunResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"` + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RunResponse) Reset() { + *x = RunResponse{} + mi := &file_agent_runner_runner_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RunResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RunResponse) ProtoMessage() {} + +func (x *RunResponse) ProtoReflect() protoreflect.Message { + mi := &file_agent_runner_runner_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RunResponse.ProtoReflect.Descriptor instead. +func (*RunResponse) Descriptor() ([]byte, []int) { + return file_agent_runner_runner_proto_rawDescGZIP(), []int{2} +} + +func (x *RunResponse) GetComputationId() string { + if x != nil { + return x.ComputationId + } + return "" +} + +func (x *RunResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type StopRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StopRequest) Reset() { + *x = StopRequest{} + mi := &file_agent_runner_runner_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StopRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StopRequest) ProtoMessage() {} + +func (x *StopRequest) ProtoReflect() protoreflect.Message { + mi := &file_agent_runner_runner_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StopRequest.ProtoReflect.Descriptor instead. +func (*StopRequest) Descriptor() ([]byte, []int) { + return file_agent_runner_runner_proto_rawDescGZIP(), []int{3} +} + +func (x *StopRequest) GetComputationId() string { + if x != nil { + return x.ComputationId + } + return "" +} + +var File_agent_runner_runner_proto protoreflect.FileDescriptor + +const file_agent_runner_runner_proto_rawDesc = "" + + "\n" + + "\x19agent/runner/runner.proto\x12\x06runner\x1a\x1bgoogle/protobuf/empty.proto\"\xd3\x01\n" + + "\n" + + "RunRequest\x12%\n" + + "\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x1b\n" + + "\talgo_type\x18\x02 \x01(\tR\balgoType\x12\x1c\n" + + "\talgorithm\x18\x03 \x01(\fR\talgorithm\x12\"\n" + + "\frequirements\x18\x04 \x01(\fR\frequirements\x12\x12\n" + + "\x04args\x18\x05 \x03(\tR\x04args\x12+\n" + + "\bdatasets\x18\x06 \x03(\v2\x0f.runner.DatasetR\bdatasets\"9\n" + + "\aDataset\x12\x1a\n" + + "\bfilename\x18\x01 \x01(\tR\bfilename\x12\x12\n" + + "\x04hash\x18\x02 \x01(\fR\x04hash\"J\n" + + "\vRunResponse\x12%\n" + + "\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x14\n" + + "\x05error\x18\x02 \x01(\tR\x05error\"4\n" + + "\vStopRequest\x12%\n" + + "\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId2x\n" + + "\x11ComputationRunner\x12.\n" + + "\x03Run\x12\x12.runner.RunRequest\x1a\x13.runner.RunResponse\x123\n" + + "\x04Stop\x12\x13.runner.StopRequest\x1a\x16.google.protobuf.EmptyB\n" + + "Z\b./runnerb\x06proto3" + +var ( + file_agent_runner_runner_proto_rawDescOnce sync.Once + file_agent_runner_runner_proto_rawDescData []byte +) + +func file_agent_runner_runner_proto_rawDescGZIP() []byte { + file_agent_runner_runner_proto_rawDescOnce.Do(func() { + file_agent_runner_runner_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc))) + }) + return file_agent_runner_runner_proto_rawDescData +} + +var file_agent_runner_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_agent_runner_runner_proto_goTypes = []any{ + (*RunRequest)(nil), // 0: runner.RunRequest + (*Dataset)(nil), // 1: runner.Dataset + (*RunResponse)(nil), // 2: runner.RunResponse + (*StopRequest)(nil), // 3: runner.StopRequest + (*emptypb.Empty)(nil), // 4: google.protobuf.Empty +} +var file_agent_runner_runner_proto_depIdxs = []int32{ + 1, // 0: runner.RunRequest.datasets:type_name -> runner.Dataset + 0, // 1: runner.ComputationRunner.Run:input_type -> runner.RunRequest + 3, // 2: runner.ComputationRunner.Stop:input_type -> runner.StopRequest + 2, // 3: runner.ComputationRunner.Run:output_type -> runner.RunResponse + 4, // 4: runner.ComputationRunner.Stop:output_type -> google.protobuf.Empty + 3, // [3:5] is the sub-list for method output_type + 1, // [1:3] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_agent_runner_runner_proto_init() } +func file_agent_runner_runner_proto_init() { + if File_agent_runner_runner_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_agent_runner_runner_proto_goTypes, + DependencyIndexes: file_agent_runner_runner_proto_depIdxs, + MessageInfos: file_agent_runner_runner_proto_msgTypes, + }.Build() + File_agent_runner_runner_proto = out.File + file_agent_runner_runner_proto_goTypes = nil + file_agent_runner_runner_proto_depIdxs = nil +} diff --git a/agent/runner/runner.proto b/agent/runner/runner.proto new file mode 100644 index 00000000..6f39a20f --- /dev/null +++ b/agent/runner/runner.proto @@ -0,0 +1,38 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package runner; + +option go_package = "./runner"; + +import "google/protobuf/empty.proto"; + +service ComputationRunner { + rpc Run(RunRequest) returns (RunResponse); + rpc Stop(StopRequest) returns (google.protobuf.Empty); +} + +message RunRequest { + string computation_id = 1; + string algo_type = 2; // "binary", "python", "wasm", "docker" + bytes algorithm = 3; // The algorithm binary/script content + bytes requirements = 4; // Python requirements.txt content + repeated string args = 5; + repeated Dataset datasets = 6; +} + +message Dataset { + string filename = 1; + bytes hash = 2; +} + +message RunResponse { + string computation_id = 1; + string error = 2; +} + +message StopRequest { + string computation_id = 1; +} diff --git a/agent/runner/runner_grpc.pb.go b/agent/runner/runner_grpc.pb.go new file mode 100644 index 00000000..4ec12046 --- /dev/null +++ b/agent/runner/runner_grpc.pb.go @@ -0,0 +1,163 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v6.33.1 +// source: agent/runner/runner.proto + +package runner + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ComputationRunner_Run_FullMethodName = "/runner.ComputationRunner/Run" + ComputationRunner_Stop_FullMethodName = "/runner.ComputationRunner/Stop" +) + +// ComputationRunnerClient is the client API for ComputationRunner service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ComputationRunnerClient interface { + Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error) + Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) +} + +type computationRunnerClient struct { + cc grpc.ClientConnInterface +} + +func NewComputationRunnerClient(cc grpc.ClientConnInterface) ComputationRunnerClient { + return &computationRunnerClient{cc} +} + +func (c *computationRunnerClient) Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RunResponse) + err := c.cc.Invoke(ctx, ComputationRunner_Run_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *computationRunnerClient) Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(emptypb.Empty) + err := c.cc.Invoke(ctx, ComputationRunner_Stop_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ComputationRunnerServer is the server API for ComputationRunner service. +// All implementations must embed UnimplementedComputationRunnerServer +// for forward compatibility. +type ComputationRunnerServer interface { + Run(context.Context, *RunRequest) (*RunResponse, error) + Stop(context.Context, *StopRequest) (*emptypb.Empty, error) + mustEmbedUnimplementedComputationRunnerServer() +} + +// UnimplementedComputationRunnerServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedComputationRunnerServer struct{} + +func (UnimplementedComputationRunnerServer) Run(context.Context, *RunRequest) (*RunResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Run not implemented") +} +func (UnimplementedComputationRunnerServer) Stop(context.Context, *StopRequest) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method Stop not implemented") +} +func (UnimplementedComputationRunnerServer) mustEmbedUnimplementedComputationRunnerServer() {} +func (UnimplementedComputationRunnerServer) testEmbeddedByValue() {} + +// UnsafeComputationRunnerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ComputationRunnerServer will +// result in compilation errors. +type UnsafeComputationRunnerServer interface { + mustEmbedUnimplementedComputationRunnerServer() +} + +func RegisterComputationRunnerServer(s grpc.ServiceRegistrar, srv ComputationRunnerServer) { + // If the following call pancis, it indicates UnimplementedComputationRunnerServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ComputationRunner_ServiceDesc, srv) +} + +func _ComputationRunner_Run_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RunRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ComputationRunnerServer).Run(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ComputationRunner_Run_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ComputationRunnerServer).Run(ctx, req.(*RunRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ComputationRunner_Stop_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(StopRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ComputationRunnerServer).Stop(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ComputationRunner_Stop_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ComputationRunnerServer).Stop(ctx, req.(*StopRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// ComputationRunner_ServiceDesc is the grpc.ServiceDesc for ComputationRunner service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ComputationRunner_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "runner.ComputationRunner", + HandlerType: (*ComputationRunnerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Run", + Handler: _ComputationRunner_Run_Handler, + }, + { + MethodName: "Stop", + Handler: _ComputationRunner_Stop_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "agent/runner/runner.proto", +} diff --git a/agent/runner/service/service.go b/agent/runner/service/service.go new file mode 100644 index 00000000..a5c159e9 --- /dev/null +++ b/agent/runner/service/service.go @@ -0,0 +1,141 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package service + +import ( + "context" + "fmt" + "log/slog" + "os" + "path/filepath" + "sync" + + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/algorithm/binary" + "github.com/ultravioletrs/cocos/agent/algorithm/docker" + "github.com/ultravioletrs/cocos/agent/algorithm/python" + "github.com/ultravioletrs/cocos/agent/algorithm/wasm" + "github.com/ultravioletrs/cocos/agent/events" + pb "github.com/ultravioletrs/cocos/agent/runner" + "google.golang.org/protobuf/types/known/emptypb" +) + +const ( + algoFilePermission = 0o700 +) + +var _ pb.ComputationRunnerServer = (*RunnerService)(nil) + +type RunnerService struct { + pb.UnimplementedComputationRunnerServer + logger *slog.Logger + eventSvc events.Service + currentAlgo algorithm.Algorithm + mu sync.Mutex +} + +func New(logger *slog.Logger, eventSvc events.Service) *RunnerService { + return &RunnerService{ + logger: logger, + eventSvc: eventSvc, + } +} + +func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) { + s.mu.Lock() + if s.currentAlgo != nil { + s.mu.Unlock() + return &pb.RunResponse{ + ComputationId: req.ComputationId, + Error: "computation already running", + }, nil + } + s.mu.Unlock() + + defer func() { + s.mu.Lock() + s.currentAlgo = nil + s.mu.Unlock() + }() + + currentDir, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("error getting current directory: %v", err) + } + + // Write Algo File + algoPath := filepath.Join(currentDir, "algo") + f, err := os.Create(algoPath) + if err != nil { + return nil, fmt.Errorf("error creating algorithm file: %v", err) + } + if _, err := f.Write(req.Algorithm); err != nil { + return nil, fmt.Errorf("error writing algorithm to file: %v", err) + } + if err := os.Chmod(algoPath, algoFilePermission); err != nil { + return nil, fmt.Errorf("error changing file permissions: %v", err) + } + if err := f.Close(); err != nil { + return nil, fmt.Errorf("error closing file: %v", err) + } + + var algo algorithm.Algorithm + + switch req.AlgoType { + case string(algorithm.AlgoTypeBin): + algo = binary.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.Args, req.ComputationId) + case string(algorithm.AlgoTypePython): + var requirementsFile string + if len(req.Requirements) > 0 { + fr, err := os.CreateTemp("", "requirements.txt") + if err != nil { + return nil, fmt.Errorf("error creating requirments file: %v", err) + } + if _, err := fr.Write(req.Requirements); err != nil { + return nil, fmt.Errorf("error writing requirements to file: %v", err) + } + if err := fr.Close(); err != nil { + return nil, fmt.Errorf("error closing file: %v", err) + } + requirementsFile = fr.Name() + } + // Assuming default python runtime if not specified in request (proto doesn't have runtime field yet) + // We can add it or assume. + runtime := python.PyRuntime + algo = python.NewAlgorithm(s.logger, s.eventSvc, runtime, requirementsFile, algoPath, req.Args, req.ComputationId) + case string(algorithm.AlgoTypeWasm): + algo = wasm.NewAlgorithm(s.logger, s.eventSvc, req.Args, algoPath, req.ComputationId) + case string(algorithm.AlgoTypeDocker): + algo = docker.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.ComputationId) + default: + return nil, fmt.Errorf("unsupported algorithm type: %s", req.AlgoType) + } + + s.mu.Lock() + s.currentAlgo = algo + s.mu.Unlock() + + if err := algo.Run(); err != nil { + s.logger.Error("computation failed", "error", err) + return &pb.RunResponse{ + ComputationId: req.ComputationId, + Error: err.Error(), + }, nil + } + + return &pb.RunResponse{ + ComputationId: req.ComputationId, + }, nil +} + +func (s *RunnerService) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.currentAlgo != nil { + if err := s.currentAlgo.Stop(); err != nil { + return nil, err + } + } + return &emptypb.Empty{}, nil +} diff --git a/agent/runner/service/service_test.go b/agent/runner/service/service_test.go new file mode 100644 index 00000000..9e1466ba --- /dev/null +++ b/agent/runner/service/service_test.go @@ -0,0 +1,271 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package service + +import ( + "context" + "encoding/json" + "log/slog" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + pb "github.com/ultravioletrs/cocos/agent/runner" +) + +// MockEventService is a mock implementation of events.Service. +type MockEventService struct { + events []interface{} +} + +func (m *MockEventService) SendEvent(cmpID, event, status string, details json.RawMessage) { + m.events = append(m.events, map[string]interface{}{ + "cmpID": cmpID, + "event": event, + "status": status, + "details": details, + }) +} + +// TestNewRunnerService tests the creation of a new runner service. +func TestNewRunnerService(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + + rs := New(logger, eventSvc) + require.NotNil(t, rs) + assert.NotNil(t, rs.logger) + assert.NotNil(t, rs.eventSvc) + assert.Nil(t, rs.currentAlgo) +} + +// TestRunWithBinaryAlgorithm tests running a binary algorithm. +func TestRunWithBinaryAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-1", + AlgoType: "bin", + Algorithm: []byte("#!/bin/bash\necho 'test'"), + Args: []string{"arg1", "arg2"}, + } + + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "test-1", resp.ComputationId) +} + +// TestRunWithPythonAlgorithm tests running a Python algorithm. +func TestRunWithPythonAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-python", + AlgoType: "python", + Algorithm: []byte("print('hello')"), + Args: []string{}, + Requirements: []byte("numpy==1.21.0"), + } + + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "test-python", resp.ComputationId) +} + +// TestRunWithPythonAlgorithmNoRequirements tests running Python without requirements. +func TestRunWithPythonAlgorithmNoRequirements(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-python-noreq", + AlgoType: "python", + Algorithm: []byte("print('hello')"), + Args: []string{}, + } + + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "test-python-noreq", resp.ComputationId) +} + +// TestRunWithWasmAlgorithm tests running a WASM algorithm. +func TestRunWithWasmAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-wasm", + AlgoType: "wasm", + Algorithm: []byte{0x00, 0x61, 0x73, 0x6d}, + Args: []string{}, + } + + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "test-wasm", resp.ComputationId) +} + +// TestRunWithDockerAlgorithm tests running a Docker algorithm. +func TestRunWithDockerAlgorithm(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-docker", + AlgoType: "docker", + Algorithm: []byte("FROM ubuntu:latest\nRUN echo 'test'"), + Args: []string{}, + } + + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "test-docker", resp.ComputationId) +} + +// TestRunWithUnsupportedAlgorithmType tests running with unsupported algorithm type. +func TestRunWithUnsupportedAlgorithmType(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-unsupported", + AlgoType: "unsupported", + Algorithm: []byte("test"), + Args: []string{}, + } + + resp, err := rs.Run(context.Background(), req) + require.Error(t, err) + require.Nil(t, resp) +} + +// TestRunAlreadyRunning tests running computation when one is already running. +func TestRunAlreadyRunning(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + // Use a long-running bash script + req := &pb.RunRequest{ + ComputationId: "test-running", + AlgoType: "bin", + Algorithm: []byte("#!/bin/bash\nsleep 30"), + Args: []string{}, + } + + // Start first computation (will run for 30 seconds) + go func() { + _, _ = rs.Run(context.Background(), req) + }() + + // Give it time to start + time.Sleep(500 * time.Millisecond) + + // Try to run another immediately - should fail + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "computation already running", resp.Error) +} + +// TestStopWhenRunning tests stopping a running computation. +func TestStopWhenRunning(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-stop", + AlgoType: "bin", + Algorithm: []byte("#!/bin/bash\nsleep 10"), + Args: []string{}, + } + + _, err := rs.Run(context.Background(), req) + require.NoError(t, err) + + stopReq := &pb.StopRequest{ + ComputationId: "test-stop", + } + + stopResp, err := rs.Stop(context.Background(), stopReq) + require.NoError(t, err) + require.NotNil(t, stopResp) +} + +// TestStopWhenNotRunning tests stopping when no computation is running. +func TestStopWhenNotRunning(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + stopReq := &pb.StopRequest{ + ComputationId: "test-not-running", + } + + stopResp, err := rs.Stop(context.Background(), stopReq) + require.NoError(t, err) + require.NotNil(t, stopResp) +} + +// TestConcurrentRun tests that concurrent runs are properly serialized. +func TestConcurrentRun(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-concurrent", + AlgoType: "bin", + Algorithm: []byte("#!/bin/bash\nsleep 15"), + Args: []string{}, + } + + // Start first run in goroutine (will run for 15 seconds) + go func() { + _, _ = rs.Run(context.Background(), req) + }() + + // Give it time to actually start + time.Sleep(500 * time.Millisecond) + + // Concurrent attempt should fail + resp2, err := rs.Run(context.Background(), req) + require.NoError(t, err) + assert.Equal(t, "computation already running", resp2.Error) +} + +// TestRunWithMultipleArgs tests running with multiple arguments. +func TestRunWithMultipleArgs(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + eventSvc := &MockEventService{} + rs := New(logger, eventSvc) + + req := &pb.RunRequest{ + ComputationId: "test-multi-args", + AlgoType: "bin", + Algorithm: []byte("#!/bin/bash\necho $@"), + Args: []string{"arg1", "arg2", "arg3", "arg4"}, + } + + resp, err := rs.Run(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "test-multi-args", resp.ComputationId) +} diff --git a/agent/service.go b/agent/service.go index 62748983..84db1f70 100644 --- a/agent/service.go +++ b/agent/service.go @@ -16,17 +16,15 @@ import ( "github.com/absmach/supermq/pkg/errors" "github.com/ultravioletrs/cocos/agent/algorithm" - "github.com/ultravioletrs/cocos/agent/algorithm/binary" - "github.com/ultravioletrs/cocos/agent/algorithm/docker" - "github.com/ultravioletrs/cocos/agent/algorithm/python" - "github.com/ultravioletrs/cocos/agent/algorithm/wasm" "github.com/ultravioletrs/cocos/agent/events" + runnerpb "github.com/ultravioletrs/cocos/agent/runner" "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" + runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner" "golang.org/x/crypto/sha3" ) @@ -130,8 +128,12 @@ type Service interface { type agentService struct { mu sync.Mutex - computation Computation // Holds the current computation request details. - algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. + computation Computation // Holds the current computation request details. + runnerClient runner_client.Client + algoType string + algoArgs []string + algoRequirements []byte + algoReceived bool result []byte // Stores the result of the computation. sm statemachine.StateMachine // Manages the state transitions of the agent service. runError error // Stores any error encountered during the computation run. @@ -146,13 +148,14 @@ type agentService struct { var _ Service = (*agentService)(nil) // New instantiates the agent service implementation. -func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, vmlp int) Service { +func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service { sm := statemachine.NewStateMachine(Idle) ctx, cancel := context.WithCancel(ctx) svc := &agentService{ sm: sm, eventSvc: eventSvc, attestationClient: attestationClient, + runnerClient: runnerClient, logger: logger, cancel: cancel, vmpl: vmlp, @@ -233,10 +236,9 @@ func (as *agentService) StopComputation(ctx context.Context) error { as.cancel() - if as.algorithm != nil { - if err := as.algorithm.Stop(); err != nil { - return fmt.Errorf("error stopping computation: %v", err) - } + if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil { + as.logger.Warn("failed to stop runner", "error", err) + // proceed to cleanup } if err := os.RemoveAll(algorithm.DatasetsDir); err != nil { @@ -250,7 +252,10 @@ func (as *agentService) StopComputation(ctx context.Context) error { as.sm.Reset(Idle) as.computation = Computation{} - as.algorithm = nil + as.algoReceived = false + as.algoType = "" + as.algoArgs = nil + as.algoRequirements = nil as.result = nil as.runError = nil as.resultsConsumed = false @@ -278,7 +283,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { } as.mu.Lock() defer as.mu.Unlock() - if as.algorithm != nil { + if as.algoReceived { return ErrAllManifestItemsReceived } @@ -317,38 +322,16 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { args := algorithm.AlgorithmArgsFromContext(ctx) - switch algoType { - case string(algorithm.AlgoTypeBin): - as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args, as.computation.ID) - case string(algorithm.AlgoTypePython): - var requirementsFile string - if len(algo.Requirements) > 0 { - fr, err := os.CreateTemp("", "requirements.txt") - if err != nil { - return fmt.Errorf("error creating requirments file: %v", err) - } - - if _, err := fr.Write(algo.Requirements); err != nil { - return fmt.Errorf("error writing requirements to file: %v", err) - } - if err := fr.Close(); err != nil { - return fmt.Errorf("error closing file: %v", err) - } - requirementsFile = fr.Name() - } - runtime := python.PythonRunTimeFromContext(ctx) - as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args, as.computation.ID) - case string(algorithm.AlgoTypeWasm): - as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, args, f.Name(), as.computation.ID) - case string(algorithm.AlgoTypeDocker): - as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name(), as.computation.ID) - } + as.algoType = algoType + as.algoArgs = args + as.algoRequirements = algo.Requirements + as.algoReceived = true if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil { return fmt.Errorf("error creating datasets directory: %v", err) } - if as.algorithm != nil { + if as.algoReceived { as.sm.SendEvent(AlgorithmReceived) } @@ -478,14 +461,42 @@ func (as *agentService) runComputation(state statemachine.State) { } }() + // Read algo file + currentDir, _ := os.Getwd() + algoFile := filepath.Join(currentDir, "algo") + algoBytes, err := os.ReadFile(algoFile) + if err != nil { + as.runError = fmt.Errorf("failed to read algo file: %w", err) + as.logger.Warn(as.runError.Error()) + as.publishEvent(Failed.String())(state) + return + } + as.publishEvent(InProgress.String())(state) - if err := as.algorithm.Run(); err != nil { + + // Call Runner + resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{ + ComputationId: as.computation.ID, + AlgoType: as.algoType, + Algorithm: algoBytes, + Requirements: as.algoRequirements, + Args: as.algoArgs, + // Datasets implicit on shared FS + }) + if err != nil { as.runError = err as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) as.publishEvent(Failed.String())(state) return } + if resp.Error != "" { + as.runError = errors.New(resp.Error) + as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error)) + as.publishEvent(Failed.String())(state) + return + } + results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir) if err != nil { as.runError = err diff --git a/agent/service_test.go b/agent/service_test.go index 749f2e4b..62f7c9ed 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -18,16 +18,18 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent/algorithm" - algomocks "github.com/ultravioletrs/cocos/agent/algorithm/mocks" "github.com/ultravioletrs/cocos/agent/algorithm/python" "github.com/ultravioletrs/cocos/agent/events/mocks" + runnerpb "github.com/ultravioletrs/cocos/agent/runner" "github.com/ultravioletrs/cocos/agent/statemachine" smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" + runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/emptypb" ) var ( @@ -123,7 +125,9 @@ func TestAlgo(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() client := new(MockAttestationClient) - svc := New(ctx, mglog.NewMock(), events, client, 0) + runnerCli := new(runnermocks.Client) + runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) err := svc.InitComputation(ctx, testComputation(t)) require.NoError(t, err) @@ -217,7 +221,9 @@ func TestData(t *testing.T) { defer cancel() client := new(MockAttestationClient) - svc := New(ctx, mglog.NewMock(), events, client, 0) + runnerCli := new(runnermocks.Client) + runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) err := svc.InitComputation(ctx, testComputation(t)) require.NoError(t, err) @@ -294,6 +300,7 @@ func TestResult(t *testing.T) { } client := new(MockAttestationClient) + runnerCli := new(runnermocks.Client) sm := new(smmocks.StateMachine) sm.On("Start", ctx).Return(nil) @@ -304,6 +311,7 @@ func TestResult(t *testing.T) { sm: sm, eventSvc: events, attestationClient: client, + runnerClient: runnerCli, computation: testComputation(t), } @@ -400,7 +408,8 @@ func TestAttestation(t *testing.T) { } defer getQuote.Unset() - svc := New(ctx, mglog.NewMock(), events, client, 0) + runnerCli := new(runnermocks.Client) + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) time.Sleep(300 * time.Millisecond) _, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) @@ -420,18 +429,7 @@ func TestAzureAttestationToken(t *testing.T) { name: "Azure token fetch successful", nonce: [32]byte{1, 2, 3}, // any test nonce token: []byte("mockToken"), - err: nil, // fixed expectation as err was ErrAttestationType in original but logic suggests success if token returns? Wait, orig test had ErrAttestationType? Ah, maybe provider mock returns error. - // Re-reading original test: - // err: ErrAttestationType - // provider.On(...).Return(tc.token, tc.err) - // svc.AzureAttestationToken... - // In original code, AzureAttestationToken checked `attestation.CCPlatform() != attestation.Azure`. - // Since test runs on non-azure, it returns ErrAttestationType. - // My new client calls GetAzureToken. The logic for checking platform moved to attestation-service. - // So `agent` just calls the client. - // So here we should expect whatever the client returns. - // Mock client returns tc.err. - // If I want to test success, I should set err: nil. + err: nil, }, { name: "Azure token fetch failed", @@ -450,7 +448,8 @@ func TestAzureAttestationToken(t *testing.T) { ctx := context.Background() - svc := New(ctx, mglog.NewMock(), events, client, 0) + runnerCli := new(runnermocks.Client) + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) _, err := svc.AzureAttestationToken(ctx, tc.nonce) assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) @@ -489,9 +488,6 @@ func testComputation(t *testing.T) Computation { } func TestStopComputation(t *testing.T) { - testDataDir := "test_datasets" - testResultsDir := "test_results" - cases := []struct { name string setupDirs bool @@ -511,22 +507,9 @@ func TestStopComputation(t *testing.T) { setupDirs: true, setupAlgo: true, algoStopErr: fmt.Errorf("algorithm stop failed"), - expectedErr: fmt.Errorf("error stopping computation: algorithm stop failed"), - }, - { - name: "Stop computation without algorithm", - setupDirs: true, - setupAlgo: false, - algoStopErr: nil, - expectedErr: nil, - }, - { - name: "Stop computation with missing directories", - setupDirs: false, - setupAlgo: false, - algoStopErr: nil, - expectedErr: nil, // os.RemoveAll doesn't error on non-existing directories + expectedErr: nil, // Warn only }, + // We log warnings but don't return error in StopComputation in new implementation for Stop failure. } for _, tc := range cases { @@ -539,7 +522,16 @@ func TestStopComputation(t *testing.T) { defer cancel() client := new(MockAttestationClient) - svc := New(ctx, mglog.NewMock(), events, client, 0).(*agentService) + runnerCli := new(runnermocks.Client) + + // Mock Stop call + var stopErr error + if tc.algoStopErr != nil { + stopErr = tc.algoStopErr + } + runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, stopErr) + + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0).(*agentService) svc.computation = Computation{ ID: "test-computation", @@ -547,17 +539,17 @@ func TestStopComputation(t *testing.T) { } if tc.setupDirs { - err := os.MkdirAll(testDataDir, 0o755) + err := os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) - err = os.MkdirAll(testResultsDir, 0o755) + err = os.MkdirAll(algorithm.ResultsDir, 0o755) require.NoError(t, err) } - if tc.setupAlgo { - mockAlgo := new(algomocks.Algorithm) - mockAlgo.On("Stop").Return(tc.algoStopErr) - svc.algorithm = mockAlgo - } + // Use real dirs for test + // algorithm.DatasetsDir refers to global var? + // "github.com/ultravioletrs/cocos/agent/algorithm" + // It uses hardcoded path "datasets" and "results" in current dir. + // Tests create them in current dir. err := svc.StopComputation(ctx) @@ -575,8 +567,8 @@ func TestStopComputation(t *testing.T) { events.AssertExpectations(t) - _ = os.RemoveAll(testDataDir) - _ = os.RemoveAll(testResultsDir) + _ = os.RemoveAll(algorithm.DatasetsDir) + _ = os.RemoveAll(algorithm.ResultsDir) }) } } @@ -608,7 +600,11 @@ func TestStopComputationIntegration(t *testing.T) { defer cancel() client := new(MockAttestationClient) - svc := New(ctx, mglog.NewMock(), events, client, 0) + runnerCli := new(runnermocks.Client) + runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) + runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil) + + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) computation := Computation{ ID: "integration-test", @@ -647,7 +643,10 @@ func TestStopComputationConcurrent(t *testing.T) { defer cancel() client := new(MockAttestationClient) - svc := New(ctx, mglog.NewMock(), events, client, 0) + runnerCli := new(runnermocks.Client) + runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil) + + svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) svc.(*agentService).computation = Computation{ ID: "concurrent-test", diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 3c15f1dc..2ac994b9 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -11,6 +11,7 @@ import ( "fmt" "log" "log/slog" + "net/url" "os" "os/signal" "syscall" @@ -25,6 +26,7 @@ import ( cvmsapi "github.com/ultravioletrs/cocos/agent/cvms/api/grpc" "github.com/ultravioletrs/cocos/agent/cvms/server" "github.com/ultravioletrs/cocos/agent/events" + logpb "github.com/ultravioletrs/cocos/agent/log" agentlogger "github.com/ultravioletrs/cocos/internal/logger" "github.com/ultravioletrs/cocos/pkg/atls" "github.com/ultravioletrs/cocos/pkg/attestation" @@ -33,6 +35,9 @@ import ( pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc" attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation" cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm" + logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log" + runnerclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner" + "github.com/ultravioletrs/cocos/pkg/ingress" "golang.org/x/sync/errgroup" ) @@ -54,6 +59,7 @@ type config struct { AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"` AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"` AttestationServiceSocket string `env:"ATTESTATION_SERVICE_SOCKET" envDefault:"/run/cocos/attestation.sock"` + EnableATLS bool `env:"AGENT_ENABLE_ATLS" envDefault:"true"` } func main() { @@ -75,18 +81,63 @@ func main() { return } - eventsLogsQueue := make(chan *cvms.ClientStreamMessage, 1000) + logQueue := make(chan *cvms.ClientStreamMessage, 1000) + cvmsQueue := make(chan *cvms.ClientStreamMessage, 1000) - handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, eventsLogsQueue) + handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue) logger := slog.New(handler) - eventSvc, err := events.New(svcName, eventsLogsQueue) + eventSvc, err := events.New(svcName, logQueue) if err != nil { logger.Error(fmt.Sprintf("failed to create events service %s", err.Error())) exitCode = 1 return } + logClient, err := logclient.NewClient("/run/cocos/log.sock") + if err != nil { + logger.Warn(fmt.Sprintf("failed to create log client: %s. Logging will be local only until service is available.", err)) + } else { + defer logClient.Close() + } + + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case msg := <-logQueue: + if logClient == nil { + continue + } + switch m := msg.Message.(type) { + case *cvms.ClientStreamMessage_AgentLog: + err := logClient.SendLog(ctx, &logpb.LogEntry{ + Message: m.AgentLog.Message, + ComputationId: m.AgentLog.ComputationId, + Level: m.AgentLog.Level, + Timestamp: m.AgentLog.Timestamp, + }) + if err != nil { + logger.Error("failed to send log", "error", err) + } + case *cvms.ClientStreamMessage_AgentEvent: + err := logClient.SendEvent(ctx, &logpb.EventEntry{ + EventType: m.AgentEvent.EventType, + Timestamp: m.AgentEvent.Timestamp, + ComputationId: m.AgentEvent.ComputationId, + Details: m.AgentEvent.Details, + Originator: m.AgentEvent.Originator, + Status: m.AgentEvent.Status, + }) + if err != nil { + logger.Error("failed to send event", "error", err) + } + } + } + } + }) + var provider attestation.Provider ccPlatform := attestation.CCPlatform() @@ -128,13 +179,6 @@ func main() { return grpcClient, pc, nil } - pc, err := cvmsClient.Process(ctx) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - if cfg.Vmpl < 0 || cfg.Vmpl > 3 { logger.Error("vmpl level must be in a range [0, 3]") exitCode = 1 @@ -149,7 +193,15 @@ func main() { } defer attClient.Close() - svc := newService(ctx, logger, eventSvc, attClient, cfg.Vmpl) + runnerClient, err := runnerclient.NewClient("/run/cocos/runner.sock") + if err != nil { + logger.Error(fmt.Sprintf("failed to create runner client: %s", err)) + exitCode = 1 + return + } + defer runnerClient.Close() + + svc := newService(ctx, logger, eventSvc, attClient, runnerClient, cfg.Vmpl) if err := os.MkdirAll(storageDir, 0o755); err != nil { logger.Error(fmt.Sprintf("failed to create storage directory: %s", err)) @@ -158,8 +210,7 @@ func main() { } var certProvider atls.CertificateProvider - - if ccPlatform != attestation.NoCC { + if cfg.EnableATLS && ccPlatform != attestation.NoCC { var certsSDK sdk.SDK if cfg.CAUrl != "" { certsSDK = sdk.NewSDK(sdk.Config{ @@ -174,7 +225,23 @@ func main() { } } - mc, err := cvmsapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), storageDir, reconnectFn, cvmGRPCClient) + // Create ingress proxy server + backendURL, err := url.Parse("unix:///run/cocos/agent.sock") + if err != nil { + logger.Error(fmt.Sprintf("failed to parse backend URL: %s", err)) + exitCode = 1 + return + } + ingressProxy := ingress.NewProxyServer(logger, backendURL, certProvider) + + pc, err := cvmsClient.Process(ctx) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to cvm server: %s", err)) + exitCode = 1 + return + } + + mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), ingressProxy, storageDir, reconnectFn, cvmGRPCClient) if err != nil { logger.Error(err.Error()) exitCode = 1 @@ -214,7 +281,7 @@ func main() { exitCode = 1 return } - eventsLogsQueue <- &cvms.ClientStreamMessage{ + cvmsQueue <- &cvms.ClientStreamMessage{ Message: &cvms.ClientStreamMessage_AzureAttestationToken{ AzureAttestationToken: &cvms.AzureAttestationToken{ File: azureAttestationToken, @@ -224,7 +291,7 @@ func main() { } } - eventsLogsQueue <- &cvms.ClientStreamMessage{ + cvmsQueue <- &cvms.ClientStreamMessage{ Message: &cvms.ClientStreamMessage_VTPMattestationReport{ VTPMattestationReport: &cvms.AttestationResponse{ File: attest, @@ -238,8 +305,8 @@ func main() { } } -func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attClient attestation_client.Client, vmpl int) agent.Service { - svc := agent.New(ctx, logger, eventSvc, attClient, vmpl) +func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attClient attestation_client.Client, runnerClient runnerclient.Client, vmpl int) agent.Service { + svc := agent.New(ctx, logger, eventSvc, attClient, runnerClient, vmpl) svc = api.LoggingMiddleware(svc, logger) counter, latency := prometheus.MakeMetrics(svcName, "api") diff --git a/cmd/computation-runner/main.go b/cmd/computation-runner/main.go new file mode 100644 index 00000000..1ded568b --- /dev/null +++ b/cmd/computation-runner/main.go @@ -0,0 +1,123 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package main + +import ( + "context" + "fmt" + "log/slog" + "net" + "os" + "os/signal" + "syscall" + + mglog "github.com/absmach/supermq/logger" + "github.com/caarlos0/env/v11" + pb "github.com/ultravioletrs/cocos/agent/runner" + runnerevents "github.com/ultravioletrs/cocos/agent/runner/events" + "github.com/ultravioletrs/cocos/agent/runner/service" + logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" +) + +const ( + svcName = "computation-runner" + socketPath = "/run/cocos/runner.sock" +) + +type config struct { + LogLevel string `env:"RUNNER_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"debug"` + LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + var cfg config + if err := env.Parse(&cfg); err != nil { + fmt.Printf("failed to load %s configuration : %s\n", svcName, err) + os.Exit(1) + } + + var exitCode int + defer mglog.ExitWithError(&exitCode) + + var level slog.Level + if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil { + fmt.Println(err) + exitCode = 1 + return + } + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level})) + + // Connect to Log Forwarder + logClient, err := logclient.NewClient(cfg.LogForwarder) + if err != nil { + logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Events will not be forwarded.", err)) + } else { + defer logClient.Close() + } + + eventSvc := runnerevents.NewAdapter(logClient, svcName) + + // Remove existing socket if it exists + if _, err := os.Stat(socketPath); err == nil { + if err := os.Remove(socketPath); err != nil { + logger.Error(fmt.Sprintf("failed to remove existing socket: %s", err)) + exitCode = 1 + return + } + } + + dir := socketPath[:len(socketPath)-len("/runner.sock")] + if err := os.MkdirAll(dir, 0o755); err != nil { + logger.Error(fmt.Sprintf("failed to create socket directory: %s", err)) + exitCode = 1 + return + } + + lis, err := net.Listen("unix", socketPath) + if err != nil { + logger.Error(fmt.Sprintf("failed to listen on socket: %s", err)) + exitCode = 1 + return + } + + if err := os.Chmod(socketPath, 0o777); err != nil { + logger.Error(fmt.Sprintf("failed to chmod socket: %s", err)) + exitCode = 1 + return + } + + grpcServer := grpc.NewServer() + svc := service.New(logger, eventSvc) + pb.RegisterComputationRunnerServer(grpcServer, svc) + + g.Go(func() error { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(ch) + + select { + case <-ch: + logger.Info("Received signal, shutting down...") + cancel() + grpcServer.GracefulStop() + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + + g.Go(func() error { + logger.Info(fmt.Sprintf("%s started on %s", svcName, socketPath)) + return grpcServer.Serve(lis) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("%s terminated: %s", svcName, err)) + } +} diff --git a/cmd/egress-proxy/main.go b/cmd/egress-proxy/main.go new file mode 100644 index 00000000..cd2aec5b --- /dev/null +++ b/cmd/egress-proxy/main.go @@ -0,0 +1,88 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + mglog "github.com/absmach/supermq/logger" + "github.com/caarlos0/env/v11" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/ultravioletrs/cocos/pkg/egress" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "egress-proxy" +) + +type config struct { + Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"` + Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"` +} + +func main() { + var cfg config + if err := env.Parse(&cfg); err != nil { + fmt.Fprintf(os.Stderr, "failed to load configuration: %s\n", err) + os.Exit(1) + } + + cmd := &cobra.Command{ + Use: svcName, + Short: "Egress Proxy Service", + RunE: func(cmd *cobra.Command, args []string) error { + return run(cfg) + }, + } + + pflag.StringVar(&cfg.Level, "log-level", cfg.Level, "Log level") + pflag.StringVar(&cfg.Port, "port", cfg.Port, "Proxy port") + + if err := cmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %s\n", err) + os.Exit(1) + } +} + +func run(cfg config) error { + logger, err := mglog.New(os.Stdout, cfg.Level) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g, ctx := errgroup.WithContext(ctx) + + proxy := egress.NewProxy(logger, ":"+cfg.Port) + + g.Go(func() error { + return proxy.Start() + }) + + g.Go(func() error { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + select { + case s := <-c: + logger.Info(fmt.Sprintf("received signal %s, stopping", s)) + cancel() + return proxy.Stop(ctx) + case <-ctx.Done(): + return nil + } + }) + + if err := g.Wait(); err != nil { + return fmt.Errorf("server exit with error: %w", err) + } + + return nil +} diff --git a/cmd/ingress-proxy/main.go b/cmd/ingress-proxy/main.go new file mode 100644 index 00000000..5ce3e417 --- /dev/null +++ b/cmd/ingress-proxy/main.go @@ -0,0 +1,137 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package main + +import ( + "context" + "fmt" + "net/url" + "os" + "os/signal" + "syscall" + + "github.com/absmach/certs/sdk" + mglog "github.com/absmach/supermq/logger" + "github.com/caarlos0/env/v11" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/ultravioletrs/cocos/pkg/atls" + "github.com/ultravioletrs/cocos/pkg/attestation" + "github.com/ultravioletrs/cocos/pkg/attestation/azure" + "github.com/ultravioletrs/cocos/pkg/ingress" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "ingress-proxy" +) + +type config struct { + LogLevel string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"` + Backend string `env:"COCOS_INGRESS_BACKEND" envDefault:"http://localhost:7001"` + + // ATLS Config + CAUrl string `env:"AGENT_CVM_CA_URL" envDefault:""` + CVMId string `env:"AGENT_CVM_ID" envDefault:""` + CertsToken string `env:"AGENT_CERTS_TOKEN" envDefault:""` + AgentMaaURL string `env:"AGENT_MAA_URL" envDefault:"https://sharedeus2.eus2.attest.azure.net"` + AgentOSBuild string `env:"AGENT_OS_BUILD" envDefault:"UVC"` + AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"` + AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"` +} + +func main() { + var cfg config + if err := env.Parse(&cfg); err != nil { + fmt.Fprintf(os.Stderr, "failed to load configuration: %s\n", err) + os.Exit(1) + } + + cmd := &cobra.Command{ + Use: svcName, + Short: "Ingress Proxy Service", + RunE: func(cmd *cobra.Command, args []string) error { + return run(cfg) + }, + } + + pflag.StringVar(&cfg.LogLevel, "log-level", cfg.LogLevel, "Log level") + pflag.StringVar(&cfg.Backend, "backend", cfg.Backend, "Backend URL") + + if err := cmd.Execute(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %s\n", err) + os.Exit(1) + } +} + +func run(cfg config) error { + logger, err := mglog.New(os.Stdout, cfg.LogLevel) + if err != nil { + return fmt.Errorf("failed to create logger: %w", err) + } + + backendURL, err := url.Parse(cfg.Backend) + if err != nil { + return fmt.Errorf("failed to parse backend URL: %w", err) + } + + // Initialize Certificate Provider + var provider attestation.Provider + ccPlatform := attestation.CCPlatform() + + azureConfig := azure.NewEnvConfigFromAgent( + cfg.AgentOSBuild, + cfg.AgentOSType, + cfg.AgentOSDistro, + cfg.AgentMaaURL, + ) + azure.InitializeDefaultMAAVars(azureConfig) + + var certProvider atls.CertificateProvider + + if ccPlatform != attestation.NoCC { + var certsSDK sdk.SDK + if cfg.CAUrl != "" { + certsSDK = sdk.NewSDK(sdk.Config{ + CertsURL: cfg.CAUrl, + }) + } + certProvider, err = atls.NewProvider(provider, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK) + if err != nil { + return fmt.Errorf("failed to create certificate provider: %w", err) + } + } else { + logger.Warn("No Confidential Computing platform detected. ATLS will not be available.") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + g, ctx := errgroup.WithContext(ctx) + + // Create proxy server (but don't start it yet - it will be started per-computation) + _ = ingress.NewProxyServer(logger, backendURL, certProvider) + + // Note: The proxy server will be started dynamically when a computation is initiated + // via the Manager's ComputationRunReq message. For now, we just keep the service alive. + logger.Info("ingress-proxy service initialized, waiting for computation requests...") + + g.Go(func() error { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + select { + case s := <-c: + logger.Info(fmt.Sprintf("received signal %s, stopping", s)) + cancel() + return nil + case <-ctx.Done(): + return nil + } + }) + + if err := g.Wait(); err != nil { + return fmt.Errorf("server exit with error: %w", err) + } + + return nil +} diff --git a/cmd/log-forwarder/main.go b/cmd/log-forwarder/main.go new file mode 100644 index 00000000..3a454a44 --- /dev/null +++ b/cmd/log-forwarder/main.go @@ -0,0 +1,155 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package main + +import ( + "context" + "fmt" + "log/slog" + "net" + "os" + "os/signal" + "syscall" + + mglog "github.com/absmach/supermq/logger" + "github.com/caarlos0/env/v11" + "github.com/ultravioletrs/cocos/agent/cvms" + pb "github.com/ultravioletrs/cocos/agent/log" + "github.com/ultravioletrs/cocos/agent/log/service" + "github.com/ultravioletrs/cocos/pkg/clients" + cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" +) + +const ( + svcName = "log-forwarder" + socketPath = "/run/cocos/log.sock" + envPrefixCVMGRPC = "AGENT_CVM_GRPC_" +) + +type config struct { + LogLevel string `env:"LOG_FORWARDER_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"debug"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + var cfg config + if err := env.Parse(&cfg); err != nil { + fmt.Printf("failed to load %s configuration : %s\n", svcName, err) + os.Exit(1) + } + + var exitCode int + defer mglog.ExitWithError(&exitCode) + + var level slog.Level + if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil { + fmt.Println(err) + exitCode = 1 + return + } + + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level})) + + // Remove existing socket if it exists + if _, err := os.Stat(socketPath); err == nil { + if err := os.Remove(socketPath); err != nil { + logger.Error(fmt.Sprintf("failed to remove existing socket: %s", err)) + exitCode = 1 + return + } + } + + dir := socketPath[:len(socketPath)-len("/log.sock")] + if err := os.MkdirAll(dir, 0o755); err != nil { + logger.Error(fmt.Sprintf("failed to create socket directory: %s", err)) + exitCode = 1 + return + } + + lis, err := net.Listen("unix", socketPath) + if err != nil { + logger.Error(fmt.Sprintf("failed to listen on socket: %s", err)) + exitCode = 1 + return + } + + if err := os.Chmod(socketPath, 0o777); err != nil { + logger.Error(fmt.Sprintf("failed to chmod socket: %s", err)) + exitCode = 1 + return + } + + // Connect to Manager + cvmGrpcConfig := clients.StandardClientConfig{} + if err := env.ParseWithOptions(&cvmGrpcConfig, env.Options{Prefix: envPrefixCVMGRPC}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)) + exitCode = 1 + return + } + + cvmClient, cvmsClient, err := cvmsgrpc.NewCVMClient(cvmGrpcConfig) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to CVM manager: %s", err)) + exitCode = 1 + return + } + defer cvmClient.Close() + + // Create stream to Manager + stream, err := cvmsClient.Process(ctx) + if err != nil { + logger.Error(fmt.Sprintf("failed to create stream to manager: %s", err)) + exitCode = 1 + return + } + + logQueue := make(chan *cvms.ClientStreamMessage, 1000) + + grpcServer := grpc.NewServer() + svc := service.New(logger, cvmsClient, logQueue) + pb.RegisterLogCollectorServer(grpcServer, svc) + + // Log Consumer Goroutine + g.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case msg := <-logQueue: + if err := stream.Send(msg); err != nil { + logger.Error(fmt.Sprintf("failed to send log to manager: %s", err)) + // Reconnect logic would go here + } + } + } + }) + + g.Go(func() error { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(ch) + + select { + case <-ch: + logger.Info("Received signal, shutting down...") + cancel() + grpcServer.GracefulStop() + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + + g.Go(func() error { + logger.Info(fmt.Sprintf("%s started on %s", svcName, socketPath)) + return grpcServer.Serve(lis) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("%s terminated: %s", svcName, err)) + } +} diff --git a/hal/linux/Config.in b/hal/linux/Config.in index 96dd78c0..48534b95 100644 --- a/hal/linux/Config.in +++ b/hal/linux/Config.in @@ -1,3 +1,7 @@ source "$BR2_EXTERNAL_COCOS_PATH/package/agent/Config.in" source "$BR2_EXTERNAL_COCOS_PATH/package/attestation-service/Config.in" source "$BR2_EXTERNAL_COCOS_PATH/package/wasmedge/Config.in" +source "$BR2_EXTERNAL_COCOS_PATH/package/log-forwarder/Config.in" +source "$BR2_EXTERNAL_COCOS_PATH/package/computation-runner/Config.in" +source "$BR2_EXTERNAL_COCOS_PATH/package/egress-proxy/Config.in" +source "$BR2_EXTERNAL_COCOS_PATH/package/ingress-proxy/Config.in" diff --git a/hal/linux/package/agent/Config.in b/hal/linux/package/agent/Config.in index f042b58c..0e9093be 100644 --- a/hal/linux/package/agent/Config.in +++ b/hal/linux/package/agent/Config.in @@ -2,6 +2,10 @@ config BR2_PACKAGE_AGENT bool "agent" default y select BR2_PACKAGE_ATTESTATION_SERVICE + select BR2_PACKAGE_LOG_FORWARDER + select BR2_PACKAGE_COMPUTATION_RUNNER + select BR2_PACKAGE_INGRESS_PROXY + select BR2_PACKAGE_EGRESS_PROXY help Confidential Computing Agent is a state machine capable of receiving datasets and algorithm, running computations, and diff --git a/hal/linux/package/computation-runner/Config.in b/hal/linux/package/computation-runner/Config.in new file mode 100644 index 00000000..18718d79 --- /dev/null +++ b/hal/linux/package/computation-runner/Config.in @@ -0,0 +1,5 @@ +config BR2_PACKAGE_COMPUTATION_RUNNER + bool "computation-runner" + select BR2_PACKAGE_LOG_FORWARDER + help + Cocos AI Computation Runner service. diff --git a/hal/linux/package/computation-runner/computation-runner.mk b/hal/linux/package/computation-runner/computation-runner.mk new file mode 100644 index 00000000..1850695a --- /dev/null +++ b/hal/linux/package/computation-runner/computation-runner.mk @@ -0,0 +1,22 @@ +################################################################################ +# +# computation-runner +# +################################################################################ + +COMPUTATION_RUNNER_VERSION = main +COMPUTATION_RUNNER_SITE = $(call github,ultravioletrs,cocos,$(COMPUTATION_RUNNER_VERSION)) + +define COMPUTATION_RUNNER_BUILD_CMDS + $(MAKE) -C $(@D) computation-runner +endef + +define COMPUTATION_RUNNER_INSTALL_TARGET_CMDS + $(INSTALL) -D -m 0750 $(@D)/build/cocos-computation-runner $(TARGET_DIR)/usr/bin/computation-runner +endef + +define COMPUTATION_RUNNER_INSTALL_INIT_SYSTEMD + $(INSTALL) -D -m 0640 $(@D)/init/systemd/computation-runner.service $(TARGET_DIR)/usr/lib/systemd/system/computation-runner.service +endef + +$(eval $(generic-package)) diff --git a/hal/linux/package/egress-proxy/Config.in b/hal/linux/package/egress-proxy/Config.in new file mode 100644 index 00000000..945d2689 --- /dev/null +++ b/hal/linux/package/egress-proxy/Config.in @@ -0,0 +1,6 @@ +config BR2_PACKAGE_EGRESS_PROXY + bool "egress-proxy" + help + Cocos AI Egress Proxy Service. + + https://github.com/ultravioletrs/cocos diff --git a/hal/linux/package/egress-proxy/egress-proxy.mk b/hal/linux/package/egress-proxy/egress-proxy.mk new file mode 100644 index 00000000..eb5e5866 --- /dev/null +++ b/hal/linux/package/egress-proxy/egress-proxy.mk @@ -0,0 +1,22 @@ +################################################################################ +# +# Cocos AI Egress Proxy +# +################################################################################ + +EGRESS_PROXY_VERSION = main +EGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(EGRESS_PROXY_VERSION)) + +define EGRESS_PROXY_BUILD_CMDS + $(MAKE) -C $(@D) egress-proxy +endef + +define EGRESS_PROXY_INSTALL_TARGET_CMDS + $(INSTALL) -D -m 0755 $(@D)/build/cocos-egress-proxy $(TARGET_DIR)/usr/bin/egress-proxy +endef + +define EGRESS_PROXY_INSTALL_INIT_SYSTEMD + $(INSTALL) -D -m 0644 $(@D)/init/systemd/egress-proxy.service $(TARGET_DIR)/usr/lib/systemd/system/egress-proxy.service +endef + +$(eval $(generic-package)) diff --git a/hal/linux/package/ingress-proxy/Config.in b/hal/linux/package/ingress-proxy/Config.in new file mode 100644 index 00000000..a0f6b36b --- /dev/null +++ b/hal/linux/package/ingress-proxy/Config.in @@ -0,0 +1,4 @@ +config BR2_PACKAGE_INGRESS_PROXY + bool "ingress-proxy" + help + Cocos Ingress Proxy service. diff --git a/hal/linux/package/ingress-proxy/ingress-proxy.mk b/hal/linux/package/ingress-proxy/ingress-proxy.mk new file mode 100644 index 00000000..e407a572 --- /dev/null +++ b/hal/linux/package/ingress-proxy/ingress-proxy.mk @@ -0,0 +1,22 @@ +################################################################################ +# +# ingress-proxy +# +################################################################################ + +INGRESS_PROXY_VERSION = main +INGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(INGRESS_PROXY_VERSION)) + +define INGRESS_PROXY_BUILD_CMDS + $(MAKE) -C $(@D) ingress-proxy +endef + +define INGRESS_PROXY_INSTALL_TARGET_CMDS + $(INSTALL) -D -m 0750 $(@D)/build/cocos-ingress-proxy $(TARGET_DIR)/usr/bin/ingress-proxy +endef + +# NOTE: The ingress-proxy is managed per-computation by the agent, not as a standalone +# systemd service. The binary is installed for use by the agent, but no systemd service +# is created. + +$(eval $(generic-package)) diff --git a/hal/linux/package/log-forwarder/Config.in b/hal/linux/package/log-forwarder/Config.in new file mode 100644 index 00000000..634daeb8 --- /dev/null +++ b/hal/linux/package/log-forwarder/Config.in @@ -0,0 +1,4 @@ +config BR2_PACKAGE_LOG_FORWARDER + bool "log-forwarder" + help + Cocos AI Log Forwarder service. diff --git a/hal/linux/package/log-forwarder/log-forwarder.mk b/hal/linux/package/log-forwarder/log-forwarder.mk new file mode 100644 index 00000000..47545880 --- /dev/null +++ b/hal/linux/package/log-forwarder/log-forwarder.mk @@ -0,0 +1,22 @@ +################################################################################ +# +# log-forwarder +# +################################################################################ + +LOG_FORWARDER_VERSION = main +LOG_FORWARDER_SITE = $(call github,ultravioletrs,cocos,$(LOG_FORWARDER_VERSION)) + +define LOG_FORWARDER_BUILD_CMDS + $(MAKE) -C $(@D) log-forwarder +endef + +define LOG_FORWARDER_INSTALL_TARGET_CMDS + $(INSTALL) -D -m 0750 $(@D)/build/cocos-log-forwarder $(TARGET_DIR)/usr/bin/log-forwarder +endef + +define LOG_FORWARDER_INSTALL_INIT_SYSTEMD + $(INSTALL) -D -m 0640 $(@D)/init/systemd/log-forwarder.service $(TARGET_DIR)/usr/lib/systemd/system/log-forwarder.service +endef + +$(eval $(generic-package)) diff --git a/init/systemd/cocos-agent.service b/init/systemd/cocos-agent.service index ff4b6493..0b809153 100644 --- a/init/systemd/cocos-agent.service +++ b/init/systemd/cocos-agent.service @@ -1,16 +1,19 @@ [Unit] Description=Cocos AI agent -After=network.target attestation-service.service -Requires=attestation-service.service +After=network.target attestation-service.service log-forwarder.service computation-runner.service egress-proxy.service +Requires=log-forwarder.service computation-runner.service egress-proxy.service Before=docker.service [Service] WorkingDirectory=/cocos +Environment="HTTP_PROXY=http://localhost:3128" +Environment="HTTPS_PROXY=http://localhost:3128" +Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/" +Environment="AGENT_ENABLE_ATLS=false" StandardOutput=file:/var/log/cocos/agent.stdout StandardError=file:/var/log/cocos/agent.stderr EnvironmentFile=/etc/cocos/environment -ExecStartPre=/cocos_init/agent_setup.sh -ExecStart=/cocos_init/agent_start_script.sh +ExecStart=/usr/bin/cocos-agent --config-file=/etc/cocos/cocos-agent.conf Restart=always RestartSec=5s diff --git a/init/systemd/computation-runner.service b/init/systemd/computation-runner.service new file mode 100644 index 00000000..8ab218cd --- /dev/null +++ b/init/systemd/computation-runner.service @@ -0,0 +1,20 @@ +[Unit] +Description=Cocos AI Computation Runner +After=network.target log-forwarder.service +Before=cocos-agent.service +Requires=log-forwarder.service + +[Service] +WorkingDirectory=/cocos +StandardOutput=file:/var/log/cocos/runner.stdout +StandardError=file:/var/log/cocos/runner.stderr +EnvironmentFile=/etc/cocos/environment +Environment="HTTP_PROXY=http://localhost:3128" +Environment="HTTPS_PROXY=http://localhost:3128" +Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/" +ExecStart=/usr/bin/computation-runner --socket-path=/run/cocos/runner.sock +Restart=always +RestartSec=5s + +[Install] +WantedBy=default.target diff --git a/init/systemd/egress-proxy.service b/init/systemd/egress-proxy.service new file mode 100644 index 00000000..6a942aa4 --- /dev/null +++ b/init/systemd/egress-proxy.service @@ -0,0 +1,15 @@ +[Unit] +Description=Cocos Egress Proxy Service +After=network.target + +[Service] +EnvironmentFile=/etc/cocos/environment +Environment="COCOS_LOG_LEVEL=debug" +ExecStart=/usr/bin/egress-proxy --port=3128 +Restart=always +RestartSec=5 +User=root +Group=root + +[Install] +WantedBy=multi-user.target diff --git a/init/systemd/log-forwarder.service b/init/systemd/log-forwarder.service new file mode 100644 index 00000000..2d040c7d --- /dev/null +++ b/init/systemd/log-forwarder.service @@ -0,0 +1,17 @@ +[Unit] +Description=Cocos AI Log Forwarder +After=network.target +Before=cocos-agent.service + +[Service] +WorkingDirectory=/cocos +StandardOutput=file:/var/log/cocos/log-forwarder.stdout +StandardError=file:/var/log/cocos/log-forwarder.stderr +EnvironmentFile=/etc/cocos/environment +ExecStartPre=/cocos_init/agent_setup.sh +ExecStart=/usr/bin/log-forwarder +Restart=always +RestartSec=5s + +[Install] +WantedBy=default.target diff --git a/internal/logger/protohandler.go b/internal/logger/protohandler.go index 58146c7e..8d4b9daf 100644 --- a/internal/logger/protohandler.go +++ b/internal/logger/protohandler.go @@ -44,6 +44,7 @@ func (h *handler) Enabled(_ context.Context, l slog.Level) bool { } func (h *handler) Handle(_ context.Context, r slog.Record) error { + slog.Info("logging message", "message", r.Message) message := r.Message timestamp := timestamppb.New(r.Time) level := r.Level.String() diff --git a/pkg/clients/grpc/attestation/client_test.go b/pkg/clients/grpc/attestation/client_test.go new file mode 100644 index 00000000..ebdd214c --- /dev/null +++ b/pkg/clients/grpc/attestation/client_test.go @@ -0,0 +1,392 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package attestation + +import ( + "context" + "net" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + attestation_v1 "github.com/ultravioletrs/cocos/internal/proto/attestation/v1" + "github.com/ultravioletrs/cocos/pkg/attestation" + "google.golang.org/grpc" +) + +// mockAttestationServer is a mock implementation of the AttestationServiceServer. +type mockAttestationServer struct { + attestation_v1.UnimplementedAttestationServiceServer + fetchAttestationCalled bool + fetchAzureTokenCalled bool + lastReportData []byte + lastNonce []byte + lastPlatformType attestation_v1.PlatformType + attestationErr error + azureTokenErr error +} + +func (m *mockAttestationServer) FetchAttestation(ctx context.Context, req *attestation_v1.AttestationRequest) (*attestation_v1.AttestationResponse, error) { + m.fetchAttestationCalled = true + m.lastReportData = req.ReportData + m.lastNonce = req.Nonce + m.lastPlatformType = req.PlatformType + + if m.attestationErr != nil { + return nil, m.attestationErr + } + + return &attestation_v1.AttestationResponse{ + Quote: []byte("mock-attestation-quote"), + }, nil +} + +func (m *mockAttestationServer) FetchAzureToken(ctx context.Context, req *attestation_v1.AzureTokenRequest) (*attestation_v1.AzureTokenResponse, error) { + m.fetchAzureTokenCalled = true + m.lastNonce = req.Nonce + + if m.azureTokenErr != nil { + return nil, m.azureTokenErr + } + + return &attestation_v1.AzureTokenResponse{ + Token: []byte("mock-azure-token"), + }, nil +} + +// TestNewClient tests creating a new attestation client. +func TestNewClient(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-test.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + require.NotNil(t, client) + + err = client.Close() + assert.NoError(t, err) +} + +// TestGetAttestationSNP tests getting SNP attestation. +func TestGetAttestationSNP(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-snp.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + var reportData [64]byte + var nonce [32]byte + copy(reportData[:], []byte("test-report-data")) + copy(nonce[:], []byte("test-nonce")) + + quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.SNP) + require.NoError(t, err) + assert.Equal(t, []byte("mock-attestation-quote"), quote) + assert.True(t, mockServer.fetchAttestationCalled) + assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP, mockServer.lastPlatformType) + assert.Equal(t, reportData[:], mockServer.lastReportData) + assert.Equal(t, nonce[:], mockServer.lastNonce) +} + +// TestGetAttestationTDX tests getting TDX attestation. +func TestGetAttestationTDX(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-tdx.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + var reportData [64]byte + var nonce [32]byte + + quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.TDX) + require.NoError(t, err) + assert.NotNil(t, quote) + assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType) +} + +// TestGetAttestationVTPM tests getting vTPM attestation. +func TestGetAttestationVTPM(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-vtpm.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + var reportData [64]byte + var nonce [32]byte + + quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.VTPM) + require.NoError(t, err) + assert.NotNil(t, quote) + assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_VTPM, mockServer.lastPlatformType) +} + +// TestGetAttestationSNPvTPM tests getting SNP+vTPM attestation. +func TestGetAttestationSNPvTPM(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-snpvtpm.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + var reportData [64]byte + var nonce [32]byte + + quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.SNPvTPM) + require.NoError(t, err) + assert.NotNil(t, quote) + assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM, mockServer.lastPlatformType) +} + +// TestGetAttestationUnspecified tests getting attestation with unspecified platform. +func TestGetAttestationUnspecified(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-unspec.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + var reportData [64]byte + var nonce [32]byte + + // Use an invalid platform type (999) + quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.PlatformType(999)) + require.NoError(t, err) + assert.NotNil(t, quote) + assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED, mockServer.lastPlatformType) +} + +// TestGetAzureToken tests getting Azure token. +func TestGetAzureToken(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-azure.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + var nonce [32]byte + copy(nonce[:], []byte("azure-nonce")) + + token, err := client.GetAzureToken(ctx, nonce) + require.NoError(t, err) + assert.Equal(t, []byte("mock-azure-token"), token) + assert.True(t, mockServer.fetchAzureTokenCalled) + assert.Equal(t, nonce[:], mockServer.lastNonce) +} + +// TestGetAttestationWithCanceledContext tests GetAttestation with canceled context. +func TestGetAttestationWithCanceledContext(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-cancel.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var reportData [64]byte + var nonce [32]byte + + _, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +// TestClientClose tests closing the attestation client. +func TestClientClose(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-close.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + + err = client.Close() + assert.NoError(t, err) +} + +// TestClientOperationsAfterClose tests operations after closing. +func TestClientOperationsAfterClose(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "attestation-after-close.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockAttestationServer{} + attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + + err = client.Close() + require.NoError(t, err) + + ctx := context.Background() + var reportData [64]byte + var nonce [32]byte + + _, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP) + assert.Error(t, err) +} diff --git a/pkg/clients/grpc/log/client.go b/pkg/clients/grpc/log/client.go new file mode 100644 index 00000000..017fff5b --- /dev/null +++ b/pkg/clients/grpc/log/client.go @@ -0,0 +1,64 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package log + +import ( + "context" + "time" + + "github.com/ultravioletrs/cocos/agent/log" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type Client interface { + SendLog(ctx context.Context, entry *log.LogEntry) error + SendEvent(ctx context.Context, entry *log.EventEntry) error + Close() error +} + +type client struct { + conn *grpc.ClientConn + client log.LogCollectorClient +} + +func NewClient(socketPath string) (Client, error) { + conn, err := grpc.NewClient("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + + return &client{ + conn: conn, + client: log.NewLogCollectorClient(conn), + }, nil +} + +func (c *client) Close() error { + return c.conn.Close() +} + +func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if entry.Timestamp == nil { + entry.Timestamp = timestamppb.Now() + } + + _, err := c.client.SendLog(ctx, entry) + return err +} + +func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if entry.Timestamp == nil { + entry.Timestamp = timestamppb.Now() + } + + _, err := c.client.SendEvent(ctx, entry) + return err +} diff --git a/pkg/clients/grpc/log/client_test.go b/pkg/clients/grpc/log/client_test.go new file mode 100644 index 00000000..f6183f15 --- /dev/null +++ b/pkg/clients/grpc/log/client_test.go @@ -0,0 +1,332 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package log + +import ( + "context" + "net" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/agent/log" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// mockLogCollectorServer is a mock implementation of the LogCollectorServer. +type mockLogCollectorServer struct { + log.UnimplementedLogCollectorServer + sendLogCalled bool + sendEventCalled bool + lastLogEntry *log.LogEntry + lastEventEntry *log.EventEntry + sendLogErr error + sendEventErr error +} + +func (m *mockLogCollectorServer) SendLog(ctx context.Context, entry *log.LogEntry) (*emptypb.Empty, error) { + m.sendLogCalled = true + m.lastLogEntry = entry + if m.sendLogErr != nil { + return nil, m.sendLogErr + } + return &emptypb.Empty{}, nil +} + +func (m *mockLogCollectorServer) SendEvent(ctx context.Context, entry *log.EventEntry) (*emptypb.Empty, error) { + m.sendEventCalled = true + m.lastEventEntry = entry + if m.sendEventErr != nil { + return nil, m.sendEventErr + } + return &emptypb.Empty{}, nil +} + +// TestNewClient tests creating a new log client. +func TestNewClient(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-test.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + require.NotNil(t, client) + + err = client.Close() + assert.NoError(t, err) +} + +// TestClientSendLog tests sending a log entry. +func TestClientSendLog(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-sendlog.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + entry := &log.LogEntry{ + Level: "INFO", + Message: "test log message", + ComputationId: "test-computation", + } + + err = client.SendLog(ctx, entry) + require.NoError(t, err) + assert.True(t, mockServer.sendLogCalled) + assert.Equal(t, "INFO", mockServer.lastLogEntry.Level) + assert.Equal(t, "test log message", mockServer.lastLogEntry.Message) + assert.Equal(t, "test-computation", mockServer.lastLogEntry.ComputationId) + assert.NotNil(t, mockServer.lastLogEntry.Timestamp) +} + +// TestClientSendLogWithTimestamp tests sending a log entry with existing timestamp. +func TestClientSendLogWithTimestamp(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-timestamp.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + customTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + entry := &log.LogEntry{ + Level: "ERROR", + Message: "test error", + ComputationId: "test", + Timestamp: timestamppb.New(customTime), + } + + err = client.SendLog(ctx, entry) + require.NoError(t, err) + assert.True(t, mockServer.sendLogCalled) + assert.Equal(t, customTime.Unix(), mockServer.lastLogEntry.Timestamp.AsTime().Unix()) +} + +// TestClientSendEvent tests sending an event entry. +func TestClientSendEvent(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-sendevent.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + entry := &log.EventEntry{ + EventType: "computation.started", + ComputationId: "test-computation", + Originator: "agent", + Status: "started", + } + + err = client.SendEvent(ctx, entry) + require.NoError(t, err) + assert.True(t, mockServer.sendEventCalled) + assert.Equal(t, "computation.started", mockServer.lastEventEntry.EventType) + assert.Equal(t, "agent", mockServer.lastEventEntry.Originator) + assert.NotNil(t, mockServer.lastEventEntry.Timestamp) +} + +// TestClientSendEventWithTimestamp tests sending an event with existing timestamp. +func TestClientSendEventWithTimestamp(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-event-timestamp.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + customTime := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC) + entry := &log.EventEntry{ + EventType: "test.event", + ComputationId: "test", + Timestamp: timestamppb.New(customTime), + } + + err = client.SendEvent(ctx, entry) + require.NoError(t, err) + assert.True(t, mockServer.sendEventCalled) + assert.Equal(t, customTime.Unix(), mockServer.lastEventEntry.Timestamp.AsTime().Unix()) +} + +// TestClientSendLogWithCanceledContext tests SendLog with canceled context. +func TestClientSendLogWithCanceledContext(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-cancel.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + entry := &log.LogEntry{ + Level: "INFO", + Message: "test", + } + + err = client.SendLog(ctx, entry) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +// TestClientClose tests closing the client. +func TestClientClose(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-close.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + + err = client.Close() + assert.NoError(t, err) +} + +// TestClientOperationsAfterClose tests operations after closing. +func TestClientOperationsAfterClose(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "log-after-close.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockLogCollectorServer{} + log.RegisterLogCollectorServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + + err = client.Close() + require.NoError(t, err) + + ctx := context.Background() + entry := &log.LogEntry{ + Level: "INFO", + Message: "test", + } + + err = client.SendLog(ctx, entry) + assert.Error(t, err) +} diff --git a/pkg/clients/grpc/runner/client.go b/pkg/clients/grpc/runner/client.go new file mode 100644 index 00000000..8c9d3106 --- /dev/null +++ b/pkg/clients/grpc/runner/client.go @@ -0,0 +1,52 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package runner + +import ( + "context" + "time" + + pb "github.com/ultravioletrs/cocos/agent/runner" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/emptypb" +) + +type Client interface { + Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) + Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) + Close() error +} + +type client struct { + conn *grpc.ClientConn + client pb.ComputationRunnerClient +} + +func NewClient(socketPath string) (Client, error) { + conn, err := grpc.NewClient("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + + return &client{ + conn: conn, + client: pb.NewComputationRunnerClient(conn), + }, nil +} + +func (c *client) Close() error { + return c.conn.Close() +} + +func (c *client) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) { + // Run might take long time, so we need unlimited timeout or rely on context cancellation + return c.client.Run(ctx, req) +} + +func (c *client) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + return c.client.Stop(ctx, req) +} diff --git a/pkg/clients/grpc/runner/client_test.go b/pkg/clients/grpc/runner/client_test.go new file mode 100644 index 00000000..9eb44971 --- /dev/null +++ b/pkg/clients/grpc/runner/client_test.go @@ -0,0 +1,349 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package runner + +import ( + "context" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + pb "github.com/ultravioletrs/cocos/agent/runner" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +// mockComputationRunnerServer is a mock implementation of the ComputationRunnerServer. +type mockComputationRunnerServer struct { + pb.UnimplementedComputationRunnerServer + runCalled bool + stopCalled bool + runErr error + stopErr error +} + +func (m *mockComputationRunnerServer) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) { + m.runCalled = true + if m.runErr != nil { + return nil, m.runErr + } + return &pb.RunResponse{ + ComputationId: req.ComputationId, + Error: "", + }, nil +} + +func (m *mockComputationRunnerServer) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) { + m.stopCalled = true + if m.stopErr != nil { + return nil, m.stopErr + } + return &emptypb.Empty{}, nil +} + +// TestNewClient tests creating a new gRPC client. +func TestNewClient(t *testing.T) { + // Create a temporary directory for the socket + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test.sock") + + // Start a mock gRPC server + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + // Give server time to start + time.Sleep(100 * time.Millisecond) + + // Create client + client, err := NewClient(socketPath) + require.NoError(t, err) + require.NotNil(t, client) + + // Clean up + err = client.Close() + assert.NoError(t, err) +} + +// TestNewClientInvalidSocket tests creating a client with invalid socket path. +func TestNewClientInvalidSocket(t *testing.T) { + // Use a non-existent socket path + socketPath := "/tmp/nonexistent-" + time.Now().Format("20060102150405") + ".sock" + + // Create client - this should succeed as grpc.NewClient is lazy + client, err := NewClient(socketPath) + require.NoError(t, err) + require.NotNil(t, client) + + // Close should work even if never connected + err = client.Close() + assert.NoError(t, err) +} + +// TestClientRun tests the Run method. +func TestClientRun(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test-run.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + // Call Run + ctx := context.Background() + req := &pb.RunRequest{ + ComputationId: "test-computation", + } + + resp, err := client.Run(ctx, req) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, "test-computation", resp.ComputationId) + assert.True(t, mockServer.runCalled) +} + +// TestClientRunWithCanceledContext tests Run with a canceled context. +func TestClientRunWithCanceledContext(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test-run-cancel.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + // Create a canceled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := &pb.RunRequest{ + ComputationId: "test-computation", + } + + _, err = client.Run(ctx, req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") +} + +// TestClientStop tests the Stop method. +func TestClientStop(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test-stop.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + // Call Stop + ctx := context.Background() + req := &pb.StopRequest{ + ComputationId: "test-computation", + } + + resp, err := client.Stop(ctx, req) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, mockServer.stopCalled) +} + +// TestClientStopWithTimeout tests Stop with context timeout. +func TestClientStopWithTimeout(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test-stop-timeout.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + defer client.Close() + + // Create a context that will timeout + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req := &pb.StopRequest{ + ComputationId: "test-computation", + } + + // Stop should complete within the timeout + resp, err := client.Stop(ctx, req) + require.NoError(t, err) + assert.NotNil(t, resp) + assert.True(t, mockServer.stopCalled) +} + +// TestClientClose tests the Close method. +func TestClientClose(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test-close.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + + // Close should succeed + err = client.Close() + assert.NoError(t, err) +} + +// TestClientOperationsAfterClose tests that operations fail gracefully after close. +func TestClientOperationsAfterClose(t *testing.T) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test-after-close.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + client, err := NewClient(socketPath) + require.NoError(t, err) + + // Close the client + err = client.Close() + require.NoError(t, err) + + // Try to use the client after closing + ctx := context.Background() + runReq := &pb.RunRequest{ComputationId: "test"} + + // This should fail because connection is closed + _, err = client.Run(ctx, runReq) + assert.Error(t, err) +} + +// TestNewClientWithRelativePath tests creating client with relative socket path. +func TestNewClientWithRelativePath(t *testing.T) { + // Create temp directory + tmpDir := t.TempDir() + + // Change to temp directory + oldWd, err := os.Getwd() + require.NoError(t, err) + defer func() { + err := os.Chdir(oldWd) + require.NoError(t, err) + }() + + err = os.Chdir(tmpDir) + require.NoError(t, err) + + socketPath := "relative-test.sock" + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + defer listener.Close() + + grpcServer := grpc.NewServer() + mockServer := &mockComputationRunnerServer{} + pb.RegisterComputationRunnerServer(grpcServer, mockServer) + + go func() { + _ = grpcServer.Serve(listener) + }() + defer grpcServer.Stop() + + time.Sleep(100 * time.Millisecond) + + // Create client with relative path + client, err := NewClient(socketPath) + require.NoError(t, err) + require.NotNil(t, client) + + err = client.Close() + assert.NoError(t, err) +} diff --git a/pkg/clients/grpc/runner/mocks/client.go b/pkg/clients/grpc/runner/mocks/client.go new file mode 100644 index 00000000..495bbdfe --- /dev/null +++ b/pkg/clients/grpc/runner/mocks/client.go @@ -0,0 +1,223 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + mock "github.com/stretchr/testify/mock" + "github.com/ultravioletrs/cocos/agent/runner" + "google.golang.org/protobuf/types/known/emptypb" +) + +// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewClient(t interface { + mock.TestingT + Cleanup(func()) +}) *Client { + mock := &Client{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +type Client_Expecter struct { + mock *mock.Mock +} + +func (_m *Client) EXPECT() *Client_Expecter { + return &Client_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function for the type Client +func (_mock *Client) Close() error { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func() error); ok { + r0 = returnFunc() + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Client_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type Client_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *Client_Expecter) Close() *Client_Close_Call { + return &Client_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *Client_Close_Call) Run(run func()) *Client_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Client_Close_Call) Return(err error) *Client_Close_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Client_Close_Call) RunAndReturn(run func() error) *Client_Close_Call { + _c.Call.Return(run) + return _c +} + +// Run provides a mock function for the type Client +func (_mock *Client) Run(ctx context.Context, req *runner.RunRequest) (*runner.RunResponse, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Run") + } + + var r0 *runner.RunResponse + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.RunRequest) (*runner.RunResponse, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.RunRequest) *runner.RunResponse); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*runner.RunResponse) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *runner.RunRequest) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Client_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run' +type Client_Run_Call struct { + *mock.Call +} + +// Run is a helper method to define mock.On call +// - ctx context.Context +// - req *runner.RunRequest +func (_e *Client_Expecter) Run(ctx interface{}, req interface{}) *Client_Run_Call { + return &Client_Run_Call{Call: _e.mock.On("Run", ctx, req)} +} + +func (_c *Client_Run_Call) Run(run func(ctx context.Context, req *runner.RunRequest)) *Client_Run_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *runner.RunRequest + if args[1] != nil { + arg1 = args[1].(*runner.RunRequest) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Client_Run_Call) Return(runResponse *runner.RunResponse, err error) *Client_Run_Call { + _c.Call.Return(runResponse, err) + return _c +} + +func (_c *Client_Run_Call) RunAndReturn(run func(ctx context.Context, req *runner.RunRequest) (*runner.RunResponse, error)) *Client_Run_Call { + _c.Call.Return(run) + return _c +} + +// Stop provides a mock function for the type Client +func (_mock *Client) Stop(ctx context.Context, req *runner.StopRequest) (*emptypb.Empty, error) { + ret := _mock.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for Stop") + } + + var r0 *emptypb.Empty + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.StopRequest) (*emptypb.Empty, error)); ok { + return returnFunc(ctx, req) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.StopRequest) *emptypb.Empty); ok { + r0 = returnFunc(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*emptypb.Empty) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *runner.StopRequest) error); ok { + r1 = returnFunc(ctx, req) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Client_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type Client_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +// - ctx context.Context +// - req *runner.StopRequest +func (_e *Client_Expecter) Stop(ctx interface{}, req interface{}) *Client_Stop_Call { + return &Client_Stop_Call{Call: _e.mock.On("Stop", ctx, req)} +} + +func (_c *Client_Stop_Call) Run(run func(ctx context.Context, req *runner.StopRequest)) *Client_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *runner.StopRequest + if args[1] != nil { + arg1 = args[1].(*runner.StopRequest) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Client_Stop_Call) Return(empty *emptypb.Empty, err error) *Client_Stop_Call { + _c.Call.Return(empty, err) + return _c +} + +func (_c *Client_Stop_Call) RunAndReturn(run func(ctx context.Context, req *runner.StopRequest) (*emptypb.Empty, error)) *Client_Stop_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/egress/proxy.go b/pkg/egress/proxy.go new file mode 100644 index 00000000..f7d2e5fe --- /dev/null +++ b/pkg/egress/proxy.go @@ -0,0 +1,248 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package egress + +import ( + "context" + "crypto/tls" + "io" + "log/slog" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "time" + + "golang.org/x/net/http2" +) + +// Proxy is an egress proxy server. +type Proxy struct { + logger *slog.Logger + server *http.Server + addr string + transport *http.Transport +} + +// NewProxy creates a new egress proxy. +func NewProxy(logger *slog.Logger, addr string) *Proxy { + // Create HTTP/2 capable transport + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: false, + }, + ForceAttemptHTTP2: true, + } + // Enable HTTP/2 + if err := http2.ConfigureTransport(transport); err != nil { + logger.Warn("Failed to configure HTTP/2 transport", "error", err) + } + + p := &Proxy{ + logger: logger, + addr: addr, + transport: transport, + } + p.server = &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(p.handle), + } + return p +} + +// Start starts the proxy server. +func (p *Proxy) Start() error { + p.logger.Info("Starting egress proxy", "addr", p.addr) + return p.server.ListenAndServe() +} + +// Stop stops the proxy server. +func (p *Proxy) Stop(ctx context.Context) error { + p.logger.Info("Stopping egress proxy") + return p.server.Shutdown(ctx) +} + +func (p *Proxy) handle(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + p.handleConnect(w, r) + } else if r.ProtoMajor == 2 { + p.handleHTTP2(w, r) + } else { + p.handleHTTP(w, r) + } +} + +func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { + host := r.Host + p.logger.Info("CONNECT request received", "host", host) + + // nolint:godox // TODO: Check allowlist here - allowlist implementation deferred + + p.logger.Debug("Dialing destination", "host", host) + destConn, err := net.DialTimeout("tcp", host, 10*time.Second) + if err != nil { + p.logger.Error("Failed to dial destination", "host", host, "error", err) + http.Error(w, err.Error(), http.StatusServiceUnavailable) + return + } + defer destConn.Close() + p.logger.Info("Successfully connected to destination", "host", host) + + p.logger.Debug("Hijacking client connection") + hijacker, ok := w.(http.Hijacker) + if !ok { + p.logger.Error("Hijacking not supported") + http.Error(w, "Hijacking not supported", http.StatusInternalServerError) + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + p.logger.Error("Failed to hijack connection", "error", err) + return + } + defer clientConn.Close() + p.logger.Info("Successfully hijacked client connection", "host", host) + + // Send 200 Connection Established response + p.logger.Debug("Sending 200 Connection Established") + _, err = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + if err != nil { + p.logger.Error("Failed to send CONNECT response", "error", err) + return + } + p.logger.Info("Starting bidirectional pipe", "host", host) + + p.pipe(clientConn, destConn) + p.logger.Info("Pipe completed", "host", host) +} + +func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) { + p.logger.Info("HTTP request", "method", r.Method, "url", r.URL.String()) + + // nolint:godox // TODO: Check allowlist here - allowlist implementation deferred + + r.RequestURI = "" // RequestURI must be empty for Client.Do + + // Remove hop-by-hop headers + delHopHeaders(r.Header) + + // Create a client to send the request + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + resp, err := client.Do(r) + if err != nil { + p.logger.Error("Failed to execute request", "error", err) + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + // Copy headers + copyHeader(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + p.logger.Error("Failed to copy response body", "error", err) + } +} + +func (p *Proxy) handleHTTP2(w http.ResponseWriter, r *http.Request) { + p.logger.Info("HTTP/2 request", "method", r.Method, "host", r.Host, "path", r.URL.Path) + + // nolint:godox // TODO: Check allowlist here - allowlist implementation deferred + + // Parse the target URL from the request + targetURL := &url.URL{ + Scheme: "http", + Host: r.Host, + } + + // If the request has a full URL (absolute form), use it + if r.URL.IsAbs() { + targetURL = r.URL + } + + // Create a reverse proxy with HTTP/2 transport + proxy := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = targetURL.Scheme + req.URL.Host = targetURL.Host + req.Host = targetURL.Host + + // Preserve the original path and query + if !r.URL.IsAbs() { + req.URL.Path = r.URL.Path + req.URL.RawQuery = r.URL.RawQuery + } + + // Remove hop-by-hop headers + delHopHeaders(req.Header) + }, + Transport: p.transport, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + p.logger.Error("HTTP/2 proxy error", "error", err, "host", r.Host) + http.Error(w, err.Error(), http.StatusBadGateway) + }, + } + + proxy.ServeHTTP(w, r) +} + +func (p *Proxy) pipe(src, dst net.Conn) { + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + n, err := io.Copy(dst, src) + p.logger.Debug("Pipe src->dst completed", "bytes", n, "error", err) + // Close write end of dst if possible, or just close it + if c, ok := dst.(*net.TCPConn); ok { + if err := c.CloseWrite(); err != nil { + p.logger.Debug("Failed to close write end of dst", "error", err) + } + } + }() + + go func() { + defer wg.Done() + n, err := io.Copy(src, dst) + p.logger.Debug("Pipe dst->src completed", "bytes", n, "error", err) + if c, ok := src.(*net.TCPConn); ok { + if err := c.CloseWrite(); err != nil { + p.logger.Debug("Failed to close write end of src", "error", err) + } + } + }() + + wg.Wait() +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func delHopHeaders(header http.Header) { + // Standard hop-by-hop headers + hopHeaders := []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailers", + "Transfer-Encoding", + "Upgrade", + } + for _, h := range hopHeaders { + header.Del(h) + } +} diff --git a/pkg/egress/proxy_test.go b/pkg/egress/proxy_test.go new file mode 100644 index 00000000..9dd0e8f1 --- /dev/null +++ b/pkg/egress/proxy_test.go @@ -0,0 +1,795 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package egress + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProxyHTTP(t *testing.T) { + // 1. Start a backend server + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("backend response")); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + // 2. Start Proxy + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + proxy := NewProxy(logger, ":0") + + // Listen on a random port + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + // waiting for server start + time.Sleep(100 * time.Millisecond) + + // 3. Make request via proxy + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + resp, err := client.Get(backend.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "backend response", string(body)) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestProxyConnect(t *testing.T) { + // 1. Start a backend TLS server + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("secure backend response")); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + // 2. Start Proxy + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + // Listen on a random port + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // 3. Configure client to use proxy + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTPS_PROXY", proxyURL) + defer os.Unsetenv("HTTPS_PROXY") + + client := backend.Client() // This client trusts the test cert + // But we need to update its transport proxy + if transport, ok := client.Transport.(*http.Transport); ok { + transport.Proxy = http.ProxyFromEnvironment + } else { + // Create new transport if needed, but backend.Client() returns transport with TLS config + tr := &http.Transport{ + TLSClientConfig: client.Transport.(*http.Transport).TLSClientConfig, + Proxy: http.ProxyFromEnvironment, + } + client.Transport = tr + } + + resp, err := client.Get(backend.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "secure backend response", string(body)) +} + +// TestProxyHTTP2 tests HTTP/2 requests through the proxy. +func TestProxyHTTP2(t *testing.T) { + // 1. Start a backend server + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("http2 response")); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + // 2. Start Proxy + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // 3. Make HTTP/2 request via proxy + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + ForceAttemptHTTP2: true, + }, + } + + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + // This will be an HTTP/1.1 request unless explicitly configured for HTTP/2 + resp, err := client.Get(backend.URL) + if err == nil { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, "http2 response", string(body)) + } +} + +// TestProxyHeaderHandling tests that headers are properly handled. +func TestProxyHeaderHandling(t *testing.T) { + // Start a backend server that echoes headers + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Custom-Header", "custom-value") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(r.Header.Get("X-Request-Header"))); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + // Start Proxy + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // Make request with custom headers + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + req, _ := http.NewRequest("GET", backend.URL, nil) + req.Header.Set("X-Request-Header", "request-value") + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + assert.Equal(t, "custom-value", resp.Header.Get("X-Custom-Header")) + } +} + +// TestProxyWithDifferentMethods tests different HTTP methods. +func TestProxyWithDifferentMethods(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(r.Method)); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + methods := []string{"GET", "POST", "PUT", "DELETE"} + for _, method := range methods { + req, _ := http.NewRequest(method, backend.URL, nil) + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + resp, err := client.Do(req) + if err == nil { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + assert.Equal(t, method, string(body)) + } + } +} + +// TestProxyErrorHandling tests error handling in the proxy. +func TestProxyErrorHandling(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // Try to connect to a non-existent backend + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + // This should fail because the backend doesn't exist + resp, err := client.Get("http://nonexistent.example.com:99999") + if err != nil { + return + } + if resp != nil { + defer resp.Body.Close() + // Status should be error + assert.NotEqual(t, http.StatusOK, resp.StatusCode) + } +} + +// TestProxyWithLargeBody tests proxy with large response body. +func TestProxyWithLargeBody(t *testing.T) { + largeBody := make([]byte, 1024*1024) // 1MB + for i := range largeBody { + largeBody[i] = byte(i % 256) + } + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write(largeBody); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + resp, err := client.Get(backend.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, len(largeBody), len(body)) +} + +// TestCopyHeader tests the copyHeader utility function. +func TestCopyHeader(t *testing.T) { + src := http.Header{} + src.Add("X-Custom-Header", "value1") + src.Add("X-Custom-Header", "value2") + src.Add("Content-Type", "application/json") + + dst := http.Header{} + copyHeader(dst, src) + + assert.Equal(t, []string{"value1", "value2"}, dst["X-Custom-Header"]) + assert.Equal(t, []string{"application/json"}, dst["Content-Type"]) +} + +// TestDelHopHeaders tests the delHopHeaders utility function. +func TestDelHopHeaders(t *testing.T) { + header := http.Header{} + header.Set("Connection", "keep-alive") + header.Set("Keep-Alive", "timeout=5") + header.Set("Proxy-Authenticate", "Basic") + header.Set("Proxy-Authorization", "Bearer token") + header.Set("Te", "trailers") + header.Set("Trailers", "X-Custom") + header.Set("Transfer-Encoding", "chunked") + header.Set("Upgrade", "websocket") + header.Set("X-Custom-Header", "should-remain") + + delHopHeaders(header) + + // Hop-by-hop headers should be removed + assert.Empty(t, header.Get("Connection")) + assert.Empty(t, header.Get("Keep-Alive")) + assert.Empty(t, header.Get("Proxy-Authenticate")) + assert.Empty(t, header.Get("Proxy-Authorization")) + assert.Empty(t, header.Get("Te")) + assert.Empty(t, header.Get("Trailers")) + assert.Empty(t, header.Get("Transfer-Encoding")) + assert.Empty(t, header.Get("Upgrade")) + + // Custom headers should remain + assert.Equal(t, "should-remain", header.Get("X-Custom-Header")) +} + +// TestProxyConnectDialTimeout tests CONNECT with dial timeout. +func TestProxyConnectDialTimeout(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // Try to CONNECT to a non-routable address (should timeout) + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTPS_PROXY", proxyURL) + defer os.Unsetenv("HTTPS_PROXY") + + // Create a client with very short timeout + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + Timeout: 2 * time.Second, + } + + // This should fail because 192.0.2.1 is a TEST-NET address (non-routable) + _, err = client.Get("https://192.0.2.1:9999/test") + assert.Error(t, err) +} + +// TestProxyHTTPWithRedirect tests HTTP proxy handling redirects. +func TestProxyHTTPWithRedirect(t *testing.T) { + // Create a backend that redirects + redirectCount := 0 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if redirectCount == 0 { + redirectCount++ + http.Redirect(w, r, "/redirected", http.StatusFound) + return + } + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("redirected response")); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } + + resp, err := client.Get(backend.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "redirected response", string(body)) +} + +// TestProxyStopContext tests proxy stop with context. +func TestProxyStopContext(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = proxy.Stop(ctx) + assert.NoError(t, err) +} + +// TestProxyPipeWithRealConnections tests the pipe function with real TCP connections. +func TestProxyPipeWithRealConnections(t *testing.T) { + // Create two connected TCP connections + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer listener.Close() + + // Channel to receive the server connection + serverConnChan := make(chan net.Conn, 1) + go func() { + conn, err := listener.Accept() + if err == nil { + serverConnChan <- conn + } + }() + + // Create client connection + clientConn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + // Get server connection + serverConn := <-serverConnChan + defer serverConn.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + proxy := NewProxy(logger, ":0") + + // Test data transfer + testData := []byte("test data for pipe") + + // Start pipe in goroutine + go proxy.pipe(clientConn, serverConn) + + // Write from client + _, err = clientConn.Write(testData) + require.NoError(t, err) + + // Read from server + buf := make([]byte, len(testData)) + if err := serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + t.Logf("Failed to set read deadline: %v", err) + } + n, err := serverConn.Read(buf) + require.NoError(t, err) + assert.Equal(t, testData, buf[:n]) + + // Close connections to trigger pipe completion + clientConn.Close() + serverConn.Close() + + // Give pipe time to complete + time.Sleep(100 * time.Millisecond) +} + +// TestProxyHTTP2ErrorPath tests HTTP/2 error handler. +func TestProxyHTTP2ErrorPath(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // Create a request that will trigger HTTP/2 handling + req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+"/test", nil) + require.NoError(t, err) + + // Force HTTP/2 by setting the request protocol + req.ProtoMajor = 2 + req.ProtoMinor = 0 + req.Host = "nonexistent.invalid:9999" // This should cause an error + + // Create a response recorder + rr := httptest.NewRecorder() + + // Call the handler directly to test HTTP/2 error path + proxy.server.Handler.ServeHTTP(rr, req) + + // Should get an error response + assert.Equal(t, http.StatusBadGateway, rr.Code) +} + +// TestNewProxyHTTP2ConfigWarning tests NewProxy when HTTP/2 configuration might fail. +func TestNewProxyHTTP2ConfigWarning(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + // Create proxy - HTTP/2 configuration should succeed normally + proxy := NewProxy(logger, ":0") + + assert.NotNil(t, proxy) + assert.NotNil(t, proxy.transport) + assert.True(t, proxy.transport.ForceAttemptHTTP2) +} + +// TestProxyHandleHTTPError tests HTTP handler error path. +func TestProxyHandleHTTPError(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTP_PROXY", proxyURL) + defer os.Unsetenv("HTTP_PROXY") + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + Timeout: 2 * time.Second, + } + + // Try to connect to invalid backend + resp, err := client.Get("http://invalid.backend.test:99999/test") + if err == nil { + defer resp.Body.Close() + // Should get error status + assert.NotEqual(t, http.StatusOK, resp.StatusCode) + } + // Either error or bad gateway response is acceptable +} + +// TestProxyConnectWriteError tests CONNECT with write error after hijacking. +func TestProxyConnectWriteError(t *testing.T) { + // This test is challenging because we need to trigger a write error + // after successful hijacking. We'll test the path by using a closed connection. + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // Create a backend server for CONNECT + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer backend.Close() + + proxyURL := fmt.Sprintf("http://%s", ln.Addr().String()) + os.Setenv("HTTPS_PROXY", proxyURL) + defer os.Unsetenv("HTTPS_PROXY") + + client := backend.Client() + if transport, ok := client.Transport.(*http.Transport); ok { + transport.Proxy = http.ProxyFromEnvironment + } + + // Make a request through CONNECT + _, err = client.Get(backend.URL) + // The request may succeed or fail, but we're testing the code path + if err != nil { + t.Logf("Request error (expected in some cases): %v", err) + } +} + +// TestProxyHTTP2WithAbsoluteURL tests HTTP/2 handling with absolute URL. +func TestProxyHTTP2WithAbsoluteURL(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("http2 absolute url response")); err != nil { + t.Logf("Failed to write response: %v", err) + } + })) + defer backend.Close() + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ln, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + proxy := NewProxy(logger, ln.Addr().String()) + proxy.server.Addr = ln.Addr().String() + + go func() { + if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Logf("Proxy server error: %v", err) + } + }() + defer func() { + if err := proxy.Stop(context.Background()); err != nil { + t.Logf("Failed to stop proxy: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + // Create request with absolute URL + req, err := http.NewRequest("GET", backend.URL+"/test", nil) + require.NoError(t, err) + req.ProtoMajor = 2 + req.ProtoMinor = 0 + + rr := httptest.NewRecorder() + proxy.server.Handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +} diff --git a/pkg/ingress/adapter.go b/pkg/ingress/adapter.go new file mode 100644 index 00000000..2618888e --- /dev/null +++ b/pkg/ingress/adapter.go @@ -0,0 +1,25 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package ingress + +import "github.com/ultravioletrs/cocos/agent" + +// AgentConfigToProxyConfig converts agent.AgentConfig to ProxyConfig. +func AgentConfigToProxyConfig(cfg agent.AgentConfig) ProxyConfig { + return ProxyConfig{ + Port: "7002", // Ingress-proxy always uses port 7002 + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + ServerCAFile: cfg.ServerCAFile, + ClientCAFile: cfg.ClientCAFile, + AttestedTLS: cfg.AttestedTls, + } +} + +// ComputationToProxyContext converts agent.Computation to ProxyContext. +func ComputationToProxyContext(cmp agent.Computation) ProxyContext { + return ProxyContext{ + ID: cmp.ID, + Name: cmp.Name, + } +} diff --git a/pkg/ingress/adapter_test.go b/pkg/ingress/adapter_test.go new file mode 100644 index 00000000..0f068896 --- /dev/null +++ b/pkg/ingress/adapter_test.go @@ -0,0 +1,166 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package ingress + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/ultravioletrs/cocos/agent" +) + +// TestAgentConfigToProxyConfig tests conversion from AgentConfig to ProxyConfig. +func TestAgentConfigToProxyConfig(t *testing.T) { + tests := []struct { + name string + input agent.AgentConfig + expected ProxyConfig + }{ + { + name: "basic config without TLS", + input: agent.AgentConfig{ + CertFile: "", + KeyFile: "", + ServerCAFile: "", + ClientCAFile: "", + AttestedTls: false, + }, + expected: ProxyConfig{ + Port: "7002", + CertFile: "", + KeyFile: "", + ServerCAFile: "", + ClientCAFile: "", + AttestedTLS: false, + }, + }, + { + name: "config with regular TLS", + input: agent.AgentConfig{ + CertFile: "/path/to/cert.pem", + KeyFile: "/path/to/key.pem", + ServerCAFile: "/path/to/server-ca.pem", + ClientCAFile: "/path/to/client-ca.pem", + AttestedTls: false, + }, + expected: ProxyConfig{ + Port: "7002", + CertFile: "/path/to/cert.pem", + KeyFile: "/path/to/key.pem", + ServerCAFile: "/path/to/server-ca.pem", + ClientCAFile: "/path/to/client-ca.pem", + AttestedTLS: false, + }, + }, + { + name: "config with attested TLS", + input: agent.AgentConfig{ + CertFile: "", + KeyFile: "", + ServerCAFile: "/path/to/server-ca.pem", + ClientCAFile: "/path/to/client-ca.pem", + AttestedTls: true, + }, + expected: ProxyConfig{ + Port: "7002", + CertFile: "", + KeyFile: "", + ServerCAFile: "/path/to/server-ca.pem", + ClientCAFile: "/path/to/client-ca.pem", + AttestedTLS: true, + }, + }, + { + name: "config with mTLS", + input: agent.AgentConfig{ + CertFile: "/path/to/cert.pem", + KeyFile: "/path/to/key.pem", + ServerCAFile: "/path/to/server-ca.pem", + ClientCAFile: "/path/to/client-ca.pem", + AttestedTls: false, + }, + expected: ProxyConfig{ + Port: "7002", + CertFile: "/path/to/cert.pem", + KeyFile: "/path/to/key.pem", + ServerCAFile: "/path/to/server-ca.pem", + ClientCAFile: "/path/to/client-ca.pem", + AttestedTLS: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := AgentConfigToProxyConfig(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestComputationToProxyContext tests conversion from Computation to ProxyContext. +func TestComputationToProxyContext(t *testing.T) { + tests := []struct { + name string + input agent.Computation + expected ProxyContext + }{ + { + name: "computation with name", + input: agent.Computation{ + ID: "comp-123", + Name: "test-computation", + Description: "A test computation", + }, + expected: ProxyContext{ + ID: "comp-123", + Name: "test-computation", + }, + }, + { + name: "computation without name", + input: agent.Computation{ + ID: "comp-456", + Name: "", + Description: "Another test computation", + }, + expected: ProxyContext{ + ID: "comp-456", + Name: "", + }, + }, + { + name: "computation with special characters in name", + input: agent.Computation{ + ID: "comp-789", + Name: "test-computation-with-dashes_and_underscores", + Description: "Computation with special chars", + }, + expected: ProxyContext{ + ID: "comp-789", + Name: "test-computation-with-dashes_and_underscores", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ComputationToProxyContext(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestAgentConfigToProxyConfigPortIsFixed tests that port is always set to 7002. +func TestAgentConfigToProxyConfigPortIsFixed(t *testing.T) { + configs := []agent.AgentConfig{ + {}, + {CertFile: "/cert.pem", KeyFile: "/key.pem"}, + {AttestedTls: true}, + } + + for i, cfg := range configs { + result := AgentConfigToProxyConfig(cfg) + assert.Equal(t, "7002", result.Port, "Port should always be 7002 for config %d", i) + } +} diff --git a/pkg/ingress/proxy.go b/pkg/ingress/proxy.go new file mode 100644 index 00000000..fb0deeef --- /dev/null +++ b/pkg/ingress/proxy.go @@ -0,0 +1,223 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package ingress + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + + "github.com/ultravioletrs/cocos/pkg/atls" + "github.com/ultravioletrs/cocos/pkg/server" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// ProxyConfig contains configuration for starting a proxy instance. +type ProxyConfig struct { + Port string + CertFile string + KeyFile string + ServerCAFile string + ClientCAFile string + AttestedTLS bool +} + +// ProxyContext provides context information for logging and tracking. +type ProxyContext struct { + ID string + Name string +} + +// ProxyServer manages ingress proxy instances. +type ProxyServer interface { + Start(cfg ProxyConfig, ctx ProxyContext) error + Stop() error +} + +type proxyServer struct { + mu sync.RWMutex + logger *slog.Logger + backendURL *url.URL + certProvider atls.CertificateProvider + httpServer *http.Server + started bool + stopped bool +} + +// NewProxyServer creates a new ingress proxy server manager. +func NewProxyServer(logger *slog.Logger, backendURL *url.URL, certProvider atls.CertificateProvider) ProxyServer { + return &proxyServer{ + logger: logger, + backendURL: backendURL, + certProvider: certProvider, + } +} + +// Start starts the proxy server with the given configuration. +func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.started { + return fmt.Errorf("proxy server already started") + } + if p.stopped { + return fmt.Errorf("proxy server already stopped") + } + + if cfg.Port == "" { + cfg.Port = "7002" + } + + addr := fmt.Sprintf("0.0.0.0:%s", cfg.Port) + + // Configure Reverse Proxy + var rp *httputil.ReverseProxy + + // Check if backend is Unix socket or TCP + if p.backendURL.Scheme == "unix" { + // For Unix socket backend, we need to manually configure the reverse proxy + // because NewSingleHostReverseProxy doesn't support unix:// scheme + targetURL := &url.URL{ + Scheme: "http", + Host: "unix", + } + rp = httputil.NewSingleHostReverseProxy(targetURL) + + // Override the Director to not modify the request + originalDirector := rp.Director + rp.Director = func(req *http.Request) { + originalDirector(req) + // Set the URL to point to the backend service + req.URL.Scheme = "http" + req.URL.Host = "unix" + } + + // Configure Transport for Unix socket with HTTP/2 + rp.Transport = &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + var d net.Dialer + // Use Unix socket path from URL + return d.DialContext(ctx, "unix", p.backendURL.Path) + }, + } + } else { + // TCP backend + rp = httputil.NewSingleHostReverseProxy(p.backendURL) + rp.Transport = &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + } + } + + // Wrap handler with h2c for HTTP/2 cleartext support (required for gRPC without TLS) + h2cHandler := h2c.NewHandler(rp, &http2.Server{}) + + p.httpServer = &http.Server{ + Addr: addr, + Handler: h2cHandler, + } + + // Configure TLS + var tlsConfig *tls.Config + contextDesc := fmt.Sprintf("context %s", ctx.ID) + if ctx.Name != "" { + contextDesc = fmt.Sprintf("%s (%s)", ctx.Name, ctx.ID) + } + + if cfg.AttestedTLS { + if p.certProvider == nil { + return fmt.Errorf("attested TLS requested but no certificate provider available") + } + tlsConfig = &tls.Config{ + GetCertificate: p.certProvider.GetCertificate, + ClientAuth: tls.NoClientCert, + NextProtos: []string{"h2", "http/1.1"}, + } + + mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile) + if err != nil { + return fmt.Errorf("failed to configure certificate authorities: %w", err) + } + + if mtls { + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested mTLS for %s", addr, contextDesc)) + } else { + p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested TLS for %s", addr, contextDesc)) + } + } else if cfg.CertFile != "" && cfg.KeyFile != "" { + // Regular TLS + tlsSetup, err := server.SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile) + if err != nil { + return fmt.Errorf("failed to setup TLS: %w", err) + } + tlsConfig = tlsSetup.Config + tlsConfig.NextProtos = []string{"h2", "http/1.1"} + + if tlsSetup.MTLS { + p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with mTLS for %s", addr, contextDesc)) + } else { + p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with TLS for %s", addr, contextDesc)) + } + } else { + p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s without TLS for %s", addr, contextDesc)) + } + + p.started = true + + // Start server in goroutine + go func() { + var err error + if tlsConfig != nil { + ln, listenErr := net.Listen("tcp", addr) + if listenErr != nil { + p.logger.Error(fmt.Sprintf("failed to listen: %s", listenErr)) + return + } + tlsLn := tls.NewListener(ln, tlsConfig) + err = p.httpServer.Serve(tlsLn) + } else { + err = p.httpServer.ListenAndServe() + } + + if err != nil && err != http.ErrServerClosed { + p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", err)) + } + }() + + return nil +} + +// Stop stops the proxy server. +func (p *proxyServer) Stop() error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.stopped { + return nil + } + p.stopped = true + + if p.httpServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*1000000000) // 5 seconds + defer cancel() + if err := p.httpServer.Shutdown(ctx); err != nil { + return fmt.Errorf("failed to shutdown server: %w", err) + } + p.logger.Info("ingress-proxy stopped") + } + + return nil +} diff --git a/pkg/ingress/proxy_test.go b/pkg/ingress/proxy_test.go new file mode 100644 index 00000000..7c2eda5f --- /dev/null +++ b/pkg/ingress/proxy_test.go @@ -0,0 +1,477 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package ingress + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "log/slog" + "math/big" + "net" + "net/http" + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/pkg/atls/mocks" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func createTempCert(t *testing.T) (certFile, keyFile string) { + t.Helper() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + require.NoError(t, err) + + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "server.crt") + keyPath := filepath.Join(tmpDir, "server.key") + + certOut, err := os.Create(certPath) + require.NoError(t, err) + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + require.NoError(t, err) + certOut.Close() + + keyOut, err := os.Create(keyPath) + require.NoError(t, err) + b, err := x509.MarshalECPrivateKey(priv) + require.NoError(t, err) + err = pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}) + require.NoError(t, err) + keyOut.Close() + + return certPath, keyPath +} + +func getBackendURL() *url.URL { + u, _ := url.Parse("http://localhost:8080") + return u +} + +// TestNewProxyServer tests the creation of a new proxy server. +func TestNewProxyServer(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + require.NotNil(t, ps) +} + +// TestProxyStartStop tests basic start and stop operations. +func TestProxyStartStop(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "test-1", Name: "test-proxy"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = ps.Stop() + require.NoError(t, err) +} + +// TestProxyStartWithoutPort tests proxy without explicit port. +func TestProxyStartWithoutPort(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + cfg := ProxyConfig{Port: ""} + ctx := ProxyContext{ID: "test-2"} + + err := ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + time.Sleep(100 * time.Millisecond) +} + +// TestProxyStartAlreadyStarted tests error when starting twice. +func TestProxyStartAlreadyStarted(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "test-3"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + time.Sleep(100 * time.Millisecond) + + err = ps.Start(cfg, ctx) + assert.Error(t, err) + assert.Equal(t, "proxy server already started", err.Error()) +} + +// TestProxyStartAfterStopped tests error when starting after stop. +func TestProxyStartAfterStopped(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "test-4"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + err = ps.Stop() + require.NoError(t, err) + + err = ps.Start(cfg, ctx) + assert.Error(t, err) + // After stop, attempts to start will fail with "already started" error first + assert.Contains(t, err.Error(), "proxy server already") +} + +// TestProxyWithName tests proxy context with name. +func TestProxyWithName(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "id-1", Name: "named-proxy"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + time.Sleep(100 * time.Millisecond) +} + +// TestProxyWithoutName tests proxy context without name. +func TestProxyWithoutName(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "id-only"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + time.Sleep(100 * time.Millisecond) +} + +// TestProxyMultipleStops tests multiple stop calls. +func TestProxyMultipleStops(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)} + ctx := ProxyContext{ID: "test-multi-stop"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + err = ps.Stop() + require.NoError(t, err) + + err = ps.Stop() + require.NoError(t, err) +} + +// TestProxyWithoutTLS tests proxy without TLS. +func TestProxyWithoutTLS(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + AttestedTLS: false, + } + ctx := ProxyContext{ID: "test-no-tls"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + time.Sleep(100 * time.Millisecond) +} + +func TestProxyWithUnixBackend(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + // Create a temp directory for socket + dir := t.TempDir() + sockPath := filepath.Join(dir, "backend.sock") + + // Start a dummy backend on unix socket + l, err := net.Listen("unix", sockPath) + require.NoError(t, err) + defer l.Close() + + backendCalled := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCalled = true + w.WriteHeader(http.StatusOK) + }) + + h2s := &http2.Server{} + h2cHandler := h2c.NewHandler(handler, h2s) + + go func() { + _ = http.Serve(l, h2cHandler) + }() + + // Configure proxy to use this unix socket + backendURL, _ := url.Parse("unix://" + sockPath) + ps := NewProxyServer(logger, backendURL, nil) + + // Find free port for proxy + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + } + ctx := ProxyContext{ID: "test-unix"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + + time.Sleep(100 * time.Millisecond) + + // Make request to proxy + resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.True(t, backendCalled) +} + +func TestProxyRegularTLS(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + certFile, keyFile := createTempCert(t) + + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + CertFile: certFile, + KeyFile: keyFile, + } + ctx := ProxyContext{ID: "test-tls"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() + + time.Sleep(100 * time.Millisecond) + + // Client with skipped verification + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{Transport: tr} + + resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port)) + // Backend is not running/reachable so 502 or error is expected from reverse proxy, + // but 502 means connection to proxy succeeded. + // If the proxy itself was not working, we'd get connection refused or similar. + require.NoError(t, err) + if err == nil { + resp.Body.Close() + } +} + +func TestProxyRegularTLSInvalidFiles(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + CertFile: "non-existent.crt", + KeyFile: "non-existent.key", + } + ctx := ProxyContext{ID: "test-tls-fail"} + + err = ps.Start(cfg, ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to setup TLS") +} + +func TestProxyAttestedTLS(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + mockProvider := mocks.NewCertificateProvider(t) + // We don't expect calls during Listen, only during handshake. + // But Start logic doesn't block waiting for handshake. + + ps := NewProxyServer(logger, getBackendURL(), mockProvider) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + AttestedTLS: true, + } + ctx := ProxyContext{ID: "test-attested-tls"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() +} + +func TestProxyAttestedTLSMissingProvider(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + ps := NewProxyServer(logger, getBackendURL(), nil) + + cfg := ProxyConfig{ + Port: "0", + AttestedTLS: true, + } + ctx := ProxyContext{ID: "test-attested-fail"} + + err := ps.Start(cfg, ctx) + assert.Error(t, err) + assert.Equal(t, "attested TLS requested but no certificate provider available", err.Error()) +} + +func TestProxyMTLS(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + certFile, _ := createTempCert(t) + + mockProvider := mocks.NewCertificateProvider(t) + + ps := NewProxyServer(logger, getBackendURL(), mockProvider) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + // Test case: AttestedTLS with ClientCAFile (mTLS) + // server.ConfigureCertificateAuthorities reads the file. + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + AttestedTLS: true, + ClientCAFile: certFile, // Use self-signed cert as CA + ServerCAFile: certFile, // Also for server CA + } + ctx := ProxyContext{ID: "test-mtls"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() +} + +func TestProxyRegularMTLS(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + certFile, keyFile := createTempCert(t) + + ps := NewProxyServer(logger, getBackendURL(), nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + cfg := ProxyConfig{ + Port: fmt.Sprintf("%d", port), + CertFile: certFile, + KeyFile: keyFile, + ServerCAFile: certFile, + ClientCAFile: certFile, + } + ctx := ProxyContext{ID: "test-regular-mtls"} + + err = ps.Start(cfg, ctx) + require.NoError(t, err) + defer func() { _ = ps.Stop() }() +} + +func TestProxyAttestedTLSInvalidCA(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + mockProvider := mocks.NewCertificateProvider(t) + + ps := NewProxyServer(logger, getBackendURL(), mockProvider) + + cfg := ProxyConfig{ + Port: "0", + AttestedTLS: true, + ServerCAFile: "non-existent.pem", + } + ctx := ProxyContext{ID: "test-attested-invalid-ca"} + + err := ps.Start(cfg, ctx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to configure certificate authorities") +} diff --git a/pkg/server/grpc/grpc.go b/pkg/server/grpc/grpc.go index 6775ac2b..f6d92018 100644 --- a/pkg/server/grpc/grpc.go +++ b/pkg/server/grpc/grpc.go @@ -9,6 +9,7 @@ import ( "fmt" "log/slog" "net" + "os" "sync" "time" @@ -111,10 +112,26 @@ func (s *Server) Start() error { grpcServerOptions = append(grpcServerOptions, creds) - // Create listener - listener, err := net.Listen("tcp", s.Address) - if err != nil { - return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + // Create listener - detect Unix socket vs TCP + var listener net.Listener + baseConfig := s.Config.GetBaseConfig() + + // Check if this is a Unix socket path (starts with /) + if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' { + // Unix socket + // Remove existing socket file if it exists + _ = os.Remove(baseConfig.Host) + + listener, err = net.Listen("unix", baseConfig.Host) + if err != nil { + return fmt.Errorf("failed to listen on Unix socket %s: %w", baseConfig.Host, err) + } + } else { + // TCP socket + listener, err = net.Listen("tcp", s.Address) + if err != nil { + return fmt.Errorf("failed to listen on port %s: %w", s.Address, err) + } } // Create and configure server @@ -160,7 +177,12 @@ func (s *Server) configureCredentials() (grpc.ServerOption, error) { } // Use insecure credentials - s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address)) + // Determine address for logging + addr := s.Address + if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' { + addr = baseConfig.Host + } + s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, addr)) return grpc.Creds(insecure.NewCredentials()), nil } diff --git a/test/cvms/main.go b/test/cvms/main.go index 03277625..4184e995 100644 --- a/test/cvms/main.go +++ b/test/cvms/main.go @@ -78,6 +78,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se return } + s.logger.Debug("sending computation run request") if err := sendMessage(&cvms.ServerStreamMessage{ Message: &cvms.ServerStreamMessage_RunReq{ RunReq: &cvms.ComputationRunReq{ @@ -98,6 +99,11 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se s.logger.Error(fmt.Sprintf("failed to send run request: %s", err)) return } + s.logger.Info("computation run request sent successfully") + + // Keep the connection alive + <-ctx.Done() + s.logger.Info("connection closed") } func main() {