mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
NOISSUE - Introduce computation runner, log forwarder, ingress, and egress proxy services. (#559)
* feat: Introduce computation runner, log forwarder, ingress, and egress proxy services. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update Go environment variable parsing and build system to use new architecture and repository. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update package sources to `sammyoina/cocos-ai` at a specific commit, add log-forwarder pre-start hook, and rename proxy binaries. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * chore: Update build system references to a specific commit and enhance logging for service connections and message processing. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * build: Update package source repositories and versions, migrate client logging to slog, and adjust ingress/egress proxy build and install steps. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug stuck Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add HTTP/2 support to egress proxy and update build system to use specific commit hashes Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: enhance egress proxy CONNECT handling, update package sources, and add gRPC test utility Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update build system for various services to a specific commit from a new repository, change agent gRPC port to 7001, and add a gRPC test client. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Migrate agent-internal gRPC communication to Unix sockets, set ingress proxy to port 7002, and update build hashes. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Remove standalone ingress-proxy systemd service and update component versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Prevent computation re-initialization in agent and update component versions across several packages. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: update package versions and enable h2c support in ingress proxy. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: refactor ingress proxy to support HTTP/2 over Unix sockets and update component versions. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update build system package sources to `ultravioletrs/cocos` and reduce agent logging verbosity. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: improve error handling in proxy commands and remove unused gRPC test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add mock service state return value in handleRunReqChunks test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: add comprehensive tests for service and proxy components Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix linter Signed-off-by: Sammy Oina <sammyoina@gmail.com> * improve coverage Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add gRPC client and ingress adapter tests, and update egress proxy tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * improve coverage Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ee52551ca4
commit
a3265bc346
@@ -146,3 +146,10 @@ packages:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
github.com/ultravioletrs/cocos/pkg/clients/grpc/runner:
|
||||
interfaces:
|
||||
Client:
|
||||
config:
|
||||
dir: '{{.InterfaceDir}}/mocks'
|
||||
structname: '{{.InterfaceName}}'
|
||||
filename: "{{.InterfaceName | lower}}.go"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
BUILD_DIR = build
|
||||
SERVICES = manager agent cli attestation-service
|
||||
SERVICES = manager agent cli attestation-service log-forwarder computation-runner egress-proxy ingress-proxy
|
||||
ATTESTATION_POLICY = attestation_policy
|
||||
CGO_ENABLED ?= 0
|
||||
GOARCH ?= amd64
|
||||
@@ -41,6 +41,8 @@ protoc:
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/cvms/cvms.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative internal/proto/attestation/v1/attestation.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/log/log.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/runner/runner.proto
|
||||
|
||||
mocks:
|
||||
mockery --config ./.mockery.yml
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
var _ fmt.Stringer = (*Datasets)(nil)
|
||||
|
||||
type AgentConfig struct {
|
||||
Port string `json:"port,omitempty"`
|
||||
CertFile string `json:"cert_file,omitempty"`
|
||||
KeyFile string `json:"server_key,omitempty"`
|
||||
ServerCAFile string `json:"server_ca_file,omitempty"`
|
||||
|
||||
@@ -105,16 +105,15 @@ func TestDecompressToContext(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAgentConfigJSON(t *testing.T) {
|
||||
config := AgentConfig{
|
||||
Port: "8080",
|
||||
cfg := AgentConfig{
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server_ca.pem",
|
||||
ClientCAFile: "client_ca.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: true,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(config)
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal AgentConfig: %v", err)
|
||||
}
|
||||
@@ -125,7 +124,7 @@ func TestAgentConfigJSON(t *testing.T) {
|
||||
t.Fatalf("Failed to unmarshal AgentConfig: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(config, unmarshaledConfig) {
|
||||
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config)
|
||||
if !reflect.DeepEqual(cfg, unmarshaledConfig) {
|
||||
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -44,13 +46,14 @@ type CVMSClient struct {
|
||||
logger *slog.Logger
|
||||
runReqManager *runRequestManager
|
||||
sp server.AgentServer
|
||||
ingressProxy ingress.ProxyServer
|
||||
storage storage.Storage
|
||||
reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error)
|
||||
grpcClient grpc.Client
|
||||
}
|
||||
|
||||
// NewClient returns new gRPC client instance.
|
||||
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) {
|
||||
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, ingressProxy ingress.ProxyServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) {
|
||||
store, err := storage.NewFileStorage(storageDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -63,6 +66,7 @@ func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueu
|
||||
logger: logger,
|
||||
runReqManager: newRunRequestManager(),
|
||||
sp: sp,
|
||||
ingressProxy: ingressProxy,
|
||||
storage: store,
|
||||
reconnectFn: reconnectFn,
|
||||
grpcClient: grpcClient,
|
||||
@@ -205,14 +209,17 @@ func (client *CVMSClient) handleAgentStateReq(mes *cvms.ServerStreamMessage_Agen
|
||||
}
|
||||
|
||||
func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error {
|
||||
client.logger.Debug("Received RunReq chunk", "id", msg.RunReqChunks.Id, "size", len(msg.RunReqChunks.Data), "isLast", msg.RunReqChunks.IsLast)
|
||||
buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast)
|
||||
|
||||
if complete {
|
||||
client.logger.Info("Received complete computation run request", "id", msg.RunReqChunks.Id, "totalSize", len(buffer))
|
||||
var runReq cvms.ComputationRunReq
|
||||
if err := proto.Unmarshal(buffer, &runReq); err != nil {
|
||||
return errors.Wrap(err, errCorruptedManifest)
|
||||
}
|
||||
|
||||
client.logger.Info("Starting computation execution", "computationId", runReq.Id, "name", runReq.Name)
|
||||
go client.executeRun(ctx, &runReq)
|
||||
}
|
||||
|
||||
@@ -246,6 +253,15 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
})
|
||||
}
|
||||
|
||||
// Check if the agent is in the correct state to initialize a new computation.
|
||||
// If the agent is already processing this computation (e.g., after a reconnection),
|
||||
// skip initialization to avoid state errors.
|
||||
currentState := client.svc.State()
|
||||
if currentState != "ReceivingManifest" {
|
||||
client.logger.Info("Agent already processing computation, skipping initialization", "state", currentState, "computationId", runReq.Id)
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.svc.InitComputation(ctx, ac); err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
@@ -267,7 +283,6 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
}
|
||||
|
||||
if err := client.sp.Start(agent.AgentConfig{
|
||||
Port: runReq.AgentConfig.Port,
|
||||
CertFile: runReq.AgentConfig.CertFile,
|
||||
KeyFile: runReq.AgentConfig.KeyFile,
|
||||
ServerCAFile: runReq.AgentConfig.ServerCaFile,
|
||||
@@ -278,6 +293,22 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
|
||||
runRes.RunRes.Error = err.Error()
|
||||
}
|
||||
|
||||
// Start ingress proxy if available
|
||||
if client.ingressProxy != nil {
|
||||
if err := client.ingressProxy.Start(
|
||||
ingress.AgentConfigToProxyConfig(agent.AgentConfig{
|
||||
CertFile: runReq.AgentConfig.CertFile,
|
||||
KeyFile: runReq.AgentConfig.KeyFile,
|
||||
ServerCAFile: runReq.AgentConfig.ServerCaFile,
|
||||
ClientCAFile: runReq.AgentConfig.ClientCaFile,
|
||||
AttestedTls: runReq.AgentConfig.AttestedTls,
|
||||
}),
|
||||
ingress.ComputationToProxyContext(ac),
|
||||
); err != nil {
|
||||
client.logger.Warn(fmt.Sprintf("failed to start ingress proxy: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if ccPlatform == attestation.Azure || ccPlatform == attestation.SNPvTPM {
|
||||
cmpJson, err := json.Marshal(ac)
|
||||
@@ -309,6 +340,12 @@ func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.S
|
||||
if err := client.sp.Stop(); err != nil {
|
||||
msg.StopComputationRes.Message = err.Error()
|
||||
}
|
||||
// Stop ingress proxy if available
|
||||
if client.ingressProxy != nil {
|
||||
if err := client.ingressProxy.Stop(); err != nil {
|
||||
client.logger.Warn(fmt.Sprintf("failed to stop ingress proxy: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
client.mu.Unlock()
|
||||
|
||||
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
|
||||
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
|
||||
servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
clientmocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
@@ -35,6 +37,21 @@ func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error {
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// mockIngressProxy is a mock implementation of the ingress proxy.
|
||||
type mockIngressProxy struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockIngressProxy) Start(config ingress.ProxyConfig, ctx ingress.ProxyContext) error {
|
||||
args := m.Called(config, ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockIngressProxy) Stop() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestManagerClient_Process(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -121,7 +138,7 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
@@ -151,7 +168,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
@@ -187,6 +204,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("State").Return("ReceivingManifest")
|
||||
mockSvc.On("InitComputation", mock.Anything, mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
@@ -216,7 +234,7 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
@@ -255,3 +273,243 @@ func TestManagerClient_timeoutRequest(t *testing.T) {
|
||||
|
||||
assert.Len(t, rm.requests, 0)
|
||||
}
|
||||
|
||||
// TestManagerClient_sendPendingMessages tests sending pending messages on reconnection.
|
||||
func TestManagerClient_sendPendingMessages(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add a pending message to storage
|
||||
testMsg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
err = client.storage.Add(testMsg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Mock successful send
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
|
||||
// Load and send pending messages
|
||||
pending, err := client.storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, pending, 1)
|
||||
|
||||
client.sendPendingMessages(pending)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_sendPendingMessagesWithError tests pending message send failure.
|
||||
func TestManagerClient_sendPendingMessagesWithError(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testMsg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Mock failed send
|
||||
mockStream.On("Send", mock.Anything).Return(assert.AnError)
|
||||
|
||||
pending := []storage.Message{
|
||||
{
|
||||
Message: testMsg,
|
||||
Time: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
client.sendPendingMessages(pending)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_addChunkTimeout tests chunk timeout in runRequestManager.
|
||||
func TestManagerClient_addChunkTimeout(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
|
||||
// Add first chunk
|
||||
chunk1 := []byte("chunk1")
|
||||
buffer, complete := rm.addChunk("test-id", chunk1, false)
|
||||
assert.Nil(t, buffer)
|
||||
assert.False(t, complete)
|
||||
|
||||
// Verify request exists
|
||||
rm.mu.Lock()
|
||||
assert.Contains(t, rm.requests, "test-id")
|
||||
rm.mu.Unlock()
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(35 * time.Second) // runReqTimeout is 30 seconds
|
||||
|
||||
// Verify request was removed
|
||||
rm.mu.Lock()
|
||||
assert.NotContains(t, rm.requests, "test-id")
|
||||
rm.mu.Unlock()
|
||||
}
|
||||
|
||||
// TestManagerClient_addChunkMultiple tests adding multiple chunks.
|
||||
func TestManagerClient_addChunkMultiple(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
|
||||
chunk1 := []byte("chunk1")
|
||||
chunk2 := []byte("chunk2")
|
||||
chunk3 := []byte("chunk3")
|
||||
|
||||
// Add chunks
|
||||
buffer, complete := rm.addChunk("test-id", chunk1, false)
|
||||
assert.Nil(t, buffer)
|
||||
assert.False(t, complete)
|
||||
|
||||
buffer, complete = rm.addChunk("test-id", chunk2, false)
|
||||
assert.Nil(t, buffer)
|
||||
assert.False(t, complete)
|
||||
|
||||
buffer, complete = rm.addChunk("test-id", chunk3, true)
|
||||
assert.NotNil(t, buffer)
|
||||
assert.True(t, complete)
|
||||
|
||||
expected := append(append(chunk1, chunk2...), chunk3...)
|
||||
assert.Equal(t, expected, buffer)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleStopComputationWithIngressProxy tests stop with ingress proxy.
|
||||
func TestManagerClient_handleStopComputationWithIngressProxy(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
mockIngressProxy := new(mockIngressProxy)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("StopComputation", mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Stop").Return(nil)
|
||||
mockIngressProxy.On("Stop").Return(nil)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
mockServerSvc.AssertExpectations(t)
|
||||
mockIngressProxy.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
}
|
||||
|
||||
// TestManagerClient_handleStopComputationWithIngressProxyError tests stop with ingress proxy error.
|
||||
func TestManagerClient_handleStopComputationWithIngressProxyError(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
mockIngressProxy := new(mockIngressProxy)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("StopComputation", mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Stop").Return(nil)
|
||||
mockIngressProxy.On("Stop").Return(assert.AnError)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockIngressProxy.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestManagerClient_sendMessage tests sendMessage with timeout.
|
||||
func TestManagerClient_sendMessage(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 1)
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client.sendMessage(msg)
|
||||
|
||||
select {
|
||||
case received := <-messageQueue:
|
||||
assert.Equal(t, msg, received)
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Message not received")
|
||||
}
|
||||
}
|
||||
|
||||
// TestManagerClient_sendMessageTimeout tests sendMessage timeout when queue is full.
|
||||
func TestManagerClient_sendMessageTimeout(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServer)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage) // No buffer
|
||||
logger := mglog.NewMock()
|
||||
grpcClient := new(clientmocks.Client)
|
||||
|
||||
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg := &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: "test-id",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Don't read from queue, so sendMessage will timeout
|
||||
client.sendMessage(msg)
|
||||
|
||||
// Should complete without blocking
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
@@ -52,16 +53,20 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
|
||||
return errors.New("failed to get peer info")
|
||||
}
|
||||
|
||||
slog.Info("client connected to cvms server", "address", client.Addr.String())
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Info("receive goroutine context done", "address", client.Addr.String())
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
slog.Error("failed to receive from stream", "address", client.Addr.String(), "error", err)
|
||||
return err
|
||||
}
|
||||
s.incoming <- req
|
||||
@@ -85,10 +90,13 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
|
||||
}
|
||||
|
||||
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
|
||||
slog.Info("send goroutine Run() returned", "address", client.Addr.String())
|
||||
return nil
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
err := eg.Wait()
|
||||
slog.Info("stream closed", "address", client.Addr.String(), "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *grpcServer) sendRunReqInChunks(stream cvms.Service_ProcessServer, runReq *cvms.ComputationRunReq) error {
|
||||
|
||||
@@ -19,8 +19,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "agent"
|
||||
defSvcGRPCPort = "7002"
|
||||
svcName = "agent"
|
||||
defSvcGRPCSocket = "/run/cocos/agent.sock"
|
||||
)
|
||||
|
||||
type AgentServer interface {
|
||||
@@ -46,15 +46,11 @@ func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider
|
||||
}
|
||||
|
||||
func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
if cfg.Port == "" {
|
||||
cfg.Port = defSvcGRPCPort
|
||||
}
|
||||
|
||||
agentGrpcServerConfig := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
Config: server.Config{
|
||||
Host: as.host,
|
||||
Port: cfg.Port,
|
||||
Host: defSvcGRPCSocket,
|
||||
Port: "",
|
||||
CertFile: cfg.CertFile,
|
||||
KeyFile: cfg.KeyFile,
|
||||
ServerCAFile: cfg.ServerCAFile,
|
||||
|
||||
@@ -97,7 +97,6 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
{
|
||||
name: "successful start with default port",
|
||||
cfg: agent.AgentConfig{
|
||||
Port: "",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
@@ -131,7 +130,6 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
{
|
||||
name: "successful start with custom port",
|
||||
cfg: agent.AgentConfig{
|
||||
Port: "8080",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
@@ -165,7 +163,6 @@ func TestAgentServer_Start(t *testing.T) {
|
||||
{
|
||||
name: "start with minimal config",
|
||||
cfg: agent.AgentConfig{
|
||||
Port: "9090",
|
||||
AttestedTls: false,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
@@ -243,9 +240,7 @@ func TestAgentServer_Stop(t *testing.T) {
|
||||
{
|
||||
name: "stop started server",
|
||||
setupServer: func(server AgentServer) error {
|
||||
cfg := agent.AgentConfig{
|
||||
Port: "7004",
|
||||
}
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-stop-computation",
|
||||
Name: "Stop Test",
|
||||
@@ -304,7 +299,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
// Start the server
|
||||
cfg := agent.AgentConfig{Port: "7005"}
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-multiple-stop",
|
||||
Name: "Multiple Stop Test",
|
||||
@@ -347,7 +342,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
logger, svc, host, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, nil)
|
||||
|
||||
cfg := agent.AgentConfig{Port: "7006"}
|
||||
cfg := agent.AgentConfig{}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-restart",
|
||||
Name: "Restart Test",
|
||||
@@ -378,7 +373,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start again with different config
|
||||
cfg2 := agent.AgentConfig{Port: "7007"}
|
||||
cfg2 := agent.AgentConfig{}
|
||||
cmp2 := agent.Computation{
|
||||
ID: "test-restart-2",
|
||||
Name: "Restart Test 2",
|
||||
@@ -422,7 +417,6 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
{
|
||||
name: "valid config with all fields",
|
||||
config: agent.AgentConfig{
|
||||
Port: "8080",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
@@ -451,10 +445,8 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid config with minimal fields",
|
||||
config: agent.AgentConfig{
|
||||
Port: "9090",
|
||||
},
|
||||
name: "valid config with minimal fields",
|
||||
config: agent.AgentConfig{},
|
||||
cmp: agent.Computation{
|
||||
ID: "minimal-config-test",
|
||||
Name: "Minimal Config Test",
|
||||
@@ -477,10 +469,8 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "config with empty port uses default",
|
||||
config: agent.AgentConfig{
|
||||
Port: "",
|
||||
},
|
||||
name: "config with empty port uses default",
|
||||
config: agent.AgentConfig{},
|
||||
cmp: agent.Computation{
|
||||
ID: "default-port-test",
|
||||
Name: "Default Port Test",
|
||||
@@ -505,11 +495,9 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
if tt.valid {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify default port is used when empty
|
||||
if tt.config.Port == "" {
|
||||
agentSrv := server.(*agentServer)
|
||||
assert.NotNil(t, agentSrv.gs)
|
||||
}
|
||||
// Verify server started successfully
|
||||
agentSrv := server.(*agentServer)
|
||||
assert.NotNil(t, agentSrv.gs)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if err := server.Stop(); err != nil {
|
||||
@@ -526,5 +514,5 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
assert.Equal(t, "agent", svcName)
|
||||
assert.Equal(t, "7002", defSvcGRPCPort)
|
||||
assert.Equal(t, "/run/cocos/agent.sock", defSvcGRPCSocket)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.8
|
||||
// protoc v6.33.1
|
||||
// source: agent/log/log.proto
|
||||
|
||||
package log
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type LogEntry struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *LogEntry) Reset() {
|
||||
*x = LogEntry{}
|
||||
mi := &file_agent_log_log_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *LogEntry) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*LogEntry) ProtoMessage() {}
|
||||
|
||||
func (x *LogEntry) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_log_log_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use LogEntry.ProtoReflect.Descriptor instead.
|
||||
func (*LogEntry) Descriptor() ([]byte, []int) {
|
||||
return file_agent_log_log_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetMessage() string {
|
||||
if x != nil {
|
||||
return x.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetLevel() string {
|
||||
if x != nil {
|
||||
return x.Level
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *LogEntry) GetTimestamp() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type EventEntry struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"` // JSON payload
|
||||
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
|
||||
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *EventEntry) Reset() {
|
||||
*x = EventEntry{}
|
||||
mi := &file_agent_log_log_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *EventEntry) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*EventEntry) ProtoMessage() {}
|
||||
|
||||
func (x *EventEntry) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_log_log_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use EventEntry.ProtoReflect.Descriptor instead.
|
||||
func (*EventEntry) Descriptor() ([]byte, []int) {
|
||||
return file_agent_log_log_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetEventType() string {
|
||||
if x != nil {
|
||||
return x.EventType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetTimestamp() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetDetails() []byte {
|
||||
if x != nil {
|
||||
return x.Details
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetOriginator() string {
|
||||
if x != nil {
|
||||
return x.Originator
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EventEntry) GetStatus() string {
|
||||
if x != nil {
|
||||
return x.Status
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_agent_log_log_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_agent_log_log_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x13agent/log/log.proto\x12\x03log\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x9b\x01\n" +
|
||||
"\bLogEntry\x12\x18\n" +
|
||||
"\amessage\x18\x01 \x01(\tR\amessage\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x02 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05level\x18\x03 \x01(\tR\x05level\x128\n" +
|
||||
"\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\xde\x01\n" +
|
||||
"\n" +
|
||||
"EventEntry\x12\x1d\n" +
|
||||
"\n" +
|
||||
"event_type\x18\x01 \x01(\tR\teventType\x128\n" +
|
||||
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x03 \x01(\tR\rcomputationId\x12\x18\n" +
|
||||
"\adetails\x18\x04 \x01(\fR\adetails\x12\x1e\n" +
|
||||
"\n" +
|
||||
"originator\x18\x05 \x01(\tR\n" +
|
||||
"originator\x12\x16\n" +
|
||||
"\x06status\x18\x06 \x01(\tR\x06status2v\n" +
|
||||
"\fLogCollector\x120\n" +
|
||||
"\aSendLog\x12\r.log.LogEntry\x1a\x16.google.protobuf.Empty\x124\n" +
|
||||
"\tSendEvent\x12\x0f.log.EventEntry\x1a\x16.google.protobuf.EmptyB\aZ\x05./logb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_log_log_proto_rawDescOnce sync.Once
|
||||
file_agent_log_log_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_agent_log_log_proto_rawDescGZIP() []byte {
|
||||
file_agent_log_log_proto_rawDescOnce.Do(func() {
|
||||
file_agent_log_log_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)))
|
||||
})
|
||||
return file_agent_log_log_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_log_log_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_agent_log_log_proto_goTypes = []any{
|
||||
(*LogEntry)(nil), // 0: log.LogEntry
|
||||
(*EventEntry)(nil), // 1: log.EventEntry
|
||||
(*timestamppb.Timestamp)(nil), // 2: google.protobuf.Timestamp
|
||||
(*emptypb.Empty)(nil), // 3: google.protobuf.Empty
|
||||
}
|
||||
var file_agent_log_log_proto_depIdxs = []int32{
|
||||
2, // 0: log.LogEntry.timestamp:type_name -> google.protobuf.Timestamp
|
||||
2, // 1: log.EventEntry.timestamp:type_name -> google.protobuf.Timestamp
|
||||
0, // 2: log.LogCollector.SendLog:input_type -> log.LogEntry
|
||||
1, // 3: log.LogCollector.SendEvent:input_type -> log.EventEntry
|
||||
3, // 4: log.LogCollector.SendLog:output_type -> google.protobuf.Empty
|
||||
3, // 5: log.LogCollector.SendEvent:output_type -> google.protobuf.Empty
|
||||
4, // [4:6] is the sub-list for method output_type
|
||||
2, // [2:4] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_log_log_proto_init() }
|
||||
func file_agent_log_log_proto_init() {
|
||||
if File_agent_log_log_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_agent_log_log_proto_goTypes,
|
||||
DependencyIndexes: file_agent_log_log_proto_depIdxs,
|
||||
MessageInfos: file_agent_log_log_proto_msgTypes,
|
||||
}.Build()
|
||||
File_agent_log_log_proto = out.File
|
||||
file_agent_log_log_proto_goTypes = nil
|
||||
file_agent_log_log_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package log;
|
||||
|
||||
option go_package = "./log";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
service LogCollector {
|
||||
rpc SendLog(LogEntry) returns (google.protobuf.Empty);
|
||||
rpc SendEvent(EventEntry) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
||||
message LogEntry {
|
||||
string message = 1;
|
||||
string computation_id = 2;
|
||||
string level = 3;
|
||||
google.protobuf.Timestamp timestamp = 4;
|
||||
}
|
||||
|
||||
message EventEntry {
|
||||
string event_type = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
string computation_id = 3;
|
||||
bytes details = 4; // JSON payload
|
||||
string originator = 5;
|
||||
string status = 6;
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v6.33.1
|
||||
// source: agent/log/log.proto
|
||||
|
||||
package log
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
LogCollector_SendLog_FullMethodName = "/log.LogCollector/SendLog"
|
||||
LogCollector_SendEvent_FullMethodName = "/log.LogCollector/SendEvent"
|
||||
)
|
||||
|
||||
// LogCollectorClient is the client API for LogCollector service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type LogCollectorClient interface {
|
||||
SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
type logCollectorClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewLogCollectorClient(cc grpc.ClientConnInterface) LogCollectorClient {
|
||||
return &logCollectorClient{cc}
|
||||
}
|
||||
|
||||
func (c *logCollectorClient) SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, LogCollector_SendLog_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *logCollectorClient) SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, LogCollector_SendEvent_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LogCollectorServer is the server API for LogCollector service.
|
||||
// All implementations must embed UnimplementedLogCollectorServer
|
||||
// for forward compatibility.
|
||||
type LogCollectorServer interface {
|
||||
SendLog(context.Context, *LogEntry) (*emptypb.Empty, error)
|
||||
SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error)
|
||||
mustEmbedUnimplementedLogCollectorServer()
|
||||
}
|
||||
|
||||
// UnimplementedLogCollectorServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedLogCollectorServer struct{}
|
||||
|
||||
func (UnimplementedLogCollectorServer) SendLog(context.Context, *LogEntry) (*emptypb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SendLog not implemented")
|
||||
}
|
||||
func (UnimplementedLogCollectorServer) SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SendEvent not implemented")
|
||||
}
|
||||
func (UnimplementedLogCollectorServer) mustEmbedUnimplementedLogCollectorServer() {}
|
||||
func (UnimplementedLogCollectorServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeLogCollectorServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to LogCollectorServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeLogCollectorServer interface {
|
||||
mustEmbedUnimplementedLogCollectorServer()
|
||||
}
|
||||
|
||||
func RegisterLogCollectorServer(s grpc.ServiceRegistrar, srv LogCollectorServer) {
|
||||
// If the following call pancis, it indicates UnimplementedLogCollectorServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&LogCollector_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _LogCollector_SendLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(LogEntry)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LogCollectorServer).SendLog(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: LogCollector_SendLog_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LogCollectorServer).SendLog(ctx, req.(*LogEntry))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _LogCollector_SendEvent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EventEntry)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(LogCollectorServer).SendEvent(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: LogCollector_SendEvent_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(LogCollectorServer).SendEvent(ctx, req.(*EventEntry))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// LogCollector_ServiceDesc is the grpc.ServiceDesc for LogCollector service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var LogCollector_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "log.LogCollector",
|
||||
HandlerType: (*LogCollectorServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "SendLog",
|
||||
Handler: _LogCollector_SendLog_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SendEvent",
|
||||
Handler: _LogCollector_SendEvent_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "agent/log/log.proto",
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/log"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
var _ log.LogCollectorServer = (*LogForwarder)(nil)
|
||||
|
||||
type LogForwarder struct {
|
||||
log.UnimplementedLogCollectorServer
|
||||
cvmsClient cvms.ServiceClient
|
||||
logger *slog.Logger
|
||||
logQueue chan *cvms.ClientStreamMessage
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, cvmsClient cvms.ServiceClient, queue chan *cvms.ClientStreamMessage) *LogForwarder {
|
||||
return &LogForwarder{
|
||||
cvmsClient: cvmsClient,
|
||||
logger: logger,
|
||||
logQueue: queue,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LogForwarder) SendLog(ctx context.Context, req *log.LogEntry) (*emptypb.Empty, error) {
|
||||
s.logQueue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &cvms.AgentLog{
|
||||
Message: req.Message,
|
||||
ComputationId: req.ComputationId,
|
||||
Level: req.Level,
|
||||
Timestamp: req.Timestamp,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *LogForwarder) SendEvent(ctx context.Context, req *log.EventEntry) (*emptypb.Empty, error) {
|
||||
s.logQueue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &cvms.AgentEvent{
|
||||
EventType: req.EventType,
|
||||
Timestamp: req.Timestamp,
|
||||
ComputationId: req.ComputationId,
|
||||
Details: req.Details,
|
||||
Originator: req.Originator,
|
||||
Status: req.Status,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// TestNewLogForwarder tests the creation of a new log forwarder.
|
||||
func TestNewLogForwarder(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
require.NotNil(t, lf)
|
||||
assert.NotNil(t, lf.logger)
|
||||
assert.Nil(t, lf.cvmsClient)
|
||||
assert.NotNil(t, lf.logQueue)
|
||||
}
|
||||
|
||||
// TestSendLog tests sending a log entry.
|
||||
func TestSendLog(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.LogEntry{
|
||||
Message: "Test log message",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify message was queued
|
||||
select {
|
||||
case msg := <-queue:
|
||||
require.NotNil(t, msg)
|
||||
agentLog := msg.GetAgentLog()
|
||||
assert.NotNil(t, agentLog)
|
||||
assert.Equal(t, "Test log message", agentLog.Message)
|
||||
assert.Equal(t, "computation-1", agentLog.ComputationId)
|
||||
assert.Equal(t, "INFO", agentLog.Level)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendEvent tests sending an event entry.
|
||||
func TestSendEvent(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
details, err := json.Marshal(map[string]string{"key": "value"})
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &log.EventEntry{
|
||||
EventType: "COMPUTATION_STARTED",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
ComputationId: "computation-1",
|
||||
Details: details,
|
||||
Originator: "runner",
|
||||
Status: "SUCCESS",
|
||||
}
|
||||
|
||||
resp, err := lf.SendEvent(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
// Verify message was queued
|
||||
select {
|
||||
case msg := <-queue:
|
||||
require.NotNil(t, msg)
|
||||
agentEvent := msg.GetAgentEvent()
|
||||
assert.NotNil(t, agentEvent)
|
||||
assert.Equal(t, "COMPUTATION_STARTED", agentEvent.EventType)
|
||||
assert.Equal(t, "computation-1", agentEvent.ComputationId)
|
||||
assert.Equal(t, "runner", agentEvent.Originator)
|
||||
assert.Equal(t, "SUCCESS", agentEvent.Status)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendMultipleLogs tests sending multiple log entries.
|
||||
func TestSendMultipleLogs(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
req := &log.LogEntry{
|
||||
Message: "Log message",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, len(queue))
|
||||
}
|
||||
|
||||
// TestSendEventWithVariousTypes tests sending events with different types.
|
||||
func TestSendEventWithVariousTypes(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
eventTypes := []string{"STARTED", "RUNNING", "COMPLETED", "FAILED"}
|
||||
for _, eventType := range eventTypes {
|
||||
details, err := json.Marshal(map[string]string{"type": eventType})
|
||||
require.NoError(t, err)
|
||||
req := &log.EventEntry{
|
||||
EventType: eventType,
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
ComputationId: "computation-1",
|
||||
Details: details,
|
||||
Originator: "runner",
|
||||
Status: "OK",
|
||||
}
|
||||
|
||||
resp, err := lf.SendEvent(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 4, len(queue))
|
||||
}
|
||||
|
||||
// TestSendLogWithEmptyMessage tests sending log with empty message.
|
||||
func TestSendLogWithEmptyMessage(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.LogEntry{
|
||||
Message: "",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
select {
|
||||
case msg := <-queue:
|
||||
agentLog := msg.GetAgentLog()
|
||||
assert.Equal(t, "", agentLog.Message)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendEventWithNilDetails tests sending event with nil details.
|
||||
func TestSendEventWithNilDetails(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.EventEntry{
|
||||
EventType: "TEST_EVENT",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
ComputationId: "computation-1",
|
||||
Details: nil,
|
||||
Originator: "test",
|
||||
Status: "OK",
|
||||
}
|
||||
|
||||
resp, err := lf.SendEvent(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
select {
|
||||
case msg := <-queue:
|
||||
agentEvent := msg.GetAgentEvent()
|
||||
assert.Nil(t, agentEvent.Details)
|
||||
default:
|
||||
t.Fatal("No message in queue")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendLogWithVariousLevels tests sending logs with various severity levels.
|
||||
func TestSendLogWithVariousLevels(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
levels := []string{"DEBUG", "INFO", "WARN", "ERROR"}
|
||||
for _, level := range levels {
|
||||
req := &log.LogEntry{
|
||||
Message: "Test " + level,
|
||||
ComputationId: "computation-1",
|
||||
Level: level,
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 4, len(queue))
|
||||
}
|
||||
|
||||
// TestSendLogWithDifferentComputationIds tests sending logs with different computation IDs.
|
||||
func TestSendLogWithDifferentComputationIds(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
req := &log.LogEntry{
|
||||
Message: "Message",
|
||||
ComputationId: "computation-" + string(rune(48+i)),
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(queue))
|
||||
}
|
||||
|
||||
// TestQueueBehavior tests that queue is properly used.
|
||||
func TestQueueBehavior(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 1)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
req := &log.LogEntry{
|
||||
Message: "Test",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
resp, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, 1, len(queue))
|
||||
}
|
||||
|
||||
// TestConcurrentSendLog tests concurrent log sending.
|
||||
func TestConcurrentSendLog(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
queue := make(chan *cvms.ClientStreamMessage, 100)
|
||||
|
||||
lf := New(logger, nil, queue)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(id int) {
|
||||
req := &log.LogEntry{
|
||||
Message: "Concurrent log",
|
||||
ComputationId: "computation-1",
|
||||
Level: "INFO",
|
||||
Timestamp: timestamppb.New(time.Now()),
|
||||
}
|
||||
|
||||
_, err := lf.SendLog(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Give goroutines time to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Should have received all messages
|
||||
assert.True(t, len(queue) > 0)
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
)
|
||||
|
||||
type adapter struct {
|
||||
client logclient.Client
|
||||
svc string
|
||||
}
|
||||
|
||||
func NewAdapter(client logclient.Client, svc string) events.Service {
|
||||
return &adapter{
|
||||
client: client,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *adapter) SendEvent(cmpID, event, status string, details json.RawMessage) {
|
||||
err := a.client.SendEvent(context.Background(), &logpb.EventEntry{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Details: details,
|
||||
Originator: a.svc,
|
||||
Status: status,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("failed to send event to log-forwarder", "error", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
)
|
||||
|
||||
const testServiceName = "test-service"
|
||||
|
||||
// mockLogClient is a mock implementation of the log client.
|
||||
type mockLogClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockLogClient) SendLog(ctx context.Context, entry *logpb.LogEntry) error {
|
||||
args := m.Called(ctx, entry)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLogClient) SendEvent(ctx context.Context, entry *logpb.EventEntry) error {
|
||||
args := m.Called(ctx, entry)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockLogClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
// TestNewAdapter tests creating a new adapter.
|
||||
func TestNewAdapter(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
assert.NotNil(t, adapter)
|
||||
}
|
||||
|
||||
// TestSendEvent tests sending an event successfully.
|
||||
func TestSendEvent(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
cmpID := "test-computation-id"
|
||||
event := "computation.started"
|
||||
status := "success"
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
expectedEntry := &logpb.EventEntry{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Details: details,
|
||||
Originator: svc,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
|
||||
|
||||
adapter.SendEvent(cmpID, event, status, details)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
mockClient.AssertCalled(t, "SendEvent", mock.Anything, expectedEntry)
|
||||
}
|
||||
|
||||
// TestSendEventWithError tests sending an event when client returns an error.
|
||||
func TestSendEventWithError(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
cmpID := "test-computation-id"
|
||||
event := "computation.failed"
|
||||
status := "error"
|
||||
details := json.RawMessage(`{"error": "something went wrong"}`)
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, mock.Anything).Return(assert.AnError)
|
||||
|
||||
// This should not panic even when error occurs
|
||||
adapter.SendEvent(cmpID, event, status, details)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
mockClient.AssertCalled(t, "SendEvent", mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
// TestSendEventWithNilDetails tests sending an event with nil details.
|
||||
func TestSendEventWithNilDetails(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := "runner-service"
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
cmpID := "comp-123"
|
||||
event := "test.event"
|
||||
status := "pending"
|
||||
|
||||
expectedEntry := &logpb.EventEntry{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Details: nil,
|
||||
Originator: svc,
|
||||
Status: status,
|
||||
}
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
|
||||
|
||||
adapter.SendEvent(cmpID, event, status, nil)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// TestSendEventWithEmptyStrings tests sending an event with empty strings.
|
||||
func TestSendEventWithEmptyStrings(t *testing.T) {
|
||||
mockClient := new(mockLogClient)
|
||||
svc := testServiceName
|
||||
adapter := NewAdapter(mockClient, svc)
|
||||
|
||||
expectedEntry := &logpb.EventEntry{
|
||||
EventType: "",
|
||||
ComputationId: "",
|
||||
Details: nil,
|
||||
Originator: svc,
|
||||
Status: "",
|
||||
}
|
||||
|
||||
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
|
||||
|
||||
adapter.SendEvent("", "", "", nil)
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
}
|
||||
@@ -0,0 +1,341 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.8
|
||||
// protoc v6.33.1
|
||||
// source: agent/runner/runner.proto
|
||||
|
||||
package runner
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type RunRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
AlgoType string `protobuf:"bytes,2,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"` // "binary", "python", "wasm", "docker"
|
||||
Algorithm []byte `protobuf:"bytes,3,opt,name=algorithm,proto3" json:"algorithm,omitempty"` // The algorithm binary/script content
|
||||
Requirements []byte `protobuf:"bytes,4,opt,name=requirements,proto3" json:"requirements,omitempty"` // Python requirements.txt content
|
||||
Args []string `protobuf:"bytes,5,rep,name=args,proto3" json:"args,omitempty"`
|
||||
Datasets []*Dataset `protobuf:"bytes,6,rep,name=datasets,proto3" json:"datasets,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RunRequest) Reset() {
|
||||
*x = RunRequest{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RunRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RunRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RunRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RunRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RunRequest) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetAlgoType() string {
|
||||
if x != nil {
|
||||
return x.AlgoType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetAlgorithm() []byte {
|
||||
if x != nil {
|
||||
return x.Algorithm
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetRequirements() []byte {
|
||||
if x != nil {
|
||||
return x.Requirements
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetArgs() []string {
|
||||
if x != nil {
|
||||
return x.Args
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RunRequest) GetDatasets() []*Dataset {
|
||||
if x != nil {
|
||||
return x.Datasets
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Dataset struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Filename string `protobuf:"bytes,1,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
Hash []byte `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *Dataset) Reset() {
|
||||
*x = Dataset{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *Dataset) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Dataset) ProtoMessage() {}
|
||||
|
||||
func (x *Dataset) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Dataset.ProtoReflect.Descriptor instead.
|
||||
func (*Dataset) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *Dataset) GetFilename() string {
|
||||
if x != nil {
|
||||
return x.Filename
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Dataset) GetHash() []byte {
|
||||
if x != nil {
|
||||
return x.Hash
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RunResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RunResponse) Reset() {
|
||||
*x = RunResponse{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RunResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RunResponse) ProtoMessage() {}
|
||||
|
||||
func (x *RunResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RunResponse.ProtoReflect.Descriptor instead.
|
||||
func (*RunResponse) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *RunResponse) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RunResponse) GetError() string {
|
||||
if x != nil {
|
||||
return x.Error
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type StopRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *StopRequest) Reset() {
|
||||
*x = StopRequest{}
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *StopRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*StopRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_runner_runner_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use StopRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopRequest) Descriptor() ([]byte, []int) {
|
||||
return file_agent_runner_runner_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *StopRequest) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_agent_runner_runner_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_agent_runner_runner_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x19agent/runner/runner.proto\x12\x06runner\x1a\x1bgoogle/protobuf/empty.proto\"\xd3\x01\n" +
|
||||
"\n" +
|
||||
"RunRequest\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x1b\n" +
|
||||
"\talgo_type\x18\x02 \x01(\tR\balgoType\x12\x1c\n" +
|
||||
"\talgorithm\x18\x03 \x01(\fR\talgorithm\x12\"\n" +
|
||||
"\frequirements\x18\x04 \x01(\fR\frequirements\x12\x12\n" +
|
||||
"\x04args\x18\x05 \x03(\tR\x04args\x12+\n" +
|
||||
"\bdatasets\x18\x06 \x03(\v2\x0f.runner.DatasetR\bdatasets\"9\n" +
|
||||
"\aDataset\x12\x1a\n" +
|
||||
"\bfilename\x18\x01 \x01(\tR\bfilename\x12\x12\n" +
|
||||
"\x04hash\x18\x02 \x01(\fR\x04hash\"J\n" +
|
||||
"\vRunResponse\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x14\n" +
|
||||
"\x05error\x18\x02 \x01(\tR\x05error\"4\n" +
|
||||
"\vStopRequest\x12%\n" +
|
||||
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId2x\n" +
|
||||
"\x11ComputationRunner\x12.\n" +
|
||||
"\x03Run\x12\x12.runner.RunRequest\x1a\x13.runner.RunResponse\x123\n" +
|
||||
"\x04Stop\x12\x13.runner.StopRequest\x1a\x16.google.protobuf.EmptyB\n" +
|
||||
"Z\b./runnerb\x06proto3"
|
||||
|
||||
var (
|
||||
file_agent_runner_runner_proto_rawDescOnce sync.Once
|
||||
file_agent_runner_runner_proto_rawDescData []byte
|
||||
)
|
||||
|
||||
func file_agent_runner_runner_proto_rawDescGZIP() []byte {
|
||||
file_agent_runner_runner_proto_rawDescOnce.Do(func() {
|
||||
file_agent_runner_runner_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)))
|
||||
})
|
||||
return file_agent_runner_runner_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_runner_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
|
||||
var file_agent_runner_runner_proto_goTypes = []any{
|
||||
(*RunRequest)(nil), // 0: runner.RunRequest
|
||||
(*Dataset)(nil), // 1: runner.Dataset
|
||||
(*RunResponse)(nil), // 2: runner.RunResponse
|
||||
(*StopRequest)(nil), // 3: runner.StopRequest
|
||||
(*emptypb.Empty)(nil), // 4: google.protobuf.Empty
|
||||
}
|
||||
var file_agent_runner_runner_proto_depIdxs = []int32{
|
||||
1, // 0: runner.RunRequest.datasets:type_name -> runner.Dataset
|
||||
0, // 1: runner.ComputationRunner.Run:input_type -> runner.RunRequest
|
||||
3, // 2: runner.ComputationRunner.Stop:input_type -> runner.StopRequest
|
||||
2, // 3: runner.ComputationRunner.Run:output_type -> runner.RunResponse
|
||||
4, // 4: runner.ComputationRunner.Stop:output_type -> google.protobuf.Empty
|
||||
3, // [3:5] is the sub-list for method output_type
|
||||
1, // [1:3] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_runner_runner_proto_init() }
|
||||
func file_agent_runner_runner_proto_init() {
|
||||
if File_agent_runner_runner_proto != nil {
|
||||
return
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 4,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_agent_runner_runner_proto_goTypes,
|
||||
DependencyIndexes: file_agent_runner_runner_proto_depIdxs,
|
||||
MessageInfos: file_agent_runner_runner_proto_msgTypes,
|
||||
}.Build()
|
||||
File_agent_runner_runner_proto = out.File
|
||||
file_agent_runner_runner_proto_goTypes = nil
|
||||
file_agent_runner_runner_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package runner;
|
||||
|
||||
option go_package = "./runner";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
service ComputationRunner {
|
||||
rpc Run(RunRequest) returns (RunResponse);
|
||||
rpc Stop(StopRequest) returns (google.protobuf.Empty);
|
||||
}
|
||||
|
||||
message RunRequest {
|
||||
string computation_id = 1;
|
||||
string algo_type = 2; // "binary", "python", "wasm", "docker"
|
||||
bytes algorithm = 3; // The algorithm binary/script content
|
||||
bytes requirements = 4; // Python requirements.txt content
|
||||
repeated string args = 5;
|
||||
repeated Dataset datasets = 6;
|
||||
}
|
||||
|
||||
message Dataset {
|
||||
string filename = 1;
|
||||
bytes hash = 2;
|
||||
}
|
||||
|
||||
message RunResponse {
|
||||
string computation_id = 1;
|
||||
string error = 2;
|
||||
}
|
||||
|
||||
message StopRequest {
|
||||
string computation_id = 1;
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v6.33.1
|
||||
// source: agent/runner/runner.proto
|
||||
|
||||
package runner
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
ComputationRunner_Run_FullMethodName = "/runner.ComputationRunner/Run"
|
||||
ComputationRunner_Stop_FullMethodName = "/runner.ComputationRunner/Stop"
|
||||
)
|
||||
|
||||
// ComputationRunnerClient is the client API for ComputationRunner service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ComputationRunnerClient interface {
|
||||
Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error)
|
||||
Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
type computationRunnerClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewComputationRunnerClient(cc grpc.ClientConnInterface) ComputationRunnerClient {
|
||||
return &computationRunnerClient{cc}
|
||||
}
|
||||
|
||||
func (c *computationRunnerClient) Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RunResponse)
|
||||
err := c.cc.Invoke(ctx, ComputationRunner_Run_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *computationRunnerClient) Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, ComputationRunner_Stop_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ComputationRunnerServer is the server API for ComputationRunner service.
|
||||
// All implementations must embed UnimplementedComputationRunnerServer
|
||||
// for forward compatibility.
|
||||
type ComputationRunnerServer interface {
|
||||
Run(context.Context, *RunRequest) (*RunResponse, error)
|
||||
Stop(context.Context, *StopRequest) (*emptypb.Empty, error)
|
||||
mustEmbedUnimplementedComputationRunnerServer()
|
||||
}
|
||||
|
||||
// UnimplementedComputationRunnerServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedComputationRunnerServer struct{}
|
||||
|
||||
func (UnimplementedComputationRunnerServer) Run(context.Context, *RunRequest) (*RunResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Run not implemented")
|
||||
}
|
||||
func (UnimplementedComputationRunnerServer) Stop(context.Context, *StopRequest) (*emptypb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Stop not implemented")
|
||||
}
|
||||
func (UnimplementedComputationRunnerServer) mustEmbedUnimplementedComputationRunnerServer() {}
|
||||
func (UnimplementedComputationRunnerServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeComputationRunnerServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to ComputationRunnerServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeComputationRunnerServer interface {
|
||||
mustEmbedUnimplementedComputationRunnerServer()
|
||||
}
|
||||
|
||||
func RegisterComputationRunnerServer(s grpc.ServiceRegistrar, srv ComputationRunnerServer) {
|
||||
// If the following call pancis, it indicates UnimplementedComputationRunnerServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&ComputationRunner_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _ComputationRunner_Run_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RunRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ComputationRunnerServer).Run(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ComputationRunner_Run_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ComputationRunnerServer).Run(ctx, req.(*RunRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ComputationRunner_Stop_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(StopRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ComputationRunnerServer).Stop(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ComputationRunner_Stop_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ComputationRunnerServer).Stop(ctx, req.(*StopRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ComputationRunner_ServiceDesc is the grpc.ServiceDesc for ComputationRunner service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var ComputationRunner_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "runner.ComputationRunner",
|
||||
HandlerType: (*ComputationRunnerServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Run",
|
||||
Handler: _ComputationRunner_Run_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Stop",
|
||||
Handler: _ComputationRunner_Stop_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "agent/runner/runner.proto",
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/docker"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
const (
|
||||
algoFilePermission = 0o700
|
||||
)
|
||||
|
||||
var _ pb.ComputationRunnerServer = (*RunnerService)(nil)
|
||||
|
||||
type RunnerService struct {
|
||||
pb.UnimplementedComputationRunnerServer
|
||||
logger *slog.Logger
|
||||
eventSvc events.Service
|
||||
currentAlgo algorithm.Algorithm
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, eventSvc events.Service) *RunnerService {
|
||||
return &RunnerService{
|
||||
logger: logger,
|
||||
eventSvc: eventSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
|
||||
s.mu.Lock()
|
||||
if s.currentAlgo != nil {
|
||||
s.mu.Unlock()
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
Error: "computation already running",
|
||||
}, nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.currentAlgo = nil
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
currentDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting current directory: %v", err)
|
||||
}
|
||||
|
||||
// Write Algo File
|
||||
algoPath := filepath.Join(currentDir, "algo")
|
||||
f, err := os.Create(algoPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating algorithm file: %v", err)
|
||||
}
|
||||
if _, err := f.Write(req.Algorithm); err != nil {
|
||||
return nil, fmt.Errorf("error writing algorithm to file: %v", err)
|
||||
}
|
||||
if err := os.Chmod(algoPath, algoFilePermission); err != nil {
|
||||
return nil, fmt.Errorf("error changing file permissions: %v", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return nil, fmt.Errorf("error closing file: %v", err)
|
||||
}
|
||||
|
||||
var algo algorithm.Algorithm
|
||||
|
||||
switch req.AlgoType {
|
||||
case string(algorithm.AlgoTypeBin):
|
||||
algo = binary.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.Args, req.ComputationId)
|
||||
case string(algorithm.AlgoTypePython):
|
||||
var requirementsFile string
|
||||
if len(req.Requirements) > 0 {
|
||||
fr, err := os.CreateTemp("", "requirements.txt")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating requirments file: %v", err)
|
||||
}
|
||||
if _, err := fr.Write(req.Requirements); err != nil {
|
||||
return nil, fmt.Errorf("error writing requirements to file: %v", err)
|
||||
}
|
||||
if err := fr.Close(); err != nil {
|
||||
return nil, fmt.Errorf("error closing file: %v", err)
|
||||
}
|
||||
requirementsFile = fr.Name()
|
||||
}
|
||||
// Assuming default python runtime if not specified in request (proto doesn't have runtime field yet)
|
||||
// We can add it or assume.
|
||||
runtime := python.PyRuntime
|
||||
algo = python.NewAlgorithm(s.logger, s.eventSvc, runtime, requirementsFile, algoPath, req.Args, req.ComputationId)
|
||||
case string(algorithm.AlgoTypeWasm):
|
||||
algo = wasm.NewAlgorithm(s.logger, s.eventSvc, req.Args, algoPath, req.ComputationId)
|
||||
case string(algorithm.AlgoTypeDocker):
|
||||
algo = docker.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.ComputationId)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported algorithm type: %s", req.AlgoType)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.currentAlgo = algo
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := algo.Run(); err != nil {
|
||||
s.logger.Error("computation failed", "error", err)
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
Error: err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *RunnerService) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.currentAlgo != nil {
|
||||
if err := s.currentAlgo.Stop(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
)
|
||||
|
||||
// MockEventService is a mock implementation of events.Service.
|
||||
type MockEventService struct {
|
||||
events []interface{}
|
||||
}
|
||||
|
||||
func (m *MockEventService) SendEvent(cmpID, event, status string, details json.RawMessage) {
|
||||
m.events = append(m.events, map[string]interface{}{
|
||||
"cmpID": cmpID,
|
||||
"event": event,
|
||||
"status": status,
|
||||
"details": details,
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewRunnerService tests the creation of a new runner service.
|
||||
func TestNewRunnerService(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
|
||||
rs := New(logger, eventSvc)
|
||||
require.NotNil(t, rs)
|
||||
assert.NotNil(t, rs.logger)
|
||||
assert.NotNil(t, rs.eventSvc)
|
||||
assert.Nil(t, rs.currentAlgo)
|
||||
}
|
||||
|
||||
// TestRunWithBinaryAlgorithm tests running a binary algorithm.
|
||||
func TestRunWithBinaryAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-1",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\necho 'test'"),
|
||||
Args: []string{"arg1", "arg2"},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "test-1", resp.ComputationId)
|
||||
}
|
||||
|
||||
// TestRunWithPythonAlgorithm tests running a Python algorithm.
|
||||
func TestRunWithPythonAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-python",
|
||||
AlgoType: "python",
|
||||
Algorithm: []byte("print('hello')"),
|
||||
Args: []string{},
|
||||
Requirements: []byte("numpy==1.21.0"),
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "test-python", resp.ComputationId)
|
||||
}
|
||||
|
||||
// TestRunWithPythonAlgorithmNoRequirements tests running Python without requirements.
|
||||
func TestRunWithPythonAlgorithmNoRequirements(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-python-noreq",
|
||||
AlgoType: "python",
|
||||
Algorithm: []byte("print('hello')"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "test-python-noreq", resp.ComputationId)
|
||||
}
|
||||
|
||||
// TestRunWithWasmAlgorithm tests running a WASM algorithm.
|
||||
func TestRunWithWasmAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-wasm",
|
||||
AlgoType: "wasm",
|
||||
Algorithm: []byte{0x00, 0x61, 0x73, 0x6d},
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "test-wasm", resp.ComputationId)
|
||||
}
|
||||
|
||||
// TestRunWithDockerAlgorithm tests running a Docker algorithm.
|
||||
func TestRunWithDockerAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-docker",
|
||||
AlgoType: "docker",
|
||||
Algorithm: []byte("FROM ubuntu:latest\nRUN echo 'test'"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "test-docker", resp.ComputationId)
|
||||
}
|
||||
|
||||
// TestRunWithUnsupportedAlgorithmType tests running with unsupported algorithm type.
|
||||
func TestRunWithUnsupportedAlgorithmType(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-unsupported",
|
||||
AlgoType: "unsupported",
|
||||
Algorithm: []byte("test"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
}
|
||||
|
||||
// TestRunAlreadyRunning tests running computation when one is already running.
|
||||
func TestRunAlreadyRunning(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
// Use a long-running bash script
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-running",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\nsleep 30"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
// Start first computation (will run for 30 seconds)
|
||||
go func() {
|
||||
_, _ = rs.Run(context.Background(), req)
|
||||
}()
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Try to run another immediately - should fail
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "computation already running", resp.Error)
|
||||
}
|
||||
|
||||
// TestStopWhenRunning tests stopping a running computation.
|
||||
func TestStopWhenRunning(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-stop",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\nsleep 10"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
_, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
|
||||
stopReq := &pb.StopRequest{
|
||||
ComputationId: "test-stop",
|
||||
}
|
||||
|
||||
stopResp, err := rs.Stop(context.Background(), stopReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stopResp)
|
||||
}
|
||||
|
||||
// TestStopWhenNotRunning tests stopping when no computation is running.
|
||||
func TestStopWhenNotRunning(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
stopReq := &pb.StopRequest{
|
||||
ComputationId: "test-not-running",
|
||||
}
|
||||
|
||||
stopResp, err := rs.Stop(context.Background(), stopReq)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, stopResp)
|
||||
}
|
||||
|
||||
// TestConcurrentRun tests that concurrent runs are properly serialized.
|
||||
func TestConcurrentRun(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-concurrent",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\nsleep 15"),
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
// Start first run in goroutine (will run for 15 seconds)
|
||||
go func() {
|
||||
_, _ = rs.Run(context.Background(), req)
|
||||
}()
|
||||
|
||||
// Give it time to actually start
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Concurrent attempt should fail
|
||||
resp2, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "computation already running", resp2.Error)
|
||||
}
|
||||
|
||||
// TestRunWithMultipleArgs tests running with multiple arguments.
|
||||
func TestRunWithMultipleArgs(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := &MockEventService{}
|
||||
rs := New(logger, eventSvc)
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-multi-args",
|
||||
AlgoType: "bin",
|
||||
Algorithm: []byte("#!/bin/bash\necho $@"),
|
||||
Args: []string{"arg1", "arg2", "arg3", "arg4"},
|
||||
}
|
||||
|
||||
resp, err := rs.Run(context.Background(), req)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.Equal(t, "test-multi-args", resp.ComputationId)
|
||||
}
|
||||
+52
-41
@@ -16,17 +16,15 @@ import (
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/docker"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
@@ -130,8 +128,12 @@ type Service interface {
|
||||
|
||||
type agentService struct {
|
||||
mu sync.Mutex
|
||||
computation Computation // Holds the current computation request details.
|
||||
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
|
||||
computation Computation // Holds the current computation request details.
|
||||
runnerClient runner_client.Client
|
||||
algoType string
|
||||
algoArgs []string
|
||||
algoRequirements []byte
|
||||
algoReceived bool
|
||||
result []byte // Stores the result of the computation.
|
||||
sm statemachine.StateMachine // Manages the state transitions of the agent service.
|
||||
runError error // Stores any error encountered during the computation run.
|
||||
@@ -146,13 +148,14 @@ type agentService struct {
|
||||
var _ Service = (*agentService)(nil)
|
||||
|
||||
// New instantiates the agent service implementation.
|
||||
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, vmlp int) Service {
|
||||
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service {
|
||||
sm := statemachine.NewStateMachine(Idle)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
svc := &agentService{
|
||||
sm: sm,
|
||||
eventSvc: eventSvc,
|
||||
attestationClient: attestationClient,
|
||||
runnerClient: runnerClient,
|
||||
logger: logger,
|
||||
cancel: cancel,
|
||||
vmpl: vmlp,
|
||||
@@ -233,10 +236,9 @@ func (as *agentService) StopComputation(ctx context.Context) error {
|
||||
|
||||
as.cancel()
|
||||
|
||||
if as.algorithm != nil {
|
||||
if err := as.algorithm.Stop(); err != nil {
|
||||
return fmt.Errorf("error stopping computation: %v", err)
|
||||
}
|
||||
if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil {
|
||||
as.logger.Warn("failed to stop runner", "error", err)
|
||||
// proceed to cleanup
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
|
||||
@@ -250,7 +252,10 @@ func (as *agentService) StopComputation(ctx context.Context) error {
|
||||
as.sm.Reset(Idle)
|
||||
|
||||
as.computation = Computation{}
|
||||
as.algorithm = nil
|
||||
as.algoReceived = false
|
||||
as.algoType = ""
|
||||
as.algoArgs = nil
|
||||
as.algoRequirements = nil
|
||||
as.result = nil
|
||||
as.runError = nil
|
||||
as.resultsConsumed = false
|
||||
@@ -278,7 +283,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
}
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if as.algorithm != nil {
|
||||
if as.algoReceived {
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
|
||||
@@ -317,38 +322,16 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
|
||||
args := algorithm.AlgorithmArgsFromContext(ctx)
|
||||
|
||||
switch algoType {
|
||||
case string(algorithm.AlgoTypeBin):
|
||||
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args, as.computation.ID)
|
||||
case string(algorithm.AlgoTypePython):
|
||||
var requirementsFile string
|
||||
if len(algo.Requirements) > 0 {
|
||||
fr, err := os.CreateTemp("", "requirements.txt")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating requirments file: %v", err)
|
||||
}
|
||||
|
||||
if _, err := fr.Write(algo.Requirements); err != nil {
|
||||
return fmt.Errorf("error writing requirements to file: %v", err)
|
||||
}
|
||||
if err := fr.Close(); err != nil {
|
||||
return fmt.Errorf("error closing file: %v", err)
|
||||
}
|
||||
requirementsFile = fr.Name()
|
||||
}
|
||||
runtime := python.PythonRunTimeFromContext(ctx)
|
||||
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args, as.computation.ID)
|
||||
case string(algorithm.AlgoTypeWasm):
|
||||
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, args, f.Name(), as.computation.ID)
|
||||
case string(algorithm.AlgoTypeDocker):
|
||||
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name(), as.computation.ID)
|
||||
}
|
||||
as.algoType = algoType
|
||||
as.algoArgs = args
|
||||
as.algoRequirements = algo.Requirements
|
||||
as.algoReceived = true
|
||||
|
||||
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
|
||||
return fmt.Errorf("error creating datasets directory: %v", err)
|
||||
}
|
||||
|
||||
if as.algorithm != nil {
|
||||
if as.algoReceived {
|
||||
as.sm.SendEvent(AlgorithmReceived)
|
||||
}
|
||||
|
||||
@@ -478,14 +461,42 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Read algo file
|
||||
currentDir, _ := os.Getwd()
|
||||
algoFile := filepath.Join(currentDir, "algo")
|
||||
algoBytes, err := os.ReadFile(algoFile)
|
||||
if err != nil {
|
||||
as.runError = fmt.Errorf("failed to read algo file: %w", err)
|
||||
as.logger.Warn(as.runError.Error())
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
as.publishEvent(InProgress.String())(state)
|
||||
if err := as.algorithm.Run(); err != nil {
|
||||
|
||||
// Call Runner
|
||||
resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{
|
||||
ComputationId: as.computation.ID,
|
||||
AlgoType: as.algoType,
|
||||
Algorithm: algoBytes,
|
||||
Requirements: as.algoRequirements,
|
||||
Args: as.algoArgs,
|
||||
// Datasets implicit on shared FS
|
||||
})
|
||||
if err != nil {
|
||||
as.runError = err
|
||||
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Error != "" {
|
||||
as.runError = errors.New(resp.Error)
|
||||
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
|
||||
if err != nil {
|
||||
as.runError = err
|
||||
|
||||
+46
-47
@@ -18,16 +18,18 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
algomocks "github.com/ultravioletrs/cocos/agent/algorithm/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -123,7 +125,9 @@ func TestAlgo(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
client := new(MockAttestationClient)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||||
|
||||
err := svc.InitComputation(ctx, testComputation(t))
|
||||
require.NoError(t, err)
|
||||
@@ -217,7 +221,9 @@ func TestData(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
client := new(MockAttestationClient)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||||
|
||||
err := svc.InitComputation(ctx, testComputation(t))
|
||||
require.NoError(t, err)
|
||||
@@ -294,6 +300,7 @@ func TestResult(t *testing.T) {
|
||||
}
|
||||
|
||||
client := new(MockAttestationClient)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
|
||||
sm := new(smmocks.StateMachine)
|
||||
sm.On("Start", ctx).Return(nil)
|
||||
@@ -304,6 +311,7 @@ func TestResult(t *testing.T) {
|
||||
sm: sm,
|
||||
eventSvc: events,
|
||||
attestationClient: client,
|
||||
runnerClient: runnerCli,
|
||||
computation: testComputation(t),
|
||||
}
|
||||
|
||||
@@ -400,7 +408,8 @@ func TestAttestation(t *testing.T) {
|
||||
}
|
||||
defer getQuote.Unset()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||||
@@ -420,18 +429,7 @@ func TestAzureAttestationToken(t *testing.T) {
|
||||
name: "Azure token fetch successful",
|
||||
nonce: [32]byte{1, 2, 3}, // any test nonce
|
||||
token: []byte("mockToken"),
|
||||
err: nil, // fixed expectation as err was ErrAttestationType in original but logic suggests success if token returns? Wait, orig test had ErrAttestationType? Ah, maybe provider mock returns error.
|
||||
// Re-reading original test:
|
||||
// err: ErrAttestationType
|
||||
// provider.On(...).Return(tc.token, tc.err)
|
||||
// svc.AzureAttestationToken...
|
||||
// In original code, AzureAttestationToken checked `attestation.CCPlatform() != attestation.Azure`.
|
||||
// Since test runs on non-azure, it returns ErrAttestationType.
|
||||
// My new client calls GetAzureToken. The logic for checking platform moved to attestation-service.
|
||||
// So `agent` just calls the client.
|
||||
// So here we should expect whatever the client returns.
|
||||
// Mock client returns tc.err.
|
||||
// If I want to test success, I should set err: nil.
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Azure token fetch failed",
|
||||
@@ -450,7 +448,8 @@ func TestAzureAttestationToken(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||||
|
||||
_, err := svc.AzureAttestationToken(ctx, tc.nonce)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
||||
@@ -489,9 +488,6 @@ func testComputation(t *testing.T) Computation {
|
||||
}
|
||||
|
||||
func TestStopComputation(t *testing.T) {
|
||||
testDataDir := "test_datasets"
|
||||
testResultsDir := "test_results"
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
setupDirs bool
|
||||
@@ -511,22 +507,9 @@ func TestStopComputation(t *testing.T) {
|
||||
setupDirs: true,
|
||||
setupAlgo: true,
|
||||
algoStopErr: fmt.Errorf("algorithm stop failed"),
|
||||
expectedErr: fmt.Errorf("error stopping computation: algorithm stop failed"),
|
||||
},
|
||||
{
|
||||
name: "Stop computation without algorithm",
|
||||
setupDirs: true,
|
||||
setupAlgo: false,
|
||||
algoStopErr: nil,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Stop computation with missing directories",
|
||||
setupDirs: false,
|
||||
setupAlgo: false,
|
||||
algoStopErr: nil,
|
||||
expectedErr: nil, // os.RemoveAll doesn't error on non-existing directories
|
||||
expectedErr: nil, // Warn only
|
||||
},
|
||||
// We log warnings but don't return error in StopComputation in new implementation for Stop failure.
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
@@ -539,7 +522,16 @@ func TestStopComputation(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
client := new(MockAttestationClient)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0).(*agentService)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
|
||||
// Mock Stop call
|
||||
var stopErr error
|
||||
if tc.algoStopErr != nil {
|
||||
stopErr = tc.algoStopErr
|
||||
}
|
||||
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, stopErr)
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0).(*agentService)
|
||||
|
||||
svc.computation = Computation{
|
||||
ID: "test-computation",
|
||||
@@ -547,17 +539,17 @@ func TestStopComputation(t *testing.T) {
|
||||
}
|
||||
|
||||
if tc.setupDirs {
|
||||
err := os.MkdirAll(testDataDir, 0o755)
|
||||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
err = os.MkdirAll(testResultsDir, 0o755)
|
||||
err = os.MkdirAll(algorithm.ResultsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tc.setupAlgo {
|
||||
mockAlgo := new(algomocks.Algorithm)
|
||||
mockAlgo.On("Stop").Return(tc.algoStopErr)
|
||||
svc.algorithm = mockAlgo
|
||||
}
|
||||
// Use real dirs for test
|
||||
// algorithm.DatasetsDir refers to global var?
|
||||
// "github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
// It uses hardcoded path "datasets" and "results" in current dir.
|
||||
// Tests create them in current dir.
|
||||
|
||||
err := svc.StopComputation(ctx)
|
||||
|
||||
@@ -575,8 +567,8 @@ func TestStopComputation(t *testing.T) {
|
||||
|
||||
events.AssertExpectations(t)
|
||||
|
||||
_ = os.RemoveAll(testDataDir)
|
||||
_ = os.RemoveAll(testResultsDir)
|
||||
_ = os.RemoveAll(algorithm.DatasetsDir)
|
||||
_ = os.RemoveAll(algorithm.ResultsDir)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -608,7 +600,11 @@ func TestStopComputationIntegration(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
client := new(MockAttestationClient)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||||
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||||
|
||||
computation := Computation{
|
||||
ID: "integration-test",
|
||||
@@ -647,7 +643,10 @@ func TestStopComputationConcurrent(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
client := new(MockAttestationClient)
|
||||
svc := New(ctx, mglog.NewMock(), events, client, 0)
|
||||
runnerCli := new(runnermocks.Client)
|
||||
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||||
|
||||
svc.(*agentService).computation = Computation{
|
||||
ID: "concurrent-test",
|
||||
|
||||
+85
-18
@@ -11,6 +11,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
cvmsapi "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/server"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
logpb "github.com/ultravioletrs/cocos/agent/log"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
@@ -33,6 +35,9 @@ import (
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
|
||||
cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
runnerclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
@@ -54,6 +59,7 @@ type config struct {
|
||||
AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"`
|
||||
AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"`
|
||||
AttestationServiceSocket string `env:"ATTESTATION_SERVICE_SOCKET" envDefault:"/run/cocos/attestation.sock"`
|
||||
EnableATLS bool `env:"AGENT_ENABLE_ATLS" envDefault:"true"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -75,18 +81,63 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
eventsLogsQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
cvmsQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
|
||||
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, eventsLogsQueue)
|
||||
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, logQueue)
|
||||
logger := slog.New(handler)
|
||||
|
||||
eventSvc, err := events.New(svcName, eventsLogsQueue)
|
||||
eventSvc, err := events.New(svcName, logQueue)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create events service %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
logClient, err := logclient.NewClient("/run/cocos/log.sock")
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("failed to create log client: %s. Logging will be local only until service is available.", err))
|
||||
} else {
|
||||
defer logClient.Close()
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case msg := <-logQueue:
|
||||
if logClient == nil {
|
||||
continue
|
||||
}
|
||||
switch m := msg.Message.(type) {
|
||||
case *cvms.ClientStreamMessage_AgentLog:
|
||||
err := logClient.SendLog(ctx, &logpb.LogEntry{
|
||||
Message: m.AgentLog.Message,
|
||||
ComputationId: m.AgentLog.ComputationId,
|
||||
Level: m.AgentLog.Level,
|
||||
Timestamp: m.AgentLog.Timestamp,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to send log", "error", err)
|
||||
}
|
||||
case *cvms.ClientStreamMessage_AgentEvent:
|
||||
err := logClient.SendEvent(ctx, &logpb.EventEntry{
|
||||
EventType: m.AgentEvent.EventType,
|
||||
Timestamp: m.AgentEvent.Timestamp,
|
||||
ComputationId: m.AgentEvent.ComputationId,
|
||||
Details: m.AgentEvent.Details,
|
||||
Originator: m.AgentEvent.Originator,
|
||||
Status: m.AgentEvent.Status,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("failed to send event", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
var provider attestation.Provider
|
||||
ccPlatform := attestation.CCPlatform()
|
||||
|
||||
@@ -128,13 +179,6 @@ func main() {
|
||||
return grpcClient, pc, nil
|
||||
}
|
||||
|
||||
pc, err := cvmsClient.Process(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.Vmpl < 0 || cfg.Vmpl > 3 {
|
||||
logger.Error("vmpl level must be in a range [0, 3]")
|
||||
exitCode = 1
|
||||
@@ -149,7 +193,15 @@ func main() {
|
||||
}
|
||||
defer attClient.Close()
|
||||
|
||||
svc := newService(ctx, logger, eventSvc, attClient, cfg.Vmpl)
|
||||
runnerClient, err := runnerclient.NewClient("/run/cocos/runner.sock")
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create runner client: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer runnerClient.Close()
|
||||
|
||||
svc := newService(ctx, logger, eventSvc, attClient, runnerClient, cfg.Vmpl)
|
||||
|
||||
if err := os.MkdirAll(storageDir, 0o755); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create storage directory: %s", err))
|
||||
@@ -158,8 +210,7 @@ func main() {
|
||||
}
|
||||
|
||||
var certProvider atls.CertificateProvider
|
||||
|
||||
if ccPlatform != attestation.NoCC {
|
||||
if cfg.EnableATLS && ccPlatform != attestation.NoCC {
|
||||
var certsSDK sdk.SDK
|
||||
if cfg.CAUrl != "" {
|
||||
certsSDK = sdk.NewSDK(sdk.Config{
|
||||
@@ -174,7 +225,23 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
mc, err := cvmsapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), storageDir, reconnectFn, cvmGRPCClient)
|
||||
// Create ingress proxy server
|
||||
backendURL, err := url.Parse("unix:///run/cocos/agent.sock")
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to parse backend URL: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
ingressProxy := ingress.NewProxyServer(logger, backendURL, certProvider)
|
||||
|
||||
pc, err := cvmsClient.Process(ctx)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to connect to cvm server: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
mc, err := cvmsapi.NewClient(pc, svc, cvmsQueue, logger, server.NewServer(logger, svc, cfg.AgentGrpcHost, certProvider), ingressProxy, storageDir, reconnectFn, cvmGRPCClient)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
@@ -214,7 +281,7 @@ func main() {
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
eventsLogsQueue <- &cvms.ClientStreamMessage{
|
||||
cvmsQueue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AzureAttestationToken{
|
||||
AzureAttestationToken: &cvms.AzureAttestationToken{
|
||||
File: azureAttestationToken,
|
||||
@@ -224,7 +291,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
eventsLogsQueue <- &cvms.ClientStreamMessage{
|
||||
cvmsQueue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_VTPMattestationReport{
|
||||
VTPMattestationReport: &cvms.AttestationResponse{
|
||||
File: attest,
|
||||
@@ -238,8 +305,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attClient attestation_client.Client, vmpl int) agent.Service {
|
||||
svc := agent.New(ctx, logger, eventSvc, attClient, vmpl)
|
||||
func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attClient attestation_client.Client, runnerClient runnerclient.Client, vmpl int) agent.Service {
|
||||
svc := agent.New(ctx, logger, eventSvc, attClient, runnerClient, vmpl)
|
||||
|
||||
svc = api.LoggingMiddleware(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics(svcName, "api")
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
runnerevents "github.com/ultravioletrs/cocos/agent/runner/events"
|
||||
"github.com/ultravioletrs/cocos/agent/runner/service"
|
||||
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "computation-runner"
|
||||
socketPath = "/run/cocos/runner.sock"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"RUNNER_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"debug"`
|
||||
LogForwarder string `env:"LOG_FORWARDER_SOCKET" envDefault:"/run/cocos/log.sock"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
var cfg config
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
fmt.Printf("failed to load %s configuration : %s\n", svcName, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var exitCode int
|
||||
defer mglog.ExitWithError(&exitCode)
|
||||
|
||||
var level slog.Level
|
||||
if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil {
|
||||
fmt.Println(err)
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
|
||||
|
||||
// Connect to Log Forwarder
|
||||
logClient, err := logclient.NewClient(cfg.LogForwarder)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("failed to connect to log-forwarder: %s. Events will not be forwarded.", err))
|
||||
} else {
|
||||
defer logClient.Close()
|
||||
}
|
||||
|
||||
eventSvc := runnerevents.NewAdapter(logClient, svcName)
|
||||
|
||||
// Remove existing socket if it exists
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
if err := os.Remove(socketPath); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to remove existing socket: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dir := socketPath[:len(socketPath)-len("/runner.sock")]
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create socket directory: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
lis, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to listen on socket: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Chmod(socketPath, 0o777); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to chmod socket: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
svc := service.New(logger, eventSvc)
|
||||
pb.RegisterComputationRunnerServer(grpcServer, svc)
|
||||
|
||||
g.Go(func() error {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(ch)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
logger.Info("Received signal, shutting down...")
|
||||
cancel()
|
||||
grpcServer.GracefulStop()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
logger.Info(fmt.Sprintf("%s started on %s", svcName, socketPath))
|
||||
return grpcServer.Serve(lis)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
logger.Error(fmt.Sprintf("%s terminated: %s", svcName, err))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/pkg/egress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "egress-proxy"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
Level string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"`
|
||||
Port string `env:"COCOS_PROXY_PORT" envDefault:"3128"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var cfg config
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to load configuration: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: svcName,
|
||||
Short: "Egress Proxy Service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return run(cfg)
|
||||
},
|
||||
}
|
||||
|
||||
pflag.StringVar(&cfg.Level, "log-level", cfg.Level, "Log level")
|
||||
pflag.StringVar(&cfg.Port, "port", cfg.Port, "Proxy port")
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cfg config) error {
|
||||
logger, err := mglog.New(os.Stdout, cfg.Level)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create logger: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
proxy := egress.NewProxy(logger, ":"+cfg.Port)
|
||||
|
||||
g.Go(func() error {
|
||||
return proxy.Start()
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
||||
select {
|
||||
case s := <-c:
|
||||
logger.Info(fmt.Sprintf("received signal %s, stopping", s))
|
||||
cancel()
|
||||
return proxy.Stop(ctx)
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return fmt.Errorf("server exit with error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/absmach/certs/sdk"
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/azure"
|
||||
"github.com/ultravioletrs/cocos/pkg/ingress"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "ingress-proxy"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"COCOS_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"info"`
|
||||
Backend string `env:"COCOS_INGRESS_BACKEND" envDefault:"http://localhost:7001"`
|
||||
|
||||
// ATLS Config
|
||||
CAUrl string `env:"AGENT_CVM_CA_URL" envDefault:""`
|
||||
CVMId string `env:"AGENT_CVM_ID" envDefault:""`
|
||||
CertsToken string `env:"AGENT_CERTS_TOKEN" envDefault:""`
|
||||
AgentMaaURL string `env:"AGENT_MAA_URL" envDefault:"https://sharedeus2.eus2.attest.azure.net"`
|
||||
AgentOSBuild string `env:"AGENT_OS_BUILD" envDefault:"UVC"`
|
||||
AgentOSDistro string `env:"AGENT_OS_DISTRO" envDefault:"UVC"`
|
||||
AgentOSType string `env:"AGENT_OS_TYPE" envDefault:"UVC"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
var cfg config
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "failed to load configuration: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: svcName,
|
||||
Short: "Ingress Proxy Service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return run(cfg)
|
||||
},
|
||||
}
|
||||
|
||||
pflag.StringVar(&cfg.LogLevel, "log-level", cfg.LogLevel, "Log level")
|
||||
pflag.StringVar(&cfg.Backend, "backend", cfg.Backend, "Backend URL")
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run(cfg config) error {
|
||||
logger, err := mglog.New(os.Stdout, cfg.LogLevel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create logger: %w", err)
|
||||
}
|
||||
|
||||
backendURL, err := url.Parse(cfg.Backend)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse backend URL: %w", err)
|
||||
}
|
||||
|
||||
// Initialize Certificate Provider
|
||||
var provider attestation.Provider
|
||||
ccPlatform := attestation.CCPlatform()
|
||||
|
||||
azureConfig := azure.NewEnvConfigFromAgent(
|
||||
cfg.AgentOSBuild,
|
||||
cfg.AgentOSType,
|
||||
cfg.AgentOSDistro,
|
||||
cfg.AgentMaaURL,
|
||||
)
|
||||
azure.InitializeDefaultMAAVars(azureConfig)
|
||||
|
||||
var certProvider atls.CertificateProvider
|
||||
|
||||
if ccPlatform != attestation.NoCC {
|
||||
var certsSDK sdk.SDK
|
||||
if cfg.CAUrl != "" {
|
||||
certsSDK = sdk.NewSDK(sdk.Config{
|
||||
CertsURL: cfg.CAUrl,
|
||||
})
|
||||
}
|
||||
certProvider, err = atls.NewProvider(provider, ccPlatform, cfg.CertsToken, cfg.CVMId, certsSDK)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate provider: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("No Confidential Computing platform detected. ATLS will not be available.")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
// Create proxy server (but don't start it yet - it will be started per-computation)
|
||||
_ = ingress.NewProxyServer(logger, backendURL, certProvider)
|
||||
|
||||
// Note: The proxy server will be started dynamically when a computation is initiated
|
||||
// via the Manager's ComputationRunReq message. For now, we just keep the service alive.
|
||||
logger.Info("ingress-proxy service initialized, waiting for computation requests...")
|
||||
|
||||
g.Go(func() error {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
|
||||
select {
|
||||
case s := <-c:
|
||||
logger.Info(fmt.Sprintf("received signal %s, stopping", s))
|
||||
cancel()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return fmt.Errorf("server exit with error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
mglog "github.com/absmach/supermq/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
pb "github.com/ultravioletrs/cocos/agent/log"
|
||||
"github.com/ultravioletrs/cocos/agent/log/service"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients"
|
||||
cvmsgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "log-forwarder"
|
||||
socketPath = "/run/cocos/log.sock"
|
||||
envPrefixCVMGRPC = "AGENT_CVM_GRPC_"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"LOG_FORWARDER_LOG_LEVEL" envAlternate:"AGENT_LOG_LEVEL" envDefault:"debug"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
var cfg config
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
fmt.Printf("failed to load %s configuration : %s\n", svcName, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var exitCode int
|
||||
defer mglog.ExitWithError(&exitCode)
|
||||
|
||||
var level slog.Level
|
||||
if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil {
|
||||
fmt.Println(err)
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level}))
|
||||
|
||||
// Remove existing socket if it exists
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
if err := os.Remove(socketPath); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to remove existing socket: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dir := socketPath[:len(socketPath)-len("/log.sock")]
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create socket directory: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
lis, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to listen on socket: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.Chmod(socketPath, 0o777); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to chmod socket: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
// Connect to Manager
|
||||
cvmGrpcConfig := clients.StandardClientConfig{}
|
||||
if err := env.ParseWithOptions(&cvmGrpcConfig, env.Options{Prefix: envPrefixCVMGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
cvmClient, cvmsClient, err := cvmsgrpc.NewCVMClient(cvmGrpcConfig)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to connect to CVM manager: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer cvmClient.Close()
|
||||
|
||||
// Create stream to Manager
|
||||
stream, err := cvmsClient.Process(ctx)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create stream to manager: %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
logQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
svc := service.New(logger, cvmsClient, logQueue)
|
||||
pb.RegisterLogCollectorServer(grpcServer, svc)
|
||||
|
||||
// Log Consumer Goroutine
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case msg := <-logQueue:
|
||||
if err := stream.Send(msg); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to send log to manager: %s", err))
|
||||
// Reconnect logic would go here
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(ch)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
logger.Info("Received signal, shutting down...")
|
||||
cancel()
|
||||
grpcServer.GracefulStop()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
logger.Info(fmt.Sprintf("%s started on %s", svcName, socketPath))
|
||||
return grpcServer.Serve(lis)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
logger.Error(fmt.Sprintf("%s terminated: %s", svcName, err))
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,7 @@
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/agent/Config.in"
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/attestation-service/Config.in"
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/wasmedge/Config.in"
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/log-forwarder/Config.in"
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/computation-runner/Config.in"
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/egress-proxy/Config.in"
|
||||
source "$BR2_EXTERNAL_COCOS_PATH/package/ingress-proxy/Config.in"
|
||||
|
||||
@@ -2,6 +2,10 @@ config BR2_PACKAGE_AGENT
|
||||
bool "agent"
|
||||
default y
|
||||
select BR2_PACKAGE_ATTESTATION_SERVICE
|
||||
select BR2_PACKAGE_LOG_FORWARDER
|
||||
select BR2_PACKAGE_COMPUTATION_RUNNER
|
||||
select BR2_PACKAGE_INGRESS_PROXY
|
||||
select BR2_PACKAGE_EGRESS_PROXY
|
||||
help
|
||||
Confidential Computing Agent is a state machine capable of
|
||||
receiving datasets and algorithm, running computations, and
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
config BR2_PACKAGE_COMPUTATION_RUNNER
|
||||
bool "computation-runner"
|
||||
select BR2_PACKAGE_LOG_FORWARDER
|
||||
help
|
||||
Cocos AI Computation Runner service.
|
||||
@@ -0,0 +1,22 @@
|
||||
################################################################################
|
||||
#
|
||||
# computation-runner
|
||||
#
|
||||
################################################################################
|
||||
|
||||
COMPUTATION_RUNNER_VERSION = main
|
||||
COMPUTATION_RUNNER_SITE = $(call github,ultravioletrs,cocos,$(COMPUTATION_RUNNER_VERSION))
|
||||
|
||||
define COMPUTATION_RUNNER_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) computation-runner
|
||||
endef
|
||||
|
||||
define COMPUTATION_RUNNER_INSTALL_TARGET_CMDS
|
||||
$(INSTALL) -D -m 0750 $(@D)/build/cocos-computation-runner $(TARGET_DIR)/usr/bin/computation-runner
|
||||
endef
|
||||
|
||||
define COMPUTATION_RUNNER_INSTALL_INIT_SYSTEMD
|
||||
$(INSTALL) -D -m 0640 $(@D)/init/systemd/computation-runner.service $(TARGET_DIR)/usr/lib/systemd/system/computation-runner.service
|
||||
endef
|
||||
|
||||
$(eval $(generic-package))
|
||||
@@ -0,0 +1,6 @@
|
||||
config BR2_PACKAGE_EGRESS_PROXY
|
||||
bool "egress-proxy"
|
||||
help
|
||||
Cocos AI Egress Proxy Service.
|
||||
|
||||
https://github.com/ultravioletrs/cocos
|
||||
@@ -0,0 +1,22 @@
|
||||
################################################################################
|
||||
#
|
||||
# Cocos AI Egress Proxy
|
||||
#
|
||||
################################################################################
|
||||
|
||||
EGRESS_PROXY_VERSION = main
|
||||
EGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(EGRESS_PROXY_VERSION))
|
||||
|
||||
define EGRESS_PROXY_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) egress-proxy
|
||||
endef
|
||||
|
||||
define EGRESS_PROXY_INSTALL_TARGET_CMDS
|
||||
$(INSTALL) -D -m 0755 $(@D)/build/cocos-egress-proxy $(TARGET_DIR)/usr/bin/egress-proxy
|
||||
endef
|
||||
|
||||
define EGRESS_PROXY_INSTALL_INIT_SYSTEMD
|
||||
$(INSTALL) -D -m 0644 $(@D)/init/systemd/egress-proxy.service $(TARGET_DIR)/usr/lib/systemd/system/egress-proxy.service
|
||||
endef
|
||||
|
||||
$(eval $(generic-package))
|
||||
@@ -0,0 +1,4 @@
|
||||
config BR2_PACKAGE_INGRESS_PROXY
|
||||
bool "ingress-proxy"
|
||||
help
|
||||
Cocos Ingress Proxy service.
|
||||
@@ -0,0 +1,22 @@
|
||||
################################################################################
|
||||
#
|
||||
# ingress-proxy
|
||||
#
|
||||
################################################################################
|
||||
|
||||
INGRESS_PROXY_VERSION = main
|
||||
INGRESS_PROXY_SITE = $(call github,ultravioletrs,cocos,$(INGRESS_PROXY_VERSION))
|
||||
|
||||
define INGRESS_PROXY_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) ingress-proxy
|
||||
endef
|
||||
|
||||
define INGRESS_PROXY_INSTALL_TARGET_CMDS
|
||||
$(INSTALL) -D -m 0750 $(@D)/build/cocos-ingress-proxy $(TARGET_DIR)/usr/bin/ingress-proxy
|
||||
endef
|
||||
|
||||
# NOTE: The ingress-proxy is managed per-computation by the agent, not as a standalone
|
||||
# systemd service. The binary is installed for use by the agent, but no systemd service
|
||||
# is created.
|
||||
|
||||
$(eval $(generic-package))
|
||||
@@ -0,0 +1,4 @@
|
||||
config BR2_PACKAGE_LOG_FORWARDER
|
||||
bool "log-forwarder"
|
||||
help
|
||||
Cocos AI Log Forwarder service.
|
||||
@@ -0,0 +1,22 @@
|
||||
################################################################################
|
||||
#
|
||||
# log-forwarder
|
||||
#
|
||||
################################################################################
|
||||
|
||||
LOG_FORWARDER_VERSION = main
|
||||
LOG_FORWARDER_SITE = $(call github,ultravioletrs,cocos,$(LOG_FORWARDER_VERSION))
|
||||
|
||||
define LOG_FORWARDER_BUILD_CMDS
|
||||
$(MAKE) -C $(@D) log-forwarder
|
||||
endef
|
||||
|
||||
define LOG_FORWARDER_INSTALL_TARGET_CMDS
|
||||
$(INSTALL) -D -m 0750 $(@D)/build/cocos-log-forwarder $(TARGET_DIR)/usr/bin/log-forwarder
|
||||
endef
|
||||
|
||||
define LOG_FORWARDER_INSTALL_INIT_SYSTEMD
|
||||
$(INSTALL) -D -m 0640 $(@D)/init/systemd/log-forwarder.service $(TARGET_DIR)/usr/lib/systemd/system/log-forwarder.service
|
||||
endef
|
||||
|
||||
$(eval $(generic-package))
|
||||
@@ -1,16 +1,19 @@
|
||||
[Unit]
|
||||
Description=Cocos AI agent
|
||||
After=network.target attestation-service.service
|
||||
Requires=attestation-service.service
|
||||
After=network.target attestation-service.service log-forwarder.service computation-runner.service egress-proxy.service
|
||||
Requires=log-forwarder.service computation-runner.service egress-proxy.service
|
||||
Before=docker.service
|
||||
|
||||
[Service]
|
||||
WorkingDirectory=/cocos
|
||||
Environment="HTTP_PROXY=http://localhost:3128"
|
||||
Environment="HTTPS_PROXY=http://localhost:3128"
|
||||
Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/"
|
||||
Environment="AGENT_ENABLE_ATLS=false"
|
||||
StandardOutput=file:/var/log/cocos/agent.stdout
|
||||
StandardError=file:/var/log/cocos/agent.stderr
|
||||
EnvironmentFile=/etc/cocos/environment
|
||||
ExecStartPre=/cocos_init/agent_setup.sh
|
||||
ExecStart=/cocos_init/agent_start_script.sh
|
||||
ExecStart=/usr/bin/cocos-agent --config-file=/etc/cocos/cocos-agent.conf
|
||||
Restart=always
|
||||
RestartSec=5s
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
[Unit]
|
||||
Description=Cocos AI Computation Runner
|
||||
After=network.target log-forwarder.service
|
||||
Before=cocos-agent.service
|
||||
Requires=log-forwarder.service
|
||||
|
||||
[Service]
|
||||
WorkingDirectory=/cocos
|
||||
StandardOutput=file:/var/log/cocos/runner.stdout
|
||||
StandardError=file:/var/log/cocos/runner.stderr
|
||||
EnvironmentFile=/etc/cocos/environment
|
||||
Environment="HTTP_PROXY=http://localhost:3128"
|
||||
Environment="HTTPS_PROXY=http://localhost:3128"
|
||||
Environment="NO_PROXY=localhost,127.0.0.1,.local,/run/cocos/"
|
||||
ExecStart=/usr/bin/computation-runner --socket-path=/run/cocos/runner.sock
|
||||
Restart=always
|
||||
RestartSec=5s
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
@@ -0,0 +1,15 @@
|
||||
[Unit]
|
||||
Description=Cocos Egress Proxy Service
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
EnvironmentFile=/etc/cocos/environment
|
||||
Environment="COCOS_LOG_LEVEL=debug"
|
||||
ExecStart=/usr/bin/egress-proxy --port=3128
|
||||
Restart=always
|
||||
RestartSec=5
|
||||
User=root
|
||||
Group=root
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
@@ -0,0 +1,17 @@
|
||||
[Unit]
|
||||
Description=Cocos AI Log Forwarder
|
||||
After=network.target
|
||||
Before=cocos-agent.service
|
||||
|
||||
[Service]
|
||||
WorkingDirectory=/cocos
|
||||
StandardOutput=file:/var/log/cocos/log-forwarder.stdout
|
||||
StandardError=file:/var/log/cocos/log-forwarder.stderr
|
||||
EnvironmentFile=/etc/cocos/environment
|
||||
ExecStartPre=/cocos_init/agent_setup.sh
|
||||
ExecStart=/usr/bin/log-forwarder
|
||||
Restart=always
|
||||
RestartSec=5s
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
@@ -44,6 +44,7 @@ func (h *handler) Enabled(_ context.Context, l slog.Level) bool {
|
||||
}
|
||||
|
||||
func (h *handler) Handle(_ context.Context, r slog.Record) error {
|
||||
slog.Info("logging message", "message", r.Message)
|
||||
message := r.Message
|
||||
timestamp := timestamppb.New(r.Time)
|
||||
level := r.Level.String()
|
||||
|
||||
@@ -0,0 +1,392 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
attestation_v1 "github.com/ultravioletrs/cocos/internal/proto/attestation/v1"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// mockAttestationServer is a mock implementation of the AttestationServiceServer.
|
||||
type mockAttestationServer struct {
|
||||
attestation_v1.UnimplementedAttestationServiceServer
|
||||
fetchAttestationCalled bool
|
||||
fetchAzureTokenCalled bool
|
||||
lastReportData []byte
|
||||
lastNonce []byte
|
||||
lastPlatformType attestation_v1.PlatformType
|
||||
attestationErr error
|
||||
azureTokenErr error
|
||||
}
|
||||
|
||||
func (m *mockAttestationServer) FetchAttestation(ctx context.Context, req *attestation_v1.AttestationRequest) (*attestation_v1.AttestationResponse, error) {
|
||||
m.fetchAttestationCalled = true
|
||||
m.lastReportData = req.ReportData
|
||||
m.lastNonce = req.Nonce
|
||||
m.lastPlatformType = req.PlatformType
|
||||
|
||||
if m.attestationErr != nil {
|
||||
return nil, m.attestationErr
|
||||
}
|
||||
|
||||
return &attestation_v1.AttestationResponse{
|
||||
Quote: []byte("mock-attestation-quote"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockAttestationServer) FetchAzureToken(ctx context.Context, req *attestation_v1.AzureTokenRequest) (*attestation_v1.AzureTokenResponse, error) {
|
||||
m.fetchAzureTokenCalled = true
|
||||
m.lastNonce = req.Nonce
|
||||
|
||||
if m.azureTokenErr != nil {
|
||||
return nil, m.azureTokenErr
|
||||
}
|
||||
|
||||
return &attestation_v1.AzureTokenResponse{
|
||||
Token: []byte("mock-azure-token"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestNewClient tests creating a new attestation client.
|
||||
func TestNewClient(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-test.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestGetAttestationSNP tests getting SNP attestation.
|
||||
func TestGetAttestationSNP(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-snp.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
copy(reportData[:], []byte("test-report-data"))
|
||||
copy(nonce[:], []byte("test-nonce"))
|
||||
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("mock-attestation-quote"), quote)
|
||||
assert.True(t, mockServer.fetchAttestationCalled)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP, mockServer.lastPlatformType)
|
||||
assert.Equal(t, reportData[:], mockServer.lastReportData)
|
||||
assert.Equal(t, nonce[:], mockServer.lastNonce)
|
||||
}
|
||||
|
||||
// TestGetAttestationTDX tests getting TDX attestation.
|
||||
func TestGetAttestationTDX(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-tdx.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.TDX)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, quote)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_TDX, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationVTPM tests getting vTPM attestation.
|
||||
func TestGetAttestationVTPM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-vtpm.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.VTPM)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, quote)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_VTPM, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationSNPvTPM tests getting SNP+vTPM attestation.
|
||||
func TestGetAttestationSNPvTPM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-snpvtpm.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.SNPvTPM)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, quote)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_SNP_VTPM, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAttestationUnspecified tests getting attestation with unspecified platform.
|
||||
func TestGetAttestationUnspecified(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-unspec.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
// Use an invalid platform type (999)
|
||||
quote, err := client.GetAttestation(ctx, reportData, nonce, attestation.PlatformType(999))
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, quote)
|
||||
assert.Equal(t, attestation_v1.PlatformType_PLATFORM_TYPE_UNSPECIFIED, mockServer.lastPlatformType)
|
||||
}
|
||||
|
||||
// TestGetAzureToken tests getting Azure token.
|
||||
func TestGetAzureToken(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-azure.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var nonce [32]byte
|
||||
copy(nonce[:], []byte("azure-nonce"))
|
||||
|
||||
token, err := client.GetAzureToken(ctx, nonce)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("mock-azure-token"), token)
|
||||
assert.True(t, mockServer.fetchAzureTokenCalled)
|
||||
assert.Equal(t, nonce[:], mockServer.lastNonce)
|
||||
}
|
||||
|
||||
// TestGetAttestationWithCanceledContext tests GetAttestation with canceled context.
|
||||
func TestGetAttestationWithCanceledContext(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-cancel.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
_, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
}
|
||||
|
||||
// TestClientClose tests closing the attestation client.
|
||||
func TestClientClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-close.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestClientOperationsAfterClose tests operations after closing.
|
||||
func TestClientOperationsAfterClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "attestation-after-close.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockAttestationServer{}
|
||||
attestation_v1.RegisterAttestationServiceServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
var reportData [64]byte
|
||||
var nonce [32]byte
|
||||
|
||||
_, err = client.GetAttestation(ctx, reportData, nonce, attestation.SNP)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/log"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
SendLog(ctx context.Context, entry *log.LogEntry) error
|
||||
SendEvent(ctx context.Context, entry *log.EventEntry) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type client struct {
|
||||
conn *grpc.ClientConn
|
||||
client log.LogCollectorClient
|
||||
}
|
||||
|
||||
func NewClient(socketPath string) (Client, error) {
|
||||
conn, err := grpc.NewClient("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &client{
|
||||
conn: conn,
|
||||
client: log.NewLogCollectorClient(conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *client) SendLog(ctx context.Context, entry *log.LogEntry) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if entry.Timestamp == nil {
|
||||
entry.Timestamp = timestamppb.Now()
|
||||
}
|
||||
|
||||
_, err := c.client.SendLog(ctx, entry)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) SendEvent(ctx context.Context, entry *log.EventEntry) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if entry.Timestamp == nil {
|
||||
entry.Timestamp = timestamppb.Now()
|
||||
}
|
||||
|
||||
_, err := c.client.SendEvent(ctx, entry)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,332 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/log"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// mockLogCollectorServer is a mock implementation of the LogCollectorServer.
|
||||
type mockLogCollectorServer struct {
|
||||
log.UnimplementedLogCollectorServer
|
||||
sendLogCalled bool
|
||||
sendEventCalled bool
|
||||
lastLogEntry *log.LogEntry
|
||||
lastEventEntry *log.EventEntry
|
||||
sendLogErr error
|
||||
sendEventErr error
|
||||
}
|
||||
|
||||
func (m *mockLogCollectorServer) SendLog(ctx context.Context, entry *log.LogEntry) (*emptypb.Empty, error) {
|
||||
m.sendLogCalled = true
|
||||
m.lastLogEntry = entry
|
||||
if m.sendLogErr != nil {
|
||||
return nil, m.sendLogErr
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (m *mockLogCollectorServer) SendEvent(ctx context.Context, entry *log.EventEntry) (*emptypb.Empty, error) {
|
||||
m.sendEventCalled = true
|
||||
m.lastEventEntry = entry
|
||||
if m.sendEventErr != nil {
|
||||
return nil, m.sendEventErr
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// TestNewClient tests creating a new log client.
|
||||
func TestNewClient(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-test.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestClientSendLog tests sending a log entry.
|
||||
func TestClientSendLog(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-sendlog.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
entry := &log.LogEntry{
|
||||
Level: "INFO",
|
||||
Message: "test log message",
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
err = client.SendLog(ctx, entry)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, mockServer.sendLogCalled)
|
||||
assert.Equal(t, "INFO", mockServer.lastLogEntry.Level)
|
||||
assert.Equal(t, "test log message", mockServer.lastLogEntry.Message)
|
||||
assert.Equal(t, "test-computation", mockServer.lastLogEntry.ComputationId)
|
||||
assert.NotNil(t, mockServer.lastLogEntry.Timestamp)
|
||||
}
|
||||
|
||||
// TestClientSendLogWithTimestamp tests sending a log entry with existing timestamp.
|
||||
func TestClientSendLogWithTimestamp(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-timestamp.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
customTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
entry := &log.LogEntry{
|
||||
Level: "ERROR",
|
||||
Message: "test error",
|
||||
ComputationId: "test",
|
||||
Timestamp: timestamppb.New(customTime),
|
||||
}
|
||||
|
||||
err = client.SendLog(ctx, entry)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, mockServer.sendLogCalled)
|
||||
assert.Equal(t, customTime.Unix(), mockServer.lastLogEntry.Timestamp.AsTime().Unix())
|
||||
}
|
||||
|
||||
// TestClientSendEvent tests sending an event entry.
|
||||
func TestClientSendEvent(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-sendevent.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
entry := &log.EventEntry{
|
||||
EventType: "computation.started",
|
||||
ComputationId: "test-computation",
|
||||
Originator: "agent",
|
||||
Status: "started",
|
||||
}
|
||||
|
||||
err = client.SendEvent(ctx, entry)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, mockServer.sendEventCalled)
|
||||
assert.Equal(t, "computation.started", mockServer.lastEventEntry.EventType)
|
||||
assert.Equal(t, "agent", mockServer.lastEventEntry.Originator)
|
||||
assert.NotNil(t, mockServer.lastEventEntry.Timestamp)
|
||||
}
|
||||
|
||||
// TestClientSendEventWithTimestamp tests sending an event with existing timestamp.
|
||||
func TestClientSendEventWithTimestamp(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-event-timestamp.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
customTime := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
entry := &log.EventEntry{
|
||||
EventType: "test.event",
|
||||
ComputationId: "test",
|
||||
Timestamp: timestamppb.New(customTime),
|
||||
}
|
||||
|
||||
err = client.SendEvent(ctx, entry)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, mockServer.sendEventCalled)
|
||||
assert.Equal(t, customTime.Unix(), mockServer.lastEventEntry.Timestamp.AsTime().Unix())
|
||||
}
|
||||
|
||||
// TestClientSendLogWithCanceledContext tests SendLog with canceled context.
|
||||
func TestClientSendLogWithCanceledContext(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-cancel.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
entry := &log.LogEntry{
|
||||
Level: "INFO",
|
||||
Message: "test",
|
||||
}
|
||||
|
||||
err = client.SendLog(ctx, entry)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
}
|
||||
|
||||
// TestClientClose tests closing the client.
|
||||
func TestClientClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-close.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestClientOperationsAfterClose tests operations after closing.
|
||||
func TestClientOperationsAfterClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "log-after-close.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockLogCollectorServer{}
|
||||
log.RegisterLogCollectorServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
entry := &log.LogEntry{
|
||||
Level: "INFO",
|
||||
Message: "test",
|
||||
}
|
||||
|
||||
err = client.SendLog(ctx, entry)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error)
|
||||
Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type client struct {
|
||||
conn *grpc.ClientConn
|
||||
client pb.ComputationRunnerClient
|
||||
}
|
||||
|
||||
func NewClient(socketPath string) (Client, error) {
|
||||
conn, err := grpc.NewClient("unix://"+socketPath, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &client{
|
||||
conn: conn,
|
||||
client: pb.NewComputationRunnerClient(conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *client) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
|
||||
// Run might take long time, so we need unlimited timeout or rely on context cancellation
|
||||
return c.client.Run(ctx, req)
|
||||
}
|
||||
|
||||
func (c *client) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return c.client.Stop(ctx, req)
|
||||
}
|
||||
@@ -0,0 +1,349 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
pb "github.com/ultravioletrs/cocos/agent/runner"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// mockComputationRunnerServer is a mock implementation of the ComputationRunnerServer.
|
||||
type mockComputationRunnerServer struct {
|
||||
pb.UnimplementedComputationRunnerServer
|
||||
runCalled bool
|
||||
stopCalled bool
|
||||
runErr error
|
||||
stopErr error
|
||||
}
|
||||
|
||||
func (m *mockComputationRunnerServer) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
|
||||
m.runCalled = true
|
||||
if m.runErr != nil {
|
||||
return nil, m.runErr
|
||||
}
|
||||
return &pb.RunResponse{
|
||||
ComputationId: req.ComputationId,
|
||||
Error: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockComputationRunnerServer) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
|
||||
m.stopCalled = true
|
||||
if m.stopErr != nil {
|
||||
return nil, m.stopErr
|
||||
}
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
// TestNewClient tests creating a new gRPC client.
|
||||
func TestNewClient(t *testing.T) {
|
||||
// Create a temporary directory for the socket
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||||
|
||||
// Start a mock gRPC server
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create client
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
// Clean up
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNewClientInvalidSocket tests creating a client with invalid socket path.
|
||||
func TestNewClientInvalidSocket(t *testing.T) {
|
||||
// Use a non-existent socket path
|
||||
socketPath := "/tmp/nonexistent-" + time.Now().Format("20060102150405") + ".sock"
|
||||
|
||||
// Create client - this should succeed as grpc.NewClient is lazy
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
// Close should work even if never connected
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestClientRun tests the Run method.
|
||||
func TestClientRun(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test-run.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Call Run
|
||||
ctx := context.Background()
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
resp, err := client.Run(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, "test-computation", resp.ComputationId)
|
||||
assert.True(t, mockServer.runCalled)
|
||||
}
|
||||
|
||||
// TestClientRunWithCanceledContext tests Run with a canceled context.
|
||||
func TestClientRunWithCanceledContext(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test-run-cancel.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Create a canceled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
req := &pb.RunRequest{
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
_, err = client.Run(ctx, req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
}
|
||||
|
||||
// TestClientStop tests the Stop method.
|
||||
func TestClientStop(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test-stop.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Call Stop
|
||||
ctx := context.Background()
|
||||
req := &pb.StopRequest{
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
resp, err := client.Stop(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.True(t, mockServer.stopCalled)
|
||||
}
|
||||
|
||||
// TestClientStopWithTimeout tests Stop with context timeout.
|
||||
func TestClientStopWithTimeout(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test-stop-timeout.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
// Create a context that will timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req := &pb.StopRequest{
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
// Stop should complete within the timeout
|
||||
resp, err := client.Stop(ctx, req)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.True(t, mockServer.stopCalled)
|
||||
}
|
||||
|
||||
// TestClientClose tests the Close method.
|
||||
func TestClientClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test-close.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close should succeed
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestClientOperationsAfterClose tests that operations fail gracefully after close.
|
||||
func TestClientOperationsAfterClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "test-after-close.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close the client
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to use the client after closing
|
||||
ctx := context.Background()
|
||||
runReq := &pb.RunRequest{ComputationId: "test"}
|
||||
|
||||
// This should fail because connection is closed
|
||||
_, err = client.Run(ctx, runReq)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestNewClientWithRelativePath tests creating client with relative socket path.
|
||||
func TestNewClientWithRelativePath(t *testing.T) {
|
||||
// Create temp directory
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// Change to temp directory
|
||||
oldWd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := os.Chdir(oldWd)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
socketPath := "relative-test.sock"
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
mockServer := &mockComputationRunnerServer{}
|
||||
pb.RegisterComputationRunnerServer(grpcServer, mockServer)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
defer grpcServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create client with relative path
|
||||
client, err := NewClient(socketPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/runner"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewClient(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Client {
|
||||
mock := &Client{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Client is an autogenerated mock type for the Client type
|
||||
type Client struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Client_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Client) EXPECT() *Client_Expecter {
|
||||
return &Client_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Close provides a mock function for the type Client
|
||||
func (_mock *Client) Close() error {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Close")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Client_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type Client_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *Client_Expecter) Close() *Client_Close_Call {
|
||||
return &Client_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *Client_Close_Call) Run(run func()) *Client_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Client_Close_Call) Return(err error) *Client_Close_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Client_Close_Call) RunAndReturn(run func() error) *Client_Close_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Run provides a mock function for the type Client
|
||||
func (_mock *Client) Run(ctx context.Context, req *runner.RunRequest) (*runner.RunResponse, error) {
|
||||
ret := _mock.Called(ctx, req)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 *runner.RunResponse
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.RunRequest) (*runner.RunResponse, error)); ok {
|
||||
return returnFunc(ctx, req)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.RunRequest) *runner.RunResponse); ok {
|
||||
r0 = returnFunc(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*runner.RunResponse)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *runner.RunRequest) error); ok {
|
||||
r1 = returnFunc(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Client_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
|
||||
type Client_Run_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Run is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *runner.RunRequest
|
||||
func (_e *Client_Expecter) Run(ctx interface{}, req interface{}) *Client_Run_Call {
|
||||
return &Client_Run_Call{Call: _e.mock.On("Run", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *Client_Run_Call) Run(run func(ctx context.Context, req *runner.RunRequest)) *Client_Run_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *runner.RunRequest
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*runner.RunRequest)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Client_Run_Call) Return(runResponse *runner.RunResponse, err error) *Client_Run_Call {
|
||||
_c.Call.Return(runResponse, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Client_Run_Call) RunAndReturn(run func(ctx context.Context, req *runner.RunRequest) (*runner.RunResponse, error)) *Client_Run_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function for the type Client
|
||||
func (_mock *Client) Stop(ctx context.Context, req *runner.StopRequest) (*emptypb.Empty, error) {
|
||||
ret := _mock.Called(ctx, req)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 *emptypb.Empty
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.StopRequest) (*emptypb.Empty, error)); ok {
|
||||
return returnFunc(ctx, req)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *runner.StopRequest) *emptypb.Empty); ok {
|
||||
r0 = returnFunc(ctx, req)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*emptypb.Empty)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *runner.StopRequest) error); ok {
|
||||
r1 = returnFunc(ctx, req)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Client_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type Client_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - req *runner.StopRequest
|
||||
func (_e *Client_Expecter) Stop(ctx interface{}, req interface{}) *Client_Stop_Call {
|
||||
return &Client_Stop_Call{Call: _e.mock.On("Stop", ctx, req)}
|
||||
}
|
||||
|
||||
func (_c *Client_Stop_Call) Run(run func(ctx context.Context, req *runner.StopRequest)) *Client_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *runner.StopRequest
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*runner.StopRequest)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Client_Stop_Call) Return(empty *emptypb.Empty, err error) *Client_Stop_Call {
|
||||
_c.Call.Return(empty, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Client_Stop_Call) RunAndReturn(run func(ctx context.Context, req *runner.StopRequest) (*emptypb.Empty, error)) *Client_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package egress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// Proxy is an egress proxy server.
|
||||
type Proxy struct {
|
||||
logger *slog.Logger
|
||||
server *http.Server
|
||||
addr string
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
// NewProxy creates a new egress proxy.
|
||||
func NewProxy(logger *slog.Logger, addr string) *Proxy {
|
||||
// Create HTTP/2 capable transport
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
// Enable HTTP/2
|
||||
if err := http2.ConfigureTransport(transport); err != nil {
|
||||
logger.Warn("Failed to configure HTTP/2 transport", "error", err)
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
logger: logger,
|
||||
addr: addr,
|
||||
transport: transport,
|
||||
}
|
||||
p.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: http.HandlerFunc(p.handle),
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Start starts the proxy server.
|
||||
func (p *Proxy) Start() error {
|
||||
p.logger.Info("Starting egress proxy", "addr", p.addr)
|
||||
return p.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Stop stops the proxy server.
|
||||
func (p *Proxy) Stop(ctx context.Context) error {
|
||||
p.logger.Info("Stopping egress proxy")
|
||||
return p.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (p *Proxy) handle(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodConnect {
|
||||
p.handleConnect(w, r)
|
||||
} else if r.ProtoMajor == 2 {
|
||||
p.handleHTTP2(w, r)
|
||||
} else {
|
||||
p.handleHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
p.logger.Info("CONNECT request received", "host", host)
|
||||
|
||||
// nolint:godox // TODO: Check allowlist here - allowlist implementation deferred
|
||||
|
||||
p.logger.Debug("Dialing destination", "host", host)
|
||||
destConn, err := net.DialTimeout("tcp", host, 10*time.Second)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to dial destination", "host", host, "error", err)
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
defer destConn.Close()
|
||||
p.logger.Info("Successfully connected to destination", "host", host)
|
||||
|
||||
p.logger.Debug("Hijacking client connection")
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
p.logger.Error("Hijacking not supported")
|
||||
http.Error(w, "Hijacking not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
clientConn, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to hijack connection", "error", err)
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
p.logger.Info("Successfully hijacked client connection", "host", host)
|
||||
|
||||
// Send 200 Connection Established response
|
||||
p.logger.Debug("Sending 200 Connection Established")
|
||||
_, err = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to send CONNECT response", "error", err)
|
||||
return
|
||||
}
|
||||
p.logger.Info("Starting bidirectional pipe", "host", host)
|
||||
|
||||
p.pipe(clientConn, destConn)
|
||||
p.logger.Info("Pipe completed", "host", host)
|
||||
}
|
||||
|
||||
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
p.logger.Info("HTTP request", "method", r.Method, "url", r.URL.String())
|
||||
|
||||
// nolint:godox // TODO: Check allowlist here - allowlist implementation deferred
|
||||
|
||||
r.RequestURI = "" // RequestURI must be empty for Client.Do
|
||||
|
||||
// Remove hop-by-hop headers
|
||||
delHopHeaders(r.Header)
|
||||
|
||||
// Create a client to send the request
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(r)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to execute request", "error", err)
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Copy headers
|
||||
copyHeader(w.Header(), resp.Header)
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||
p.logger.Error("Failed to copy response body", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) handleHTTP2(w http.ResponseWriter, r *http.Request) {
|
||||
p.logger.Info("HTTP/2 request", "method", r.Method, "host", r.Host, "path", r.URL.Path)
|
||||
|
||||
// nolint:godox // TODO: Check allowlist here - allowlist implementation deferred
|
||||
|
||||
// Parse the target URL from the request
|
||||
targetURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: r.Host,
|
||||
}
|
||||
|
||||
// If the request has a full URL (absolute form), use it
|
||||
if r.URL.IsAbs() {
|
||||
targetURL = r.URL
|
||||
}
|
||||
|
||||
// Create a reverse proxy with HTTP/2 transport
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.Host = targetURL.Host
|
||||
|
||||
// Preserve the original path and query
|
||||
if !r.URL.IsAbs() {
|
||||
req.URL.Path = r.URL.Path
|
||||
req.URL.RawQuery = r.URL.RawQuery
|
||||
}
|
||||
|
||||
// Remove hop-by-hop headers
|
||||
delHopHeaders(req.Header)
|
||||
},
|
||||
Transport: p.transport,
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
p.logger.Error("HTTP/2 proxy error", "error", err, "host", r.Host)
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (p *Proxy) pipe(src, dst net.Conn) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := io.Copy(dst, src)
|
||||
p.logger.Debug("Pipe src->dst completed", "bytes", n, "error", err)
|
||||
// Close write end of dst if possible, or just close it
|
||||
if c, ok := dst.(*net.TCPConn); ok {
|
||||
if err := c.CloseWrite(); err != nil {
|
||||
p.logger.Debug("Failed to close write end of dst", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := io.Copy(src, dst)
|
||||
p.logger.Debug("Pipe dst->src completed", "bytes", n, "error", err)
|
||||
if c, ok := src.(*net.TCPConn); ok {
|
||||
if err := c.CloseWrite(); err != nil {
|
||||
p.logger.Debug("Failed to close write end of src", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func delHopHeaders(header http.Header) {
|
||||
// Standard hop-by-hop headers
|
||||
hopHeaders := []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te",
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
for _, h := range hopHeaders {
|
||||
header.Del(h)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,795 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package egress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestProxyHTTP(t *testing.T) {
|
||||
// 1. Start a backend server
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("backend response")); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// 2. Start Proxy
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
proxy := NewProxy(logger, ":0")
|
||||
|
||||
// Listen on a random port
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// waiting for server start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 3. Make request via proxy
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(backend.URL)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "backend response", string(body))
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestProxyConnect(t *testing.T) {
|
||||
// 1. Start a backend TLS server
|
||||
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("secure backend response")); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// 2. Start Proxy
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
// Listen on a random port
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 3. Configure client to use proxy
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTPS_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTPS_PROXY")
|
||||
|
||||
client := backend.Client() // This client trusts the test cert
|
||||
// But we need to update its transport proxy
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
} else {
|
||||
// Create new transport if needed, but backend.Client() returns transport with TLS config
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: client.Transport.(*http.Transport).TLSClientConfig,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
client.Transport = tr
|
||||
}
|
||||
|
||||
resp, err := client.Get(backend.URL)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "secure backend response", string(body))
|
||||
}
|
||||
|
||||
// TestProxyHTTP2 tests HTTP/2 requests through the proxy.
|
||||
func TestProxyHTTP2(t *testing.T) {
|
||||
// 1. Start a backend server
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("http2 response")); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// 2. Start Proxy
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 3. Make HTTP/2 request via proxy
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
}
|
||||
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
// This will be an HTTP/1.1 request unless explicitly configured for HTTP/2
|
||||
resp, err := client.Get(backend.URL)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, "http2 response", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyHeaderHandling tests that headers are properly handled.
|
||||
func TestProxyHeaderHandling(t *testing.T) {
|
||||
// Start a backend server that echoes headers
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Custom-Header", "custom-value")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(r.Header.Get("X-Request-Header"))); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Start Proxy
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Make request with custom headers
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
req, _ := http.NewRequest("GET", backend.URL, nil)
|
||||
req.Header.Set("X-Request-Header", "request-value")
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, "custom-value", resp.Header.Get("X-Custom-Header"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyWithDifferentMethods tests different HTTP methods.
|
||||
func TestProxyWithDifferentMethods(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(r.Method)); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE"}
|
||||
for _, method := range methods {
|
||||
req, _ := http.NewRequest(method, backend.URL, nil)
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
assert.Equal(t, method, string(body))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyErrorHandling tests error handling in the proxy.
|
||||
func TestProxyErrorHandling(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to connect to a non-existent backend
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
// This should fail because the backend doesn't exist
|
||||
resp, err := client.Get("http://nonexistent.example.com:99999")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp != nil {
|
||||
defer resp.Body.Close()
|
||||
// Status should be error
|
||||
assert.NotEqual(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyWithLargeBody tests proxy with large response body.
|
||||
func TestProxyWithLargeBody(t *testing.T) {
|
||||
largeBody := make([]byte, 1024*1024) // 1MB
|
||||
for i := range largeBody {
|
||||
largeBody[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(largeBody); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(backend.URL)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(largeBody), len(body))
|
||||
}
|
||||
|
||||
// TestCopyHeader tests the copyHeader utility function.
|
||||
func TestCopyHeader(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("X-Custom-Header", "value1")
|
||||
src.Add("X-Custom-Header", "value2")
|
||||
src.Add("Content-Type", "application/json")
|
||||
|
||||
dst := http.Header{}
|
||||
copyHeader(dst, src)
|
||||
|
||||
assert.Equal(t, []string{"value1", "value2"}, dst["X-Custom-Header"])
|
||||
assert.Equal(t, []string{"application/json"}, dst["Content-Type"])
|
||||
}
|
||||
|
||||
// TestDelHopHeaders tests the delHopHeaders utility function.
|
||||
func TestDelHopHeaders(t *testing.T) {
|
||||
header := http.Header{}
|
||||
header.Set("Connection", "keep-alive")
|
||||
header.Set("Keep-Alive", "timeout=5")
|
||||
header.Set("Proxy-Authenticate", "Basic")
|
||||
header.Set("Proxy-Authorization", "Bearer token")
|
||||
header.Set("Te", "trailers")
|
||||
header.Set("Trailers", "X-Custom")
|
||||
header.Set("Transfer-Encoding", "chunked")
|
||||
header.Set("Upgrade", "websocket")
|
||||
header.Set("X-Custom-Header", "should-remain")
|
||||
|
||||
delHopHeaders(header)
|
||||
|
||||
// Hop-by-hop headers should be removed
|
||||
assert.Empty(t, header.Get("Connection"))
|
||||
assert.Empty(t, header.Get("Keep-Alive"))
|
||||
assert.Empty(t, header.Get("Proxy-Authenticate"))
|
||||
assert.Empty(t, header.Get("Proxy-Authorization"))
|
||||
assert.Empty(t, header.Get("Te"))
|
||||
assert.Empty(t, header.Get("Trailers"))
|
||||
assert.Empty(t, header.Get("Transfer-Encoding"))
|
||||
assert.Empty(t, header.Get("Upgrade"))
|
||||
|
||||
// Custom headers should remain
|
||||
assert.Equal(t, "should-remain", header.Get("X-Custom-Header"))
|
||||
}
|
||||
|
||||
// TestProxyConnectDialTimeout tests CONNECT with dial timeout.
|
||||
func TestProxyConnectDialTimeout(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to CONNECT to a non-routable address (should timeout)
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTPS_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTPS_PROXY")
|
||||
|
||||
// Create a client with very short timeout
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
// This should fail because 192.0.2.1 is a TEST-NET address (non-routable)
|
||||
_, err = client.Get("https://192.0.2.1:9999/test")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestProxyHTTPWithRedirect tests HTTP proxy handling redirects.
|
||||
func TestProxyHTTPWithRedirect(t *testing.T) {
|
||||
// Create a backend that redirects
|
||||
redirectCount := 0
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if redirectCount == 0 {
|
||||
redirectCount++
|
||||
http.Redirect(w, r, "/redirected", http.StatusFound)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("redirected response")); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Get(backend.URL)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "redirected response", string(body))
|
||||
}
|
||||
|
||||
// TestProxyStopContext tests proxy stop with context.
|
||||
func TestProxyStopContext(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = proxy.Stop(ctx)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestProxyPipeWithRealConnections tests the pipe function with real TCP connections.
|
||||
func TestProxyPipeWithRealConnections(t *testing.T) {
|
||||
// Create two connected TCP connections
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
// Channel to receive the server connection
|
||||
serverConnChan := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
serverConnChan <- conn
|
||||
}
|
||||
}()
|
||||
|
||||
// Create client connection
|
||||
clientConn, err := net.Dial("tcp", listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
// Get server connection
|
||||
serverConn := <-serverConnChan
|
||||
defer serverConn.Close()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
proxy := NewProxy(logger, ":0")
|
||||
|
||||
// Test data transfer
|
||||
testData := []byte("test data for pipe")
|
||||
|
||||
// Start pipe in goroutine
|
||||
go proxy.pipe(clientConn, serverConn)
|
||||
|
||||
// Write from client
|
||||
_, err = clientConn.Write(testData)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read from server
|
||||
buf := make([]byte, len(testData))
|
||||
if err := serverConn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
|
||||
t.Logf("Failed to set read deadline: %v", err)
|
||||
}
|
||||
n, err := serverConn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testData, buf[:n])
|
||||
|
||||
// Close connections to trigger pipe completion
|
||||
clientConn.Close()
|
||||
serverConn.Close()
|
||||
|
||||
// Give pipe time to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestProxyHTTP2ErrorPath tests HTTP/2 error handler.
|
||||
func TestProxyHTTP2ErrorPath(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create a request that will trigger HTTP/2 handling
|
||||
req, err := http.NewRequest("GET", "http://"+ln.Addr().String()+"/test", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Force HTTP/2 by setting the request protocol
|
||||
req.ProtoMajor = 2
|
||||
req.ProtoMinor = 0
|
||||
req.Host = "nonexistent.invalid:9999" // This should cause an error
|
||||
|
||||
// Create a response recorder
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Call the handler directly to test HTTP/2 error path
|
||||
proxy.server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
// Should get an error response
|
||||
assert.Equal(t, http.StatusBadGateway, rr.Code)
|
||||
}
|
||||
|
||||
// TestNewProxyHTTP2ConfigWarning tests NewProxy when HTTP/2 configuration might fail.
|
||||
func TestNewProxyHTTP2ConfigWarning(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
// Create proxy - HTTP/2 configuration should succeed normally
|
||||
proxy := NewProxy(logger, ":0")
|
||||
|
||||
assert.NotNil(t, proxy)
|
||||
assert.NotNil(t, proxy.transport)
|
||||
assert.True(t, proxy.transport.ForceAttemptHTTP2)
|
||||
}
|
||||
|
||||
// TestProxyHandleHTTPError tests HTTP handler error path.
|
||||
func TestProxyHandleHTTPError(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTP_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTP_PROXY")
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
},
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
// Try to connect to invalid backend
|
||||
resp, err := client.Get("http://invalid.backend.test:99999/test")
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
// Should get error status
|
||||
assert.NotEqual(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
// Either error or bad gateway response is acceptable
|
||||
}
|
||||
|
||||
// TestProxyConnectWriteError tests CONNECT with write error after hijacking.
|
||||
func TestProxyConnectWriteError(t *testing.T) {
|
||||
// This test is challenging because we need to trigger a write error
|
||||
// after successful hijacking. We'll test the path by using a closed connection.
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create a backend server for CONNECT
|
||||
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
proxyURL := fmt.Sprintf("http://%s", ln.Addr().String())
|
||||
os.Setenv("HTTPS_PROXY", proxyURL)
|
||||
defer os.Unsetenv("HTTPS_PROXY")
|
||||
|
||||
client := backend.Client()
|
||||
if transport, ok := client.Transport.(*http.Transport); ok {
|
||||
transport.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
|
||||
// Make a request through CONNECT
|
||||
_, err = client.Get(backend.URL)
|
||||
// The request may succeed or fail, but we're testing the code path
|
||||
if err != nil {
|
||||
t.Logf("Request error (expected in some cases): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyHTTP2WithAbsoluteURL tests HTTP/2 handling with absolute URL.
|
||||
func TestProxyHTTP2WithAbsoluteURL(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("http2 absolute url response")); err != nil {
|
||||
t.Logf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := NewProxy(logger, ln.Addr().String())
|
||||
proxy.server.Addr = ln.Addr().String()
|
||||
|
||||
go func() {
|
||||
if err := proxy.server.Serve(ln); err != nil && err != http.ErrServerClosed {
|
||||
t.Logf("Proxy server error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := proxy.Stop(context.Background()); err != nil {
|
||||
t.Logf("Failed to stop proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create request with absolute URL
|
||||
req, err := http.NewRequest("GET", backend.URL+"/test", nil)
|
||||
require.NoError(t, err)
|
||||
req.ProtoMajor = 2
|
||||
req.ProtoMinor = 0
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
proxy.server.Handler.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rr.Code)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package ingress
|
||||
|
||||
import "github.com/ultravioletrs/cocos/agent"
|
||||
|
||||
// AgentConfigToProxyConfig converts agent.AgentConfig to ProxyConfig.
|
||||
func AgentConfigToProxyConfig(cfg agent.AgentConfig) ProxyConfig {
|
||||
return ProxyConfig{
|
||||
Port: "7002", // Ingress-proxy always uses port 7002
|
||||
CertFile: cfg.CertFile,
|
||||
KeyFile: cfg.KeyFile,
|
||||
ServerCAFile: cfg.ServerCAFile,
|
||||
ClientCAFile: cfg.ClientCAFile,
|
||||
AttestedTLS: cfg.AttestedTls,
|
||||
}
|
||||
}
|
||||
|
||||
// ComputationToProxyContext converts agent.Computation to ProxyContext.
|
||||
func ComputationToProxyContext(cmp agent.Computation) ProxyContext {
|
||||
return ProxyContext{
|
||||
ID: cmp.ID,
|
||||
Name: cmp.Name,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
)
|
||||
|
||||
// TestAgentConfigToProxyConfig tests conversion from AgentConfig to ProxyConfig.
|
||||
func TestAgentConfigToProxyConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input agent.AgentConfig
|
||||
expected ProxyConfig
|
||||
}{
|
||||
{
|
||||
name: "basic config without TLS",
|
||||
input: agent.AgentConfig{
|
||||
CertFile: "",
|
||||
KeyFile: "",
|
||||
ServerCAFile: "",
|
||||
ClientCAFile: "",
|
||||
AttestedTls: false,
|
||||
},
|
||||
expected: ProxyConfig{
|
||||
Port: "7002",
|
||||
CertFile: "",
|
||||
KeyFile: "",
|
||||
ServerCAFile: "",
|
||||
ClientCAFile: "",
|
||||
AttestedTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config with regular TLS",
|
||||
input: agent.AgentConfig{
|
||||
CertFile: "/path/to/cert.pem",
|
||||
KeyFile: "/path/to/key.pem",
|
||||
ServerCAFile: "/path/to/server-ca.pem",
|
||||
ClientCAFile: "/path/to/client-ca.pem",
|
||||
AttestedTls: false,
|
||||
},
|
||||
expected: ProxyConfig{
|
||||
Port: "7002",
|
||||
CertFile: "/path/to/cert.pem",
|
||||
KeyFile: "/path/to/key.pem",
|
||||
ServerCAFile: "/path/to/server-ca.pem",
|
||||
ClientCAFile: "/path/to/client-ca.pem",
|
||||
AttestedTLS: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config with attested TLS",
|
||||
input: agent.AgentConfig{
|
||||
CertFile: "",
|
||||
KeyFile: "",
|
||||
ServerCAFile: "/path/to/server-ca.pem",
|
||||
ClientCAFile: "/path/to/client-ca.pem",
|
||||
AttestedTls: true,
|
||||
},
|
||||
expected: ProxyConfig{
|
||||
Port: "7002",
|
||||
CertFile: "",
|
||||
KeyFile: "",
|
||||
ServerCAFile: "/path/to/server-ca.pem",
|
||||
ClientCAFile: "/path/to/client-ca.pem",
|
||||
AttestedTLS: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config with mTLS",
|
||||
input: agent.AgentConfig{
|
||||
CertFile: "/path/to/cert.pem",
|
||||
KeyFile: "/path/to/key.pem",
|
||||
ServerCAFile: "/path/to/server-ca.pem",
|
||||
ClientCAFile: "/path/to/client-ca.pem",
|
||||
AttestedTls: false,
|
||||
},
|
||||
expected: ProxyConfig{
|
||||
Port: "7002",
|
||||
CertFile: "/path/to/cert.pem",
|
||||
KeyFile: "/path/to/key.pem",
|
||||
ServerCAFile: "/path/to/server-ca.pem",
|
||||
ClientCAFile: "/path/to/client-ca.pem",
|
||||
AttestedTLS: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AgentConfigToProxyConfig(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestComputationToProxyContext tests conversion from Computation to ProxyContext.
|
||||
func TestComputationToProxyContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input agent.Computation
|
||||
expected ProxyContext
|
||||
}{
|
||||
{
|
||||
name: "computation with name",
|
||||
input: agent.Computation{
|
||||
ID: "comp-123",
|
||||
Name: "test-computation",
|
||||
Description: "A test computation",
|
||||
},
|
||||
expected: ProxyContext{
|
||||
ID: "comp-123",
|
||||
Name: "test-computation",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "computation without name",
|
||||
input: agent.Computation{
|
||||
ID: "comp-456",
|
||||
Name: "",
|
||||
Description: "Another test computation",
|
||||
},
|
||||
expected: ProxyContext{
|
||||
ID: "comp-456",
|
||||
Name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "computation with special characters in name",
|
||||
input: agent.Computation{
|
||||
ID: "comp-789",
|
||||
Name: "test-computation-with-dashes_and_underscores",
|
||||
Description: "Computation with special chars",
|
||||
},
|
||||
expected: ProxyContext{
|
||||
ID: "comp-789",
|
||||
Name: "test-computation-with-dashes_and_underscores",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ComputationToProxyContext(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentConfigToProxyConfigPortIsFixed tests that port is always set to 7002.
|
||||
func TestAgentConfigToProxyConfigPortIsFixed(t *testing.T) {
|
||||
configs := []agent.AgentConfig{
|
||||
{},
|
||||
{CertFile: "/cert.pem", KeyFile: "/key.pem"},
|
||||
{AttestedTls: true},
|
||||
}
|
||||
|
||||
for i, cfg := range configs {
|
||||
result := AgentConfigToProxyConfig(cfg)
|
||||
assert.Equal(t, "7002", result.Port, "Port should always be 7002 for config %d", i)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/atls"
|
||||
"github.com/ultravioletrs/cocos/pkg/server"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
// ProxyConfig contains configuration for starting a proxy instance.
|
||||
type ProxyConfig struct {
|
||||
Port string
|
||||
CertFile string
|
||||
KeyFile string
|
||||
ServerCAFile string
|
||||
ClientCAFile string
|
||||
AttestedTLS bool
|
||||
}
|
||||
|
||||
// ProxyContext provides context information for logging and tracking.
|
||||
type ProxyContext struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
// ProxyServer manages ingress proxy instances.
|
||||
type ProxyServer interface {
|
||||
Start(cfg ProxyConfig, ctx ProxyContext) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type proxyServer struct {
|
||||
mu sync.RWMutex
|
||||
logger *slog.Logger
|
||||
backendURL *url.URL
|
||||
certProvider atls.CertificateProvider
|
||||
httpServer *http.Server
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// NewProxyServer creates a new ingress proxy server manager.
|
||||
func NewProxyServer(logger *slog.Logger, backendURL *url.URL, certProvider atls.CertificateProvider) ProxyServer {
|
||||
return &proxyServer{
|
||||
logger: logger,
|
||||
backendURL: backendURL,
|
||||
certProvider: certProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the proxy server with the given configuration.
|
||||
func (p *proxyServer) Start(cfg ProxyConfig, ctx ProxyContext) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.started {
|
||||
return fmt.Errorf("proxy server already started")
|
||||
}
|
||||
if p.stopped {
|
||||
return fmt.Errorf("proxy server already stopped")
|
||||
}
|
||||
|
||||
if cfg.Port == "" {
|
||||
cfg.Port = "7002"
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("0.0.0.0:%s", cfg.Port)
|
||||
|
||||
// Configure Reverse Proxy
|
||||
var rp *httputil.ReverseProxy
|
||||
|
||||
// Check if backend is Unix socket or TCP
|
||||
if p.backendURL.Scheme == "unix" {
|
||||
// For Unix socket backend, we need to manually configure the reverse proxy
|
||||
// because NewSingleHostReverseProxy doesn't support unix:// scheme
|
||||
targetURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "unix",
|
||||
}
|
||||
rp = httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
// Override the Director to not modify the request
|
||||
originalDirector := rp.Director
|
||||
rp.Director = func(req *http.Request) {
|
||||
originalDirector(req)
|
||||
// Set the URL to point to the backend service
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = "unix"
|
||||
}
|
||||
|
||||
// Configure Transport for Unix socket with HTTP/2
|
||||
rp.Transport = &http2.Transport{
|
||||
AllowHTTP: true,
|
||||
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
// Use Unix socket path from URL
|
||||
return d.DialContext(ctx, "unix", p.backendURL.Path)
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// TCP backend
|
||||
rp = httputil.NewSingleHostReverseProxy(p.backendURL)
|
||||
rp.Transport = &http2.Transport{
|
||||
AllowHTTP: true,
|
||||
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap handler with h2c for HTTP/2 cleartext support (required for gRPC without TLS)
|
||||
h2cHandler := h2c.NewHandler(rp, &http2.Server{})
|
||||
|
||||
p.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: h2cHandler,
|
||||
}
|
||||
|
||||
// Configure TLS
|
||||
var tlsConfig *tls.Config
|
||||
contextDesc := fmt.Sprintf("context %s", ctx.ID)
|
||||
if ctx.Name != "" {
|
||||
contextDesc = fmt.Sprintf("%s (%s)", ctx.Name, ctx.ID)
|
||||
}
|
||||
|
||||
if cfg.AttestedTLS {
|
||||
if p.certProvider == nil {
|
||||
return fmt.Errorf("attested TLS requested but no certificate provider available")
|
||||
}
|
||||
tlsConfig = &tls.Config{
|
||||
GetCertificate: p.certProvider.GetCertificate,
|
||||
ClientAuth: tls.NoClientCert,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
|
||||
mtls, err := server.ConfigureCertificateAuthorities(tlsConfig, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure certificate authorities: %w", err)
|
||||
}
|
||||
|
||||
if mtls {
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested mTLS for %s", addr, contextDesc))
|
||||
} else {
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with Attested TLS for %s", addr, contextDesc))
|
||||
}
|
||||
} else if cfg.CertFile != "" && cfg.KeyFile != "" {
|
||||
// Regular TLS
|
||||
tlsSetup, err := server.SetupRegularTLS(cfg.CertFile, cfg.KeyFile, cfg.ServerCAFile, cfg.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS: %w", err)
|
||||
}
|
||||
tlsConfig = tlsSetup.Config
|
||||
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
|
||||
if tlsSetup.MTLS {
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with mTLS for %s", addr, contextDesc))
|
||||
} else {
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s with TLS for %s", addr, contextDesc))
|
||||
}
|
||||
} else {
|
||||
p.logger.Info(fmt.Sprintf("ingress-proxy listening at %s without TLS for %s", addr, contextDesc))
|
||||
}
|
||||
|
||||
p.started = true
|
||||
|
||||
// Start server in goroutine
|
||||
go func() {
|
||||
var err error
|
||||
if tlsConfig != nil {
|
||||
ln, listenErr := net.Listen("tcp", addr)
|
||||
if listenErr != nil {
|
||||
p.logger.Error(fmt.Sprintf("failed to listen: %s", listenErr))
|
||||
return
|
||||
}
|
||||
tlsLn := tls.NewListener(ln, tlsConfig)
|
||||
err = p.httpServer.Serve(tlsLn)
|
||||
} else {
|
||||
err = p.httpServer.ListenAndServe()
|
||||
}
|
||||
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
p.logger.Error(fmt.Sprintf("ingress-proxy server error: %s", err))
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the proxy server.
|
||||
func (p *proxyServer) Stop() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.stopped {
|
||||
return nil
|
||||
}
|
||||
p.stopped = true
|
||||
|
||||
if p.httpServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*1000000000) // 5 seconds
|
||||
defer cancel()
|
||||
if err := p.httpServer.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown server: %w", err)
|
||||
}
|
||||
p.logger.Info("ingress-proxy stopped")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,477 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/atls/mocks"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func createTempCert(t *testing.T) (certFile, keyFile string) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "server.crt")
|
||||
keyPath := filepath.Join(tmpDir, "server.key")
|
||||
|
||||
certOut, err := os.Create(certPath)
|
||||
require.NoError(t, err)
|
||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
require.NoError(t, err)
|
||||
certOut.Close()
|
||||
|
||||
keyOut, err := os.Create(keyPath)
|
||||
require.NoError(t, err)
|
||||
b, err := x509.MarshalECPrivateKey(priv)
|
||||
require.NoError(t, err)
|
||||
err = pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
|
||||
require.NoError(t, err)
|
||||
keyOut.Close()
|
||||
|
||||
return certPath, keyPath
|
||||
}
|
||||
|
||||
func getBackendURL() *url.URL {
|
||||
u, _ := url.Parse("http://localhost:8080")
|
||||
return u
|
||||
}
|
||||
|
||||
// TestNewProxyServer tests the creation of a new proxy server.
|
||||
func TestNewProxyServer(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
require.NotNil(t, ps)
|
||||
}
|
||||
|
||||
// TestProxyStartStop tests basic start and stop operations.
|
||||
func TestProxyStartStop(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "test-1", Name: "test-proxy"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = ps.Stop()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestProxyStartWithoutPort tests proxy without explicit port.
|
||||
func TestProxyStartWithoutPort(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
cfg := ProxyConfig{Port: ""}
|
||||
ctx := ProxyContext{ID: "test-2"}
|
||||
|
||||
err := ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestProxyStartAlreadyStarted tests error when starting twice.
|
||||
func TestProxyStartAlreadyStarted(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "test-3"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "proxy server already started", err.Error())
|
||||
}
|
||||
|
||||
// TestProxyStartAfterStopped tests error when starting after stop.
|
||||
func TestProxyStartAfterStopped(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "test-4"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = ps.Stop()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
// After stop, attempts to start will fail with "already started" error first
|
||||
assert.Contains(t, err.Error(), "proxy server already")
|
||||
}
|
||||
|
||||
// TestProxyWithName tests proxy context with name.
|
||||
func TestProxyWithName(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "id-1", Name: "named-proxy"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestProxyWithoutName tests proxy context without name.
|
||||
func TestProxyWithoutName(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "id-only"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestProxyMultipleStops tests multiple stop calls.
|
||||
func TestProxyMultipleStops(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{Port: fmt.Sprintf("%d", port)}
|
||||
ctx := ProxyContext{ID: "test-multi-stop"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = ps.Stop()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ps.Stop()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestProxyWithoutTLS tests proxy without TLS.
|
||||
func TestProxyWithoutTLS(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
AttestedTLS: false,
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-no-tls"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestProxyWithUnixBackend(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
// Create a temp directory for socket
|
||||
dir := t.TempDir()
|
||||
sockPath := filepath.Join(dir, "backend.sock")
|
||||
|
||||
// Start a dummy backend on unix socket
|
||||
l, err := net.Listen("unix", sockPath)
|
||||
require.NoError(t, err)
|
||||
defer l.Close()
|
||||
|
||||
backendCalled := false
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
h2s := &http2.Server{}
|
||||
h2cHandler := h2c.NewHandler(handler, h2s)
|
||||
|
||||
go func() {
|
||||
_ = http.Serve(l, h2cHandler)
|
||||
}()
|
||||
|
||||
// Configure proxy to use this unix socket
|
||||
backendURL, _ := url.Parse("unix://" + sockPath)
|
||||
ps := NewProxyServer(logger, backendURL, nil)
|
||||
|
||||
// Find free port for proxy
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-unix"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Make request to proxy
|
||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port))
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.True(t, backendCalled)
|
||||
}
|
||||
|
||||
func TestProxyRegularTLS(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
certFile, keyFile := createTempCert(t)
|
||||
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
CertFile: certFile,
|
||||
KeyFile: keyFile,
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-tls"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Client with skipped verification
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
|
||||
resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port))
|
||||
// Backend is not running/reachable so 502 or error is expected from reverse proxy,
|
||||
// but 502 means connection to proxy succeeded.
|
||||
// If the proxy itself was not working, we'd get connection refused or similar.
|
||||
require.NoError(t, err)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyRegularTLSInvalidFiles(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
CertFile: "non-existent.crt",
|
||||
KeyFile: "non-existent.key",
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-tls-fail"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to setup TLS")
|
||||
}
|
||||
|
||||
func TestProxyAttestedTLS(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
mockProvider := mocks.NewCertificateProvider(t)
|
||||
// We don't expect calls during Listen, only during handshake.
|
||||
// But Start logic doesn't block waiting for handshake.
|
||||
|
||||
ps := NewProxyServer(logger, getBackendURL(), mockProvider)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
AttestedTLS: true,
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-attested-tls"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
}
|
||||
|
||||
func TestProxyAttestedTLSMissingProvider(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: "0",
|
||||
AttestedTLS: true,
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-attested-fail"}
|
||||
|
||||
err := ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "attested TLS requested but no certificate provider available", err.Error())
|
||||
}
|
||||
|
||||
func TestProxyMTLS(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
certFile, _ := createTempCert(t)
|
||||
|
||||
mockProvider := mocks.NewCertificateProvider(t)
|
||||
|
||||
ps := NewProxyServer(logger, getBackendURL(), mockProvider)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
// Test case: AttestedTLS with ClientCAFile (mTLS)
|
||||
// server.ConfigureCertificateAuthorities reads the file.
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
AttestedTLS: true,
|
||||
ClientCAFile: certFile, // Use self-signed cert as CA
|
||||
ServerCAFile: certFile, // Also for server CA
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-mtls"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
}
|
||||
|
||||
func TestProxyRegularMTLS(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
certFile, keyFile := createTempCert(t)
|
||||
|
||||
ps := NewProxyServer(logger, getBackendURL(), nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listener.Close()
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: fmt.Sprintf("%d", port),
|
||||
CertFile: certFile,
|
||||
KeyFile: keyFile,
|
||||
ServerCAFile: certFile,
|
||||
ClientCAFile: certFile,
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-regular-mtls"}
|
||||
|
||||
err = ps.Start(cfg, ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = ps.Stop() }()
|
||||
}
|
||||
|
||||
func TestProxyAttestedTLSInvalidCA(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
mockProvider := mocks.NewCertificateProvider(t)
|
||||
|
||||
ps := NewProxyServer(logger, getBackendURL(), mockProvider)
|
||||
|
||||
cfg := ProxyConfig{
|
||||
Port: "0",
|
||||
AttestedTLS: true,
|
||||
ServerCAFile: "non-existent.pem",
|
||||
}
|
||||
ctx := ProxyContext{ID: "test-attested-invalid-ca"}
|
||||
|
||||
err := ps.Start(cfg, ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to configure certificate authorities")
|
||||
}
|
||||
+27
-5
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -111,10 +112,26 @@ func (s *Server) Start() error {
|
||||
|
||||
grpcServerOptions = append(grpcServerOptions, creds)
|
||||
|
||||
// Create listener
|
||||
listener, err := net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
// Create listener - detect Unix socket vs TCP
|
||||
var listener net.Listener
|
||||
baseConfig := s.Config.GetBaseConfig()
|
||||
|
||||
// Check if this is a Unix socket path (starts with /)
|
||||
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
|
||||
// Unix socket
|
||||
// Remove existing socket file if it exists
|
||||
_ = os.Remove(baseConfig.Host)
|
||||
|
||||
listener, err = net.Listen("unix", baseConfig.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on Unix socket %s: %w", baseConfig.Host, err)
|
||||
}
|
||||
} else {
|
||||
// TCP socket
|
||||
listener, err = net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %s: %w", s.Address, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create and configure server
|
||||
@@ -160,7 +177,12 @@ func (s *Server) configureCredentials() (grpc.ServerOption, error) {
|
||||
}
|
||||
|
||||
// Use insecure credentials
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, s.Address))
|
||||
// Determine address for logging
|
||||
addr := s.Address
|
||||
if len(baseConfig.Host) > 0 && baseConfig.Host[0] == '/' {
|
||||
addr = baseConfig.Host
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s service gRPC server listening at %s without TLS", s.Name, addr))
|
||||
return grpc.Creds(insecure.NewCredentials()), nil
|
||||
}
|
||||
|
||||
|
||||
@@ -78,6 +78,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Debug("sending computation run request")
|
||||
if err := sendMessage(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReq{
|
||||
RunReq: &cvms.ComputationRunReq{
|
||||
@@ -98,6 +99,11 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmsgrpc.Se
|
||||
s.logger.Error(fmt.Sprintf("failed to send run request: %s", err))
|
||||
return
|
||||
}
|
||||
s.logger.Info("computation run request sent successfully")
|
||||
|
||||
// Keep the connection alive
|
||||
<-ctx.Done()
|
||||
s.logger.Info("connection closed")
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
Reference in New Issue
Block a user