NOISSUE - Introduce computation runner, log forwarder, ingress, and egress proxy services. (#559)

* feat: Introduce computation runner, log forwarder, ingress, and egress proxy services.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update Go environment variable parsing and build system to use new architecture and repository.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update package sources to `sammyoina/cocos-ai` at a specific commit, add log-forwarder pre-start hook, and rename proxy binaries.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: Update build system references to a specific commit and enhance logging for service connections and message processing.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* build: Update package source repositories and versions, migrate client logging to slog, and adjust ingress/egress proxy build and install steps.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug stuck

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add HTTP/2 support to egress proxy and update build system to use specific commit hashes

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: enhance egress proxy CONNECT handling, update package sources, and add gRPC test utility

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update build system for various services to a specific commit from a new repository, change agent gRPC port to 7001, and add a gRPC test client.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Migrate agent-internal gRPC communication to Unix sockets, set ingress proxy to port 7002, and update build hashes.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: Remove standalone ingress-proxy systemd service and update component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix: Prevent computation re-initialization in agent and update component versions across several packages.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: update package versions and enable h2c support in ingress proxy.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: refactor ingress proxy to support HTTP/2 over Unix sockets and update component versions.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: Update build system package sources to `ultravioletrs/cocos` and reduce agent logging verbosity.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: improve error handling in proxy commands and remove unused gRPC test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add mock service state return value in handleRunReqChunks test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* feat: add comprehensive tests for service and proxy components

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix linter

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* test: add gRPC client and ingress adapter tests, and update egress proxy tests.

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2026-02-09 12:38:21 +03:00
committed by GitHub
parent ee52551ca4
commit a3265bc346
57 changed files with 6529 additions and 162 deletions
-1
View File
@@ -13,7 +13,6 @@ import (
var _ fmt.Stringer = (*Datasets)(nil)
type AgentConfig struct {
Port string `json:"port,omitempty"`
CertFile string `json:"cert_file,omitempty"`
KeyFile string `json:"server_key,omitempty"`
ServerCAFile string `json:"server_ca_file,omitempty"`
+6 -7
View File
@@ -105,16 +105,15 @@ func TestDecompressToContext(t *testing.T) {
}
func TestAgentConfigJSON(t *testing.T) {
config := AgentConfig{
Port: "8080",
cfg := AgentConfig{
CertFile: "cert.pem",
KeyFile: "key.pem",
ServerCAFile: "server_ca.pem",
ClientCAFile: "client_ca.pem",
ServerCAFile: "server-ca.pem",
ClientCAFile: "client-ca.pem",
AttestedTls: true,
}
data, err := json.Marshal(config)
data, err := json.Marshal(cfg)
if err != nil {
t.Fatalf("Failed to marshal AgentConfig: %v", err)
}
@@ -125,7 +124,7 @@ func TestAgentConfigJSON(t *testing.T) {
t.Fatalf("Failed to unmarshal AgentConfig: %v", err)
}
if !reflect.DeepEqual(config, unmarshaledConfig) {
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config)
if !reflect.DeepEqual(cfg, unmarshaledConfig) {
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, cfg)
}
}
+39 -2
View File
@@ -5,6 +5,7 @@ package grpc
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"sync"
"time"
@@ -17,6 +18,7 @@ import (
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/ingress"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
)
@@ -44,13 +46,14 @@ type CVMSClient struct {
logger *slog.Logger
runReqManager *runRequestManager
sp server.AgentServer
ingressProxy ingress.ProxyServer
storage storage.Storage
reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error)
grpcClient grpc.Client
}
// NewClient returns new gRPC client instance.
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) {
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer, ingressProxy ingress.ProxyServer, storageDir string, reconnectFn func(context.Context) (grpc.Client, cvms.Service_ProcessClient, error), grpcClient grpc.Client) (*CVMSClient, error) {
store, err := storage.NewFileStorage(storageDir)
if err != nil {
return nil, err
@@ -63,6 +66,7 @@ func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueu
logger: logger,
runReqManager: newRunRequestManager(),
sp: sp,
ingressProxy: ingressProxy,
storage: store,
reconnectFn: reconnectFn,
grpcClient: grpcClient,
@@ -205,14 +209,17 @@ func (client *CVMSClient) handleAgentStateReq(mes *cvms.ServerStreamMessage_Agen
}
func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error {
client.logger.Debug("Received RunReq chunk", "id", msg.RunReqChunks.Id, "size", len(msg.RunReqChunks.Data), "isLast", msg.RunReqChunks.IsLast)
buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast)
if complete {
client.logger.Info("Received complete computation run request", "id", msg.RunReqChunks.Id, "totalSize", len(buffer))
var runReq cvms.ComputationRunReq
if err := proto.Unmarshal(buffer, &runReq); err != nil {
return errors.Wrap(err, errCorruptedManifest)
}
client.logger.Info("Starting computation execution", "computationId", runReq.Id, "name", runReq.Name)
go client.executeRun(ctx, &runReq)
}
@@ -246,6 +253,15 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
})
}
// Check if the agent is in the correct state to initialize a new computation.
// If the agent is already processing this computation (e.g., after a reconnection),
// skip initialization to avoid state errors.
currentState := client.svc.State()
if currentState != "ReceivingManifest" {
client.logger.Info("Agent already processing computation, skipping initialization", "state", currentState, "computationId", runReq.Id)
return
}
if err := client.svc.InitComputation(ctx, ac); err != nil {
client.logger.Warn(err.Error())
return
@@ -267,7 +283,6 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
}
if err := client.sp.Start(agent.AgentConfig{
Port: runReq.AgentConfig.Port,
CertFile: runReq.AgentConfig.CertFile,
KeyFile: runReq.AgentConfig.KeyFile,
ServerCAFile: runReq.AgentConfig.ServerCaFile,
@@ -278,6 +293,22 @@ func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.Computati
runRes.RunRes.Error = err.Error()
}
// Start ingress proxy if available
if client.ingressProxy != nil {
if err := client.ingressProxy.Start(
ingress.AgentConfigToProxyConfig(agent.AgentConfig{
CertFile: runReq.AgentConfig.CertFile,
KeyFile: runReq.AgentConfig.KeyFile,
ServerCAFile: runReq.AgentConfig.ServerCaFile,
ClientCAFile: runReq.AgentConfig.ClientCaFile,
AttestedTls: runReq.AgentConfig.AttestedTls,
}),
ingress.ComputationToProxyContext(ac),
); err != nil {
client.logger.Warn(fmt.Sprintf("failed to start ingress proxy: %s", err.Error()))
}
}
defer func() {
if ccPlatform == attestation.Azure || ccPlatform == attestation.SNPvTPM {
cmpJson, err := json.Marshal(ac)
@@ -309,6 +340,12 @@ func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.S
if err := client.sp.Stop(); err != nil {
msg.StopComputationRes.Message = err.Error()
}
// Stop ingress proxy if available
if client.ingressProxy != nil {
if err := client.ingressProxy.Stop(); err != nil {
client.logger.Warn(fmt.Sprintf("failed to stop ingress proxy: %s", err.Error()))
}
}
client.mu.Unlock()
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
+261 -3
View File
@@ -11,10 +11,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/agent/cvms/api/grpc/storage"
servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks"
"github.com/ultravioletrs/cocos/agent/mocks"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
clientmocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/mocks"
"github.com/ultravioletrs/cocos/pkg/ingress"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
@@ -35,6 +37,21 @@ func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error {
return args.Error(0)
}
// mockIngressProxy is a mock implementation of the ingress proxy.
type mockIngressProxy struct {
mock.Mock
}
func (m *mockIngressProxy) Start(config ingress.ProxyConfig, ctx ingress.ProxyContext) error {
args := m.Called(config, ctx)
return args.Error(0)
}
func (m *mockIngressProxy) Stop() error {
args := m.Called()
return args.Error(0)
}
func TestManagerClient_Process(t *testing.T) {
tests := []struct {
name string
@@ -121,7 +138,7 @@ func TestManagerClient_Process(t *testing.T) {
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
@@ -151,7 +168,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
runReq := &cvms.ComputationRunReq{
@@ -187,6 +204,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
},
}
mockSvc.On("State").Return("ReceivingManifest")
mockSvc.On("InitComputation", mock.Anything, mock.Anything).Return(nil)
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
@@ -216,7 +234,7 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
stopReq := &cvms.ServerStreamMessage_StopComputation{
@@ -255,3 +273,243 @@ func TestManagerClient_timeoutRequest(t *testing.T) {
assert.Len(t, rm.requests, 0)
}
// TestManagerClient_sendPendingMessages tests sending pending messages on reconnection.
func TestManagerClient_sendPendingMessages(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServer)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
// Add a pending message to storage
testMsg := &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_RunRes{
RunRes: &cvms.RunResponse{
ComputationId: "test-id",
},
},
}
err = client.storage.Add(testMsg)
assert.NoError(t, err)
// Mock successful send
mockStream.On("Send", mock.Anything).Return(nil).Once()
// Load and send pending messages
pending, err := client.storage.Load()
assert.NoError(t, err)
assert.Len(t, pending, 1)
client.sendPendingMessages(pending)
mockStream.AssertExpectations(t)
}
// TestManagerClient_sendPendingMessagesWithError tests pending message send failure.
func TestManagerClient_sendPendingMessagesWithError(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServer)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
testMsg := &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_RunRes{
RunRes: &cvms.RunResponse{
ComputationId: "test-id",
},
},
}
// Mock failed send
mockStream.On("Send", mock.Anything).Return(assert.AnError)
pending := []storage.Message{
{
Message: testMsg,
Time: time.Now(),
},
}
client.sendPendingMessages(pending)
mockStream.AssertExpectations(t)
}
// TestManagerClient_addChunkTimeout tests chunk timeout in runRequestManager.
func TestManagerClient_addChunkTimeout(t *testing.T) {
rm := newRunRequestManager()
// Add first chunk
chunk1 := []byte("chunk1")
buffer, complete := rm.addChunk("test-id", chunk1, false)
assert.Nil(t, buffer)
assert.False(t, complete)
// Verify request exists
rm.mu.Lock()
assert.Contains(t, rm.requests, "test-id")
rm.mu.Unlock()
// Wait for timeout
time.Sleep(35 * time.Second) // runReqTimeout is 30 seconds
// Verify request was removed
rm.mu.Lock()
assert.NotContains(t, rm.requests, "test-id")
rm.mu.Unlock()
}
// TestManagerClient_addChunkMultiple tests adding multiple chunks.
func TestManagerClient_addChunkMultiple(t *testing.T) {
rm := newRunRequestManager()
chunk1 := []byte("chunk1")
chunk2 := []byte("chunk2")
chunk3 := []byte("chunk3")
// Add chunks
buffer, complete := rm.addChunk("test-id", chunk1, false)
assert.Nil(t, buffer)
assert.False(t, complete)
buffer, complete = rm.addChunk("test-id", chunk2, false)
assert.Nil(t, buffer)
assert.False(t, complete)
buffer, complete = rm.addChunk("test-id", chunk3, true)
assert.NotNil(t, buffer)
assert.True(t, complete)
expected := append(append(chunk1, chunk2...), chunk3...)
assert.Equal(t, expected, buffer)
}
// TestManagerClient_handleStopComputationWithIngressProxy tests stop with ingress proxy.
func TestManagerClient_handleStopComputationWithIngressProxy(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServer)
mockIngressProxy := new(mockIngressProxy)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
stopReq := &cvms.ServerStreamMessage_StopComputation{
StopComputation: &cvms.StopComputation{
ComputationId: "test-comp-id",
},
}
mockSvc.On("StopComputation", mock.Anything).Return(nil)
mockServerSvc.On("Stop").Return(nil)
mockIngressProxy.On("Stop").Return(nil)
client.handleStopComputation(context.Background(), stopReq)
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
mockServerSvc.AssertExpectations(t)
mockIngressProxy.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
}
// TestManagerClient_handleStopComputationWithIngressProxyError tests stop with ingress proxy error.
func TestManagerClient_handleStopComputationWithIngressProxyError(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServer)
mockIngressProxy := new(mockIngressProxy)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, mockIngressProxy, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
stopReq := &cvms.ServerStreamMessage_StopComputation{
StopComputation: &cvms.StopComputation{
ComputationId: "test-comp-id",
},
}
mockSvc.On("StopComputation", mock.Anything).Return(nil)
mockServerSvc.On("Stop").Return(nil)
mockIngressProxy.On("Stop").Return(assert.AnError)
client.handleStopComputation(context.Background(), stopReq)
time.Sleep(50 * time.Millisecond)
mockIngressProxy.AssertExpectations(t)
}
// TestManagerClient_sendMessage tests sendMessage with timeout.
func TestManagerClient_sendMessage(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServer)
messageQueue := make(chan *cvms.ClientStreamMessage, 1)
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
msg := &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_RunRes{
RunRes: &cvms.RunResponse{
ComputationId: "test-id",
},
},
}
client.sendMessage(msg)
select {
case received := <-messageQueue:
assert.Equal(t, msg, received)
case <-time.After(1 * time.Second):
t.Fatal("Message not received")
}
}
// TestManagerClient_sendMessageTimeout tests sendMessage timeout when queue is full.
func TestManagerClient_sendMessageTimeout(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServer)
messageQueue := make(chan *cvms.ClientStreamMessage) // No buffer
logger := mglog.NewMock()
grpcClient := new(clientmocks.Client)
client, err := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc, nil, t.TempDir(), func(ctx context.Context) (pkggrpc.Client, cvms.Service_ProcessClient, error) { return nil, nil, nil }, grpcClient)
assert.NoError(t, err)
msg := &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_RunRes{
RunRes: &cvms.RunResponse{
ComputationId: "test-id",
},
},
}
// Don't read from queue, so sendMessage will timeout
client.sendMessage(msg)
// Should complete without blocking
time.Sleep(100 * time.Millisecond)
}
+9 -1
View File
@@ -7,6 +7,7 @@ import (
"context"
"errors"
"io"
"log/slog"
"time"
"github.com/ultravioletrs/cocos/agent/cvms"
@@ -52,16 +53,20 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
return errors.New("failed to get peer info")
}
slog.Info("client connected to cvms server", "address", client.Addr.String())
eg, ctx := errgroup.WithContext(stream.Context())
eg.Go(func() error {
for {
select {
case <-ctx.Done():
slog.Info("receive goroutine context done", "address", client.Addr.String())
return ctx.Err()
default:
req, err := stream.Recv()
if err != nil {
slog.Error("failed to receive from stream", "address", client.Addr.String(), "error", err)
return err
}
s.incoming <- req
@@ -85,10 +90,13 @@ func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
}
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
slog.Info("send goroutine Run() returned", "address", client.Addr.String())
return nil
})
return eg.Wait()
err := eg.Wait()
slog.Info("stream closed", "address", client.Addr.String(), "error", err)
return err
}
func (s *grpcServer) sendRunReqInChunks(stream cvms.Service_ProcessServer, runReq *cvms.ComputationRunReq) error {
+4 -8
View File
@@ -19,8 +19,8 @@ import (
)
const (
svcName = "agent"
defSvcGRPCPort = "7002"
svcName = "agent"
defSvcGRPCSocket = "/run/cocos/agent.sock"
)
type AgentServer interface {
@@ -46,15 +46,11 @@ func NewServer(logger *slog.Logger, svc agent.Service, host string, certProvider
}
func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error {
if cfg.Port == "" {
cfg.Port = defSvcGRPCPort
}
agentGrpcServerConfig := server.AgentConfig{
ServerConfig: server.ServerConfig{
Config: server.Config{
Host: as.host,
Port: cfg.Port,
Host: defSvcGRPCSocket,
Port: "",
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
ServerCAFile: cfg.ServerCAFile,
+12 -24
View File
@@ -97,7 +97,6 @@ func TestAgentServer_Start(t *testing.T) {
{
name: "successful start with default port",
cfg: agent.AgentConfig{
Port: "",
CertFile: "cert.pem",
KeyFile: "key.pem",
ServerCAFile: "server-ca.pem",
@@ -131,7 +130,6 @@ func TestAgentServer_Start(t *testing.T) {
{
name: "successful start with custom port",
cfg: agent.AgentConfig{
Port: "8080",
CertFile: "cert.pem",
KeyFile: "key.pem",
ServerCAFile: "server-ca.pem",
@@ -165,7 +163,6 @@ func TestAgentServer_Start(t *testing.T) {
{
name: "start with minimal config",
cfg: agent.AgentConfig{
Port: "9090",
AttestedTls: false,
},
cmp: agent.Computation{
@@ -243,9 +240,7 @@ func TestAgentServer_Stop(t *testing.T) {
{
name: "stop started server",
setupServer: func(server AgentServer) error {
cfg := agent.AgentConfig{
Port: "7004",
}
cfg := agent.AgentConfig{}
cmp := agent.Computation{
ID: "test-stop-computation",
Name: "Stop Test",
@@ -304,7 +299,7 @@ func TestAgentServer_StopMultipleTimes(t *testing.T) {
server := NewServer(logger, svc, host, nil)
// Start the server
cfg := agent.AgentConfig{Port: "7005"}
cfg := agent.AgentConfig{}
cmp := agent.Computation{
ID: "test-multiple-stop",
Name: "Multiple Stop Test",
@@ -347,7 +342,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
logger, svc, host, pubKey := setupTest(t)
server := NewServer(logger, svc, host, nil)
cfg := agent.AgentConfig{Port: "7006"}
cfg := agent.AgentConfig{}
cmp := agent.Computation{
ID: "test-restart",
Name: "Restart Test",
@@ -378,7 +373,7 @@ func TestAgentServer_StartAfterStop(t *testing.T) {
assert.NoError(t, err)
// Start again with different config
cfg2 := agent.AgentConfig{Port: "7007"}
cfg2 := agent.AgentConfig{}
cmp2 := agent.Computation{
ID: "test-restart-2",
Name: "Restart Test 2",
@@ -422,7 +417,6 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
{
name: "valid config with all fields",
config: agent.AgentConfig{
Port: "8080",
CertFile: "cert.pem",
KeyFile: "key.pem",
ServerCAFile: "server-ca.pem",
@@ -451,10 +445,8 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
valid: true,
},
{
name: "valid config with minimal fields",
config: agent.AgentConfig{
Port: "9090",
},
name: "valid config with minimal fields",
config: agent.AgentConfig{},
cmp: agent.Computation{
ID: "minimal-config-test",
Name: "Minimal Config Test",
@@ -477,10 +469,8 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
valid: true,
},
{
name: "config with empty port uses default",
config: agent.AgentConfig{
Port: "",
},
name: "config with empty port uses default",
config: agent.AgentConfig{},
cmp: agent.Computation{
ID: "default-port-test",
Name: "Default Port Test",
@@ -505,11 +495,9 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
if tt.valid {
assert.NoError(t, err)
// Verify default port is used when empty
if tt.config.Port == "" {
agentSrv := server.(*agentServer)
assert.NotNil(t, agentSrv.gs)
}
// Verify server started successfully
agentSrv := server.(*agentServer)
assert.NotNil(t, agentSrv.gs)
time.Sleep(10 * time.Millisecond)
if err := server.Stop(); err != nil {
@@ -526,5 +514,5 @@ func TestAgentServer_ConfigValidation(t *testing.T) {
func TestConstants(t *testing.T) {
assert.Equal(t, "agent", svcName)
assert.Equal(t, "7002", defSvcGRPCPort)
assert.Equal(t, "/run/cocos/agent.sock", defSvcGRPCSocket)
}
+261
View File
@@ -0,0 +1,261 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc v6.33.1
// source: agent/log/log.proto
package log
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
emptypb "google.golang.org/protobuf/types/known/emptypb"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type LogEntry struct {
state protoimpl.MessageState `protogen:"open.v1"`
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LogEntry) Reset() {
*x = LogEntry{}
mi := &file_agent_log_log_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *LogEntry) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*LogEntry) ProtoMessage() {}
func (x *LogEntry) ProtoReflect() protoreflect.Message {
mi := &file_agent_log_log_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use LogEntry.ProtoReflect.Descriptor instead.
func (*LogEntry) Descriptor() ([]byte, []int) {
return file_agent_log_log_proto_rawDescGZIP(), []int{0}
}
func (x *LogEntry) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *LogEntry) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
func (x *LogEntry) GetLevel() string {
if x != nil {
return x.Level
}
return ""
}
func (x *LogEntry) GetTimestamp() *timestamppb.Timestamp {
if x != nil {
return x.Timestamp
}
return nil
}
type EventEntry struct {
state protoimpl.MessageState `protogen:"open.v1"`
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"` // JSON payload
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *EventEntry) Reset() {
*x = EventEntry{}
mi := &file_agent_log_log_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *EventEntry) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*EventEntry) ProtoMessage() {}
func (x *EventEntry) ProtoReflect() protoreflect.Message {
mi := &file_agent_log_log_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use EventEntry.ProtoReflect.Descriptor instead.
func (*EventEntry) Descriptor() ([]byte, []int) {
return file_agent_log_log_proto_rawDescGZIP(), []int{1}
}
func (x *EventEntry) GetEventType() string {
if x != nil {
return x.EventType
}
return ""
}
func (x *EventEntry) GetTimestamp() *timestamppb.Timestamp {
if x != nil {
return x.Timestamp
}
return nil
}
func (x *EventEntry) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
func (x *EventEntry) GetDetails() []byte {
if x != nil {
return x.Details
}
return nil
}
func (x *EventEntry) GetOriginator() string {
if x != nil {
return x.Originator
}
return ""
}
func (x *EventEntry) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
var File_agent_log_log_proto protoreflect.FileDescriptor
const file_agent_log_log_proto_rawDesc = "" +
"\n" +
"\x13agent/log/log.proto\x12\x03log\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1bgoogle/protobuf/empty.proto\"\x9b\x01\n" +
"\bLogEntry\x12\x18\n" +
"\amessage\x18\x01 \x01(\tR\amessage\x12%\n" +
"\x0ecomputation_id\x18\x02 \x01(\tR\rcomputationId\x12\x14\n" +
"\x05level\x18\x03 \x01(\tR\x05level\x128\n" +
"\ttimestamp\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\"\xde\x01\n" +
"\n" +
"EventEntry\x12\x1d\n" +
"\n" +
"event_type\x18\x01 \x01(\tR\teventType\x128\n" +
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12%\n" +
"\x0ecomputation_id\x18\x03 \x01(\tR\rcomputationId\x12\x18\n" +
"\adetails\x18\x04 \x01(\fR\adetails\x12\x1e\n" +
"\n" +
"originator\x18\x05 \x01(\tR\n" +
"originator\x12\x16\n" +
"\x06status\x18\x06 \x01(\tR\x06status2v\n" +
"\fLogCollector\x120\n" +
"\aSendLog\x12\r.log.LogEntry\x1a\x16.google.protobuf.Empty\x124\n" +
"\tSendEvent\x12\x0f.log.EventEntry\x1a\x16.google.protobuf.EmptyB\aZ\x05./logb\x06proto3"
var (
file_agent_log_log_proto_rawDescOnce sync.Once
file_agent_log_log_proto_rawDescData []byte
)
func file_agent_log_log_proto_rawDescGZIP() []byte {
file_agent_log_log_proto_rawDescOnce.Do(func() {
file_agent_log_log_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)))
})
return file_agent_log_log_proto_rawDescData
}
var file_agent_log_log_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_agent_log_log_proto_goTypes = []any{
(*LogEntry)(nil), // 0: log.LogEntry
(*EventEntry)(nil), // 1: log.EventEntry
(*timestamppb.Timestamp)(nil), // 2: google.protobuf.Timestamp
(*emptypb.Empty)(nil), // 3: google.protobuf.Empty
}
var file_agent_log_log_proto_depIdxs = []int32{
2, // 0: log.LogEntry.timestamp:type_name -> google.protobuf.Timestamp
2, // 1: log.EventEntry.timestamp:type_name -> google.protobuf.Timestamp
0, // 2: log.LogCollector.SendLog:input_type -> log.LogEntry
1, // 3: log.LogCollector.SendEvent:input_type -> log.EventEntry
3, // 4: log.LogCollector.SendLog:output_type -> google.protobuf.Empty
3, // 5: log.LogCollector.SendEvent:output_type -> google.protobuf.Empty
4, // [4:6] is the sub-list for method output_type
2, // [2:4] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_agent_log_log_proto_init() }
func file_agent_log_log_proto_init() {
if File_agent_log_log_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_log_log_proto_rawDesc), len(file_agent_log_log_proto_rawDesc)),
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_agent_log_log_proto_goTypes,
DependencyIndexes: file_agent_log_log_proto_depIdxs,
MessageInfos: file_agent_log_log_proto_msgTypes,
}.Build()
File_agent_log_log_proto = out.File
file_agent_log_log_proto_goTypes = nil
file_agent_log_log_proto_depIdxs = nil
}
+32
View File
@@ -0,0 +1,32 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package log;
option go_package = "./log";
import "google/protobuf/timestamp.proto";
import "google/protobuf/empty.proto";
service LogCollector {
rpc SendLog(LogEntry) returns (google.protobuf.Empty);
rpc SendEvent(EventEntry) returns (google.protobuf.Empty);
}
message LogEntry {
string message = 1;
string computation_id = 2;
string level = 3;
google.protobuf.Timestamp timestamp = 4;
}
message EventEntry {
string event_type = 1;
google.protobuf.Timestamp timestamp = 2;
string computation_id = 3;
bytes details = 4; // JSON payload
string originator = 5;
string status = 6;
}
+163
View File
@@ -0,0 +1,163 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v6.33.1
// source: agent/log/log.proto
package log
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
LogCollector_SendLog_FullMethodName = "/log.LogCollector/SendLog"
LogCollector_SendEvent_FullMethodName = "/log.LogCollector/SendEvent"
)
// LogCollectorClient is the client API for LogCollector service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type LogCollectorClient interface {
SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error)
SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error)
}
type logCollectorClient struct {
cc grpc.ClientConnInterface
}
func NewLogCollectorClient(cc grpc.ClientConnInterface) LogCollectorClient {
return &logCollectorClient{cc}
}
func (c *logCollectorClient) SendLog(ctx context.Context, in *LogEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, LogCollector_SendLog_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *logCollectorClient) SendEvent(ctx context.Context, in *EventEntry, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, LogCollector_SendEvent_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// LogCollectorServer is the server API for LogCollector service.
// All implementations must embed UnimplementedLogCollectorServer
// for forward compatibility.
type LogCollectorServer interface {
SendLog(context.Context, *LogEntry) (*emptypb.Empty, error)
SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error)
mustEmbedUnimplementedLogCollectorServer()
}
// UnimplementedLogCollectorServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedLogCollectorServer struct{}
func (UnimplementedLogCollectorServer) SendLog(context.Context, *LogEntry) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SendLog not implemented")
}
func (UnimplementedLogCollectorServer) SendEvent(context.Context, *EventEntry) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SendEvent not implemented")
}
func (UnimplementedLogCollectorServer) mustEmbedUnimplementedLogCollectorServer() {}
func (UnimplementedLogCollectorServer) testEmbeddedByValue() {}
// UnsafeLogCollectorServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to LogCollectorServer will
// result in compilation errors.
type UnsafeLogCollectorServer interface {
mustEmbedUnimplementedLogCollectorServer()
}
func RegisterLogCollectorServer(s grpc.ServiceRegistrar, srv LogCollectorServer) {
// If the following call pancis, it indicates UnimplementedLogCollectorServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&LogCollector_ServiceDesc, srv)
}
func _LogCollector_SendLog_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LogEntry)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(LogCollectorServer).SendLog(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: LogCollector_SendLog_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(LogCollectorServer).SendLog(ctx, req.(*LogEntry))
}
return interceptor(ctx, in, info, handler)
}
func _LogCollector_SendEvent_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EventEntry)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(LogCollectorServer).SendEvent(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: LogCollector_SendEvent_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(LogCollectorServer).SendEvent(ctx, req.(*EventEntry))
}
return interceptor(ctx, in, info, handler)
}
// LogCollector_ServiceDesc is the grpc.ServiceDesc for LogCollector service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var LogCollector_ServiceDesc = grpc.ServiceDesc{
ServiceName: "log.LogCollector",
HandlerType: (*LogCollectorServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "SendLog",
Handler: _LogCollector_SendLog_Handler,
},
{
MethodName: "SendEvent",
Handler: _LogCollector_SendEvent_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "agent/log/log.proto",
}
+59
View File
@@ -0,0 +1,59 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package service
import (
"context"
"log/slog"
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/agent/log"
"google.golang.org/protobuf/types/known/emptypb"
)
var _ log.LogCollectorServer = (*LogForwarder)(nil)
type LogForwarder struct {
log.UnimplementedLogCollectorServer
cvmsClient cvms.ServiceClient
logger *slog.Logger
logQueue chan *cvms.ClientStreamMessage
}
func New(logger *slog.Logger, cvmsClient cvms.ServiceClient, queue chan *cvms.ClientStreamMessage) *LogForwarder {
return &LogForwarder{
cvmsClient: cvmsClient,
logger: logger,
logQueue: queue,
}
}
func (s *LogForwarder) SendLog(ctx context.Context, req *log.LogEntry) (*emptypb.Empty, error) {
s.logQueue <- &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_AgentLog{
AgentLog: &cvms.AgentLog{
Message: req.Message,
ComputationId: req.ComputationId,
Level: req.Level,
Timestamp: req.Timestamp,
},
},
}
return &emptypb.Empty{}, nil
}
func (s *LogForwarder) SendEvent(ctx context.Context, req *log.EventEntry) (*emptypb.Empty, error) {
s.logQueue <- &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_AgentEvent{
AgentEvent: &cvms.AgentEvent{
EventType: req.EventType,
Timestamp: req.Timestamp,
ComputationId: req.ComputationId,
Details: req.Details,
Originator: req.Originator,
Status: req.Status,
},
},
}
return &emptypb.Empty{}, nil
}
+303
View File
@@ -0,0 +1,303 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package service
import (
"context"
"encoding/json"
"log/slog"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/agent/log"
"google.golang.org/protobuf/types/known/timestamppb"
)
// TestNewLogForwarder tests the creation of a new log forwarder.
func TestNewLogForwarder(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 10)
lf := New(logger, nil, queue)
require.NotNil(t, lf)
assert.NotNil(t, lf.logger)
assert.Nil(t, lf.cvmsClient)
assert.NotNil(t, lf.logQueue)
}
// TestSendLog tests sending a log entry.
func TestSendLog(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 10)
lf := New(logger, nil, queue)
req := &log.LogEntry{
Message: "Test log message",
ComputationId: "computation-1",
Level: "INFO",
Timestamp: timestamppb.New(time.Now()),
}
resp, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
// Verify message was queued
select {
case msg := <-queue:
require.NotNil(t, msg)
agentLog := msg.GetAgentLog()
assert.NotNil(t, agentLog)
assert.Equal(t, "Test log message", agentLog.Message)
assert.Equal(t, "computation-1", agentLog.ComputationId)
assert.Equal(t, "INFO", agentLog.Level)
default:
t.Fatal("No message in queue")
}
}
// TestSendEvent tests sending an event entry.
func TestSendEvent(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 10)
lf := New(logger, nil, queue)
details, err := json.Marshal(map[string]string{"key": "value"})
require.NoError(t, err)
req := &log.EventEntry{
EventType: "COMPUTATION_STARTED",
Timestamp: timestamppb.New(time.Now()),
ComputationId: "computation-1",
Details: details,
Originator: "runner",
Status: "SUCCESS",
}
resp, err := lf.SendEvent(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
// Verify message was queued
select {
case msg := <-queue:
require.NotNil(t, msg)
agentEvent := msg.GetAgentEvent()
assert.NotNil(t, agentEvent)
assert.Equal(t, "COMPUTATION_STARTED", agentEvent.EventType)
assert.Equal(t, "computation-1", agentEvent.ComputationId)
assert.Equal(t, "runner", agentEvent.Originator)
assert.Equal(t, "SUCCESS", agentEvent.Status)
default:
t.Fatal("No message in queue")
}
}
// TestSendMultipleLogs tests sending multiple log entries.
func TestSendMultipleLogs(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 100)
lf := New(logger, nil, queue)
for i := 0; i < 5; i++ {
req := &log.LogEntry{
Message: "Log message",
ComputationId: "computation-1",
Level: "INFO",
Timestamp: timestamppb.New(time.Now()),
}
resp, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
}
assert.Equal(t, 5, len(queue))
}
// TestSendEventWithVariousTypes tests sending events with different types.
func TestSendEventWithVariousTypes(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 100)
lf := New(logger, nil, queue)
eventTypes := []string{"STARTED", "RUNNING", "COMPLETED", "FAILED"}
for _, eventType := range eventTypes {
details, err := json.Marshal(map[string]string{"type": eventType})
require.NoError(t, err)
req := &log.EventEntry{
EventType: eventType,
Timestamp: timestamppb.New(time.Now()),
ComputationId: "computation-1",
Details: details,
Originator: "runner",
Status: "OK",
}
resp, err := lf.SendEvent(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
}
assert.Equal(t, 4, len(queue))
}
// TestSendLogWithEmptyMessage tests sending log with empty message.
func TestSendLogWithEmptyMessage(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 10)
lf := New(logger, nil, queue)
req := &log.LogEntry{
Message: "",
ComputationId: "computation-1",
Level: "INFO",
Timestamp: timestamppb.New(time.Now()),
}
resp, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
select {
case msg := <-queue:
agentLog := msg.GetAgentLog()
assert.Equal(t, "", agentLog.Message)
default:
t.Fatal("No message in queue")
}
}
// TestSendEventWithNilDetails tests sending event with nil details.
func TestSendEventWithNilDetails(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 10)
lf := New(logger, nil, queue)
req := &log.EventEntry{
EventType: "TEST_EVENT",
Timestamp: timestamppb.New(time.Now()),
ComputationId: "computation-1",
Details: nil,
Originator: "test",
Status: "OK",
}
resp, err := lf.SendEvent(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
select {
case msg := <-queue:
agentEvent := msg.GetAgentEvent()
assert.Nil(t, agentEvent.Details)
default:
t.Fatal("No message in queue")
}
}
// TestSendLogWithVariousLevels tests sending logs with various severity levels.
func TestSendLogWithVariousLevels(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 100)
lf := New(logger, nil, queue)
levels := []string{"DEBUG", "INFO", "WARN", "ERROR"}
for _, level := range levels {
req := &log.LogEntry{
Message: "Test " + level,
ComputationId: "computation-1",
Level: level,
Timestamp: timestamppb.New(time.Now()),
}
resp, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
}
assert.Equal(t, 4, len(queue))
}
// TestSendLogWithDifferentComputationIds tests sending logs with different computation IDs.
func TestSendLogWithDifferentComputationIds(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 100)
lf := New(logger, nil, queue)
for i := 0; i < 3; i++ {
req := &log.LogEntry{
Message: "Message",
ComputationId: "computation-" + string(rune(48+i)),
Level: "INFO",
Timestamp: timestamppb.New(time.Now()),
}
resp, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
}
assert.Equal(t, 3, len(queue))
}
// TestQueueBehavior tests that queue is properly used.
func TestQueueBehavior(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 1)
lf := New(logger, nil, queue)
req := &log.LogEntry{
Message: "Test",
ComputationId: "computation-1",
Level: "INFO",
Timestamp: timestamppb.New(time.Now()),
}
resp, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, 1, len(queue))
}
// TestConcurrentSendLog tests concurrent log sending.
func TestConcurrentSendLog(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
queue := make(chan *cvms.ClientStreamMessage, 100)
lf := New(logger, nil, queue)
for i := 0; i < 10; i++ {
go func(id int) {
req := &log.LogEntry{
Message: "Concurrent log",
ComputationId: "computation-1",
Level: "INFO",
Timestamp: timestamppb.New(time.Now()),
}
_, err := lf.SendLog(context.Background(), req)
require.NoError(t, err)
}(i)
}
// Give goroutines time to complete
time.Sleep(100 * time.Millisecond)
// Should have received all messages
assert.True(t, len(queue) > 0)
}
+38
View File
@@ -0,0 +1,38 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"context"
"encoding/json"
"log/slog"
"github.com/ultravioletrs/cocos/agent/events"
logpb "github.com/ultravioletrs/cocos/agent/log"
logclient "github.com/ultravioletrs/cocos/pkg/clients/grpc/log"
)
type adapter struct {
client logclient.Client
svc string
}
func NewAdapter(client logclient.Client, svc string) events.Service {
return &adapter{
client: client,
svc: svc,
}
}
func (a *adapter) SendEvent(cmpID, event, status string, details json.RawMessage) {
err := a.client.SendEvent(context.Background(), &logpb.EventEntry{
EventType: event,
ComputationId: cmpID,
Details: details,
Originator: a.svc,
Status: status,
})
if err != nil {
slog.Error("failed to send event to log-forwarder", "error", err)
}
}
+138
View File
@@ -0,0 +1,138 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
logpb "github.com/ultravioletrs/cocos/agent/log"
)
const testServiceName = "test-service"
// mockLogClient is a mock implementation of the log client.
type mockLogClient struct {
mock.Mock
}
func (m *mockLogClient) SendLog(ctx context.Context, entry *logpb.LogEntry) error {
args := m.Called(ctx, entry)
return args.Error(0)
}
func (m *mockLogClient) SendEvent(ctx context.Context, entry *logpb.EventEntry) error {
args := m.Called(ctx, entry)
return args.Error(0)
}
func (m *mockLogClient) Close() error {
args := m.Called()
return args.Error(0)
}
// TestNewAdapter tests creating a new adapter.
func TestNewAdapter(t *testing.T) {
mockClient := new(mockLogClient)
svc := testServiceName
adapter := NewAdapter(mockClient, svc)
assert.NotNil(t, adapter)
}
// TestSendEvent tests sending an event successfully.
func TestSendEvent(t *testing.T) {
mockClient := new(mockLogClient)
svc := testServiceName
adapter := NewAdapter(mockClient, svc)
cmpID := "test-computation-id"
event := "computation.started"
status := "success"
details := json.RawMessage(`{"key": "value"}`)
expectedEntry := &logpb.EventEntry{
EventType: event,
ComputationId: cmpID,
Details: details,
Originator: svc,
Status: status,
}
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
adapter.SendEvent(cmpID, event, status, details)
mockClient.AssertExpectations(t)
mockClient.AssertCalled(t, "SendEvent", mock.Anything, expectedEntry)
}
// TestSendEventWithError tests sending an event when client returns an error.
func TestSendEventWithError(t *testing.T) {
mockClient := new(mockLogClient)
svc := testServiceName
adapter := NewAdapter(mockClient, svc)
cmpID := "test-computation-id"
event := "computation.failed"
status := "error"
details := json.RawMessage(`{"error": "something went wrong"}`)
mockClient.On("SendEvent", mock.Anything, mock.Anything).Return(assert.AnError)
// This should not panic even when error occurs
adapter.SendEvent(cmpID, event, status, details)
mockClient.AssertExpectations(t)
mockClient.AssertCalled(t, "SendEvent", mock.Anything, mock.Anything)
}
// TestSendEventWithNilDetails tests sending an event with nil details.
func TestSendEventWithNilDetails(t *testing.T) {
mockClient := new(mockLogClient)
svc := "runner-service"
adapter := NewAdapter(mockClient, svc)
cmpID := "comp-123"
event := "test.event"
status := "pending"
expectedEntry := &logpb.EventEntry{
EventType: event,
ComputationId: cmpID,
Details: nil,
Originator: svc,
Status: status,
}
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
adapter.SendEvent(cmpID, event, status, nil)
mockClient.AssertExpectations(t)
}
// TestSendEventWithEmptyStrings tests sending an event with empty strings.
func TestSendEventWithEmptyStrings(t *testing.T) {
mockClient := new(mockLogClient)
svc := testServiceName
adapter := NewAdapter(mockClient, svc)
expectedEntry := &logpb.EventEntry{
EventType: "",
ComputationId: "",
Details: nil,
Originator: svc,
Status: "",
}
mockClient.On("SendEvent", mock.Anything, expectedEntry).Return(nil)
adapter.SendEvent("", "", "", nil)
mockClient.AssertExpectations(t)
}
+341
View File
@@ -0,0 +1,341 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.8
// protoc v6.33.1
// source: agent/runner/runner.proto
package runner
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
emptypb "google.golang.org/protobuf/types/known/emptypb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type RunRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
AlgoType string `protobuf:"bytes,2,opt,name=algo_type,json=algoType,proto3" json:"algo_type,omitempty"` // "binary", "python", "wasm", "docker"
Algorithm []byte `protobuf:"bytes,3,opt,name=algorithm,proto3" json:"algorithm,omitempty"` // The algorithm binary/script content
Requirements []byte `protobuf:"bytes,4,opt,name=requirements,proto3" json:"requirements,omitempty"` // Python requirements.txt content
Args []string `protobuf:"bytes,5,rep,name=args,proto3" json:"args,omitempty"`
Datasets []*Dataset `protobuf:"bytes,6,rep,name=datasets,proto3" json:"datasets,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RunRequest) Reset() {
*x = RunRequest{}
mi := &file_agent_runner_runner_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RunRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RunRequest) ProtoMessage() {}
func (x *RunRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_runner_runner_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RunRequest.ProtoReflect.Descriptor instead.
func (*RunRequest) Descriptor() ([]byte, []int) {
return file_agent_runner_runner_proto_rawDescGZIP(), []int{0}
}
func (x *RunRequest) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
func (x *RunRequest) GetAlgoType() string {
if x != nil {
return x.AlgoType
}
return ""
}
func (x *RunRequest) GetAlgorithm() []byte {
if x != nil {
return x.Algorithm
}
return nil
}
func (x *RunRequest) GetRequirements() []byte {
if x != nil {
return x.Requirements
}
return nil
}
func (x *RunRequest) GetArgs() []string {
if x != nil {
return x.Args
}
return nil
}
func (x *RunRequest) GetDatasets() []*Dataset {
if x != nil {
return x.Datasets
}
return nil
}
type Dataset struct {
state protoimpl.MessageState `protogen:"open.v1"`
Filename string `protobuf:"bytes,1,opt,name=filename,proto3" json:"filename,omitempty"`
Hash []byte `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Dataset) Reset() {
*x = Dataset{}
mi := &file_agent_runner_runner_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Dataset) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Dataset) ProtoMessage() {}
func (x *Dataset) ProtoReflect() protoreflect.Message {
mi := &file_agent_runner_runner_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Dataset.ProtoReflect.Descriptor instead.
func (*Dataset) Descriptor() ([]byte, []int) {
return file_agent_runner_runner_proto_rawDescGZIP(), []int{1}
}
func (x *Dataset) GetFilename() string {
if x != nil {
return x.Filename
}
return ""
}
func (x *Dataset) GetHash() []byte {
if x != nil {
return x.Hash
}
return nil
}
type RunResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RunResponse) Reset() {
*x = RunResponse{}
mi := &file_agent_runner_runner_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RunResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RunResponse) ProtoMessage() {}
func (x *RunResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_runner_runner_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RunResponse.ProtoReflect.Descriptor instead.
func (*RunResponse) Descriptor() ([]byte, []int) {
return file_agent_runner_runner_proto_rawDescGZIP(), []int{2}
}
func (x *RunResponse) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
func (x *RunResponse) GetError() string {
if x != nil {
return x.Error
}
return ""
}
type StopRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ComputationId string `protobuf:"bytes,1,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *StopRequest) Reset() {
*x = StopRequest{}
mi := &file_agent_runner_runner_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *StopRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StopRequest) ProtoMessage() {}
func (x *StopRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_runner_runner_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StopRequest.ProtoReflect.Descriptor instead.
func (*StopRequest) Descriptor() ([]byte, []int) {
return file_agent_runner_runner_proto_rawDescGZIP(), []int{3}
}
func (x *StopRequest) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
var File_agent_runner_runner_proto protoreflect.FileDescriptor
const file_agent_runner_runner_proto_rawDesc = "" +
"\n" +
"\x19agent/runner/runner.proto\x12\x06runner\x1a\x1bgoogle/protobuf/empty.proto\"\xd3\x01\n" +
"\n" +
"RunRequest\x12%\n" +
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x1b\n" +
"\talgo_type\x18\x02 \x01(\tR\balgoType\x12\x1c\n" +
"\talgorithm\x18\x03 \x01(\fR\talgorithm\x12\"\n" +
"\frequirements\x18\x04 \x01(\fR\frequirements\x12\x12\n" +
"\x04args\x18\x05 \x03(\tR\x04args\x12+\n" +
"\bdatasets\x18\x06 \x03(\v2\x0f.runner.DatasetR\bdatasets\"9\n" +
"\aDataset\x12\x1a\n" +
"\bfilename\x18\x01 \x01(\tR\bfilename\x12\x12\n" +
"\x04hash\x18\x02 \x01(\fR\x04hash\"J\n" +
"\vRunResponse\x12%\n" +
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId\x12\x14\n" +
"\x05error\x18\x02 \x01(\tR\x05error\"4\n" +
"\vStopRequest\x12%\n" +
"\x0ecomputation_id\x18\x01 \x01(\tR\rcomputationId2x\n" +
"\x11ComputationRunner\x12.\n" +
"\x03Run\x12\x12.runner.RunRequest\x1a\x13.runner.RunResponse\x123\n" +
"\x04Stop\x12\x13.runner.StopRequest\x1a\x16.google.protobuf.EmptyB\n" +
"Z\b./runnerb\x06proto3"
var (
file_agent_runner_runner_proto_rawDescOnce sync.Once
file_agent_runner_runner_proto_rawDescData []byte
)
func file_agent_runner_runner_proto_rawDescGZIP() []byte {
file_agent_runner_runner_proto_rawDescOnce.Do(func() {
file_agent_runner_runner_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)))
})
return file_agent_runner_runner_proto_rawDescData
}
var file_agent_runner_runner_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_agent_runner_runner_proto_goTypes = []any{
(*RunRequest)(nil), // 0: runner.RunRequest
(*Dataset)(nil), // 1: runner.Dataset
(*RunResponse)(nil), // 2: runner.RunResponse
(*StopRequest)(nil), // 3: runner.StopRequest
(*emptypb.Empty)(nil), // 4: google.protobuf.Empty
}
var file_agent_runner_runner_proto_depIdxs = []int32{
1, // 0: runner.RunRequest.datasets:type_name -> runner.Dataset
0, // 1: runner.ComputationRunner.Run:input_type -> runner.RunRequest
3, // 2: runner.ComputationRunner.Stop:input_type -> runner.StopRequest
2, // 3: runner.ComputationRunner.Run:output_type -> runner.RunResponse
4, // 4: runner.ComputationRunner.Stop:output_type -> google.protobuf.Empty
3, // [3:5] is the sub-list for method output_type
1, // [1:3] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_agent_runner_runner_proto_init() }
func file_agent_runner_runner_proto_init() {
if File_agent_runner_runner_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_runner_runner_proto_rawDesc), len(file_agent_runner_runner_proto_rawDesc)),
NumEnums: 0,
NumMessages: 4,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_agent_runner_runner_proto_goTypes,
DependencyIndexes: file_agent_runner_runner_proto_depIdxs,
MessageInfos: file_agent_runner_runner_proto_msgTypes,
}.Build()
File_agent_runner_runner_proto = out.File
file_agent_runner_runner_proto_goTypes = nil
file_agent_runner_runner_proto_depIdxs = nil
}
+38
View File
@@ -0,0 +1,38 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package runner;
option go_package = "./runner";
import "google/protobuf/empty.proto";
service ComputationRunner {
rpc Run(RunRequest) returns (RunResponse);
rpc Stop(StopRequest) returns (google.protobuf.Empty);
}
message RunRequest {
string computation_id = 1;
string algo_type = 2; // "binary", "python", "wasm", "docker"
bytes algorithm = 3; // The algorithm binary/script content
bytes requirements = 4; // Python requirements.txt content
repeated string args = 5;
repeated Dataset datasets = 6;
}
message Dataset {
string filename = 1;
bytes hash = 2;
}
message RunResponse {
string computation_id = 1;
string error = 2;
}
message StopRequest {
string computation_id = 1;
}
+163
View File
@@ -0,0 +1,163 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v6.33.1
// source: agent/runner/runner.proto
package runner
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
ComputationRunner_Run_FullMethodName = "/runner.ComputationRunner/Run"
ComputationRunner_Stop_FullMethodName = "/runner.ComputationRunner/Stop"
)
// ComputationRunnerClient is the client API for ComputationRunner service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ComputationRunnerClient interface {
Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error)
Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
}
type computationRunnerClient struct {
cc grpc.ClientConnInterface
}
func NewComputationRunnerClient(cc grpc.ClientConnInterface) ComputationRunnerClient {
return &computationRunnerClient{cc}
}
func (c *computationRunnerClient) Run(ctx context.Context, in *RunRequest, opts ...grpc.CallOption) (*RunResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RunResponse)
err := c.cc.Invoke(ctx, ComputationRunner_Run_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *computationRunnerClient) Stop(ctx context.Context, in *StopRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, ComputationRunner_Stop_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// ComputationRunnerServer is the server API for ComputationRunner service.
// All implementations must embed UnimplementedComputationRunnerServer
// for forward compatibility.
type ComputationRunnerServer interface {
Run(context.Context, *RunRequest) (*RunResponse, error)
Stop(context.Context, *StopRequest) (*emptypb.Empty, error)
mustEmbedUnimplementedComputationRunnerServer()
}
// UnimplementedComputationRunnerServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedComputationRunnerServer struct{}
func (UnimplementedComputationRunnerServer) Run(context.Context, *RunRequest) (*RunResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Run not implemented")
}
func (UnimplementedComputationRunnerServer) Stop(context.Context, *StopRequest) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Stop not implemented")
}
func (UnimplementedComputationRunnerServer) mustEmbedUnimplementedComputationRunnerServer() {}
func (UnimplementedComputationRunnerServer) testEmbeddedByValue() {}
// UnsafeComputationRunnerServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to ComputationRunnerServer will
// result in compilation errors.
type UnsafeComputationRunnerServer interface {
mustEmbedUnimplementedComputationRunnerServer()
}
func RegisterComputationRunnerServer(s grpc.ServiceRegistrar, srv ComputationRunnerServer) {
// If the following call pancis, it indicates UnimplementedComputationRunnerServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&ComputationRunner_ServiceDesc, srv)
}
func _ComputationRunner_Run_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RunRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ComputationRunnerServer).Run(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ComputationRunner_Run_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ComputationRunnerServer).Run(ctx, req.(*RunRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ComputationRunner_Stop_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StopRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ComputationRunnerServer).Stop(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ComputationRunner_Stop_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ComputationRunnerServer).Stop(ctx, req.(*StopRequest))
}
return interceptor(ctx, in, info, handler)
}
// ComputationRunner_ServiceDesc is the grpc.ServiceDesc for ComputationRunner service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var ComputationRunner_ServiceDesc = grpc.ServiceDesc{
ServiceName: "runner.ComputationRunner",
HandlerType: (*ComputationRunnerServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Run",
Handler: _ComputationRunner_Run_Handler,
},
{
MethodName: "Stop",
Handler: _ComputationRunner_Stop_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "agent/runner/runner.proto",
}
+141
View File
@@ -0,0 +1,141 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package service
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"sync"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
"github.com/ultravioletrs/cocos/agent/algorithm/docker"
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
"github.com/ultravioletrs/cocos/agent/events"
pb "github.com/ultravioletrs/cocos/agent/runner"
"google.golang.org/protobuf/types/known/emptypb"
)
const (
algoFilePermission = 0o700
)
var _ pb.ComputationRunnerServer = (*RunnerService)(nil)
type RunnerService struct {
pb.UnimplementedComputationRunnerServer
logger *slog.Logger
eventSvc events.Service
currentAlgo algorithm.Algorithm
mu sync.Mutex
}
func New(logger *slog.Logger, eventSvc events.Service) *RunnerService {
return &RunnerService{
logger: logger,
eventSvc: eventSvc,
}
}
func (s *RunnerService) Run(ctx context.Context, req *pb.RunRequest) (*pb.RunResponse, error) {
s.mu.Lock()
if s.currentAlgo != nil {
s.mu.Unlock()
return &pb.RunResponse{
ComputationId: req.ComputationId,
Error: "computation already running",
}, nil
}
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.currentAlgo = nil
s.mu.Unlock()
}()
currentDir, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("error getting current directory: %v", err)
}
// Write Algo File
algoPath := filepath.Join(currentDir, "algo")
f, err := os.Create(algoPath)
if err != nil {
return nil, fmt.Errorf("error creating algorithm file: %v", err)
}
if _, err := f.Write(req.Algorithm); err != nil {
return nil, fmt.Errorf("error writing algorithm to file: %v", err)
}
if err := os.Chmod(algoPath, algoFilePermission); err != nil {
return nil, fmt.Errorf("error changing file permissions: %v", err)
}
if err := f.Close(); err != nil {
return nil, fmt.Errorf("error closing file: %v", err)
}
var algo algorithm.Algorithm
switch req.AlgoType {
case string(algorithm.AlgoTypeBin):
algo = binary.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.Args, req.ComputationId)
case string(algorithm.AlgoTypePython):
var requirementsFile string
if len(req.Requirements) > 0 {
fr, err := os.CreateTemp("", "requirements.txt")
if err != nil {
return nil, fmt.Errorf("error creating requirments file: %v", err)
}
if _, err := fr.Write(req.Requirements); err != nil {
return nil, fmt.Errorf("error writing requirements to file: %v", err)
}
if err := fr.Close(); err != nil {
return nil, fmt.Errorf("error closing file: %v", err)
}
requirementsFile = fr.Name()
}
// Assuming default python runtime if not specified in request (proto doesn't have runtime field yet)
// We can add it or assume.
runtime := python.PyRuntime
algo = python.NewAlgorithm(s.logger, s.eventSvc, runtime, requirementsFile, algoPath, req.Args, req.ComputationId)
case string(algorithm.AlgoTypeWasm):
algo = wasm.NewAlgorithm(s.logger, s.eventSvc, req.Args, algoPath, req.ComputationId)
case string(algorithm.AlgoTypeDocker):
algo = docker.NewAlgorithm(s.logger, s.eventSvc, algoPath, req.ComputationId)
default:
return nil, fmt.Errorf("unsupported algorithm type: %s", req.AlgoType)
}
s.mu.Lock()
s.currentAlgo = algo
s.mu.Unlock()
if err := algo.Run(); err != nil {
s.logger.Error("computation failed", "error", err)
return &pb.RunResponse{
ComputationId: req.ComputationId,
Error: err.Error(),
}, nil
}
return &pb.RunResponse{
ComputationId: req.ComputationId,
}, nil
}
func (s *RunnerService) Stop(ctx context.Context, req *pb.StopRequest) (*emptypb.Empty, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.currentAlgo != nil {
if err := s.currentAlgo.Stop(); err != nil {
return nil, err
}
}
return &emptypb.Empty{}, nil
}
+271
View File
@@ -0,0 +1,271 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package service
import (
"context"
"encoding/json"
"log/slog"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
pb "github.com/ultravioletrs/cocos/agent/runner"
)
// MockEventService is a mock implementation of events.Service.
type MockEventService struct {
events []interface{}
}
func (m *MockEventService) SendEvent(cmpID, event, status string, details json.RawMessage) {
m.events = append(m.events, map[string]interface{}{
"cmpID": cmpID,
"event": event,
"status": status,
"details": details,
})
}
// TestNewRunnerService tests the creation of a new runner service.
func TestNewRunnerService(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
require.NotNil(t, rs)
assert.NotNil(t, rs.logger)
assert.NotNil(t, rs.eventSvc)
assert.Nil(t, rs.currentAlgo)
}
// TestRunWithBinaryAlgorithm tests running a binary algorithm.
func TestRunWithBinaryAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-1",
AlgoType: "bin",
Algorithm: []byte("#!/bin/bash\necho 'test'"),
Args: []string{"arg1", "arg2"},
}
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "test-1", resp.ComputationId)
}
// TestRunWithPythonAlgorithm tests running a Python algorithm.
func TestRunWithPythonAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-python",
AlgoType: "python",
Algorithm: []byte("print('hello')"),
Args: []string{},
Requirements: []byte("numpy==1.21.0"),
}
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "test-python", resp.ComputationId)
}
// TestRunWithPythonAlgorithmNoRequirements tests running Python without requirements.
func TestRunWithPythonAlgorithmNoRequirements(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-python-noreq",
AlgoType: "python",
Algorithm: []byte("print('hello')"),
Args: []string{},
}
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "test-python-noreq", resp.ComputationId)
}
// TestRunWithWasmAlgorithm tests running a WASM algorithm.
func TestRunWithWasmAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-wasm",
AlgoType: "wasm",
Algorithm: []byte{0x00, 0x61, 0x73, 0x6d},
Args: []string{},
}
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "test-wasm", resp.ComputationId)
}
// TestRunWithDockerAlgorithm tests running a Docker algorithm.
func TestRunWithDockerAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-docker",
AlgoType: "docker",
Algorithm: []byte("FROM ubuntu:latest\nRUN echo 'test'"),
Args: []string{},
}
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "test-docker", resp.ComputationId)
}
// TestRunWithUnsupportedAlgorithmType tests running with unsupported algorithm type.
func TestRunWithUnsupportedAlgorithmType(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-unsupported",
AlgoType: "unsupported",
Algorithm: []byte("test"),
Args: []string{},
}
resp, err := rs.Run(context.Background(), req)
require.Error(t, err)
require.Nil(t, resp)
}
// TestRunAlreadyRunning tests running computation when one is already running.
func TestRunAlreadyRunning(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
// Use a long-running bash script
req := &pb.RunRequest{
ComputationId: "test-running",
AlgoType: "bin",
Algorithm: []byte("#!/bin/bash\nsleep 30"),
Args: []string{},
}
// Start first computation (will run for 30 seconds)
go func() {
_, _ = rs.Run(context.Background(), req)
}()
// Give it time to start
time.Sleep(500 * time.Millisecond)
// Try to run another immediately - should fail
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "computation already running", resp.Error)
}
// TestStopWhenRunning tests stopping a running computation.
func TestStopWhenRunning(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-stop",
AlgoType: "bin",
Algorithm: []byte("#!/bin/bash\nsleep 10"),
Args: []string{},
}
_, err := rs.Run(context.Background(), req)
require.NoError(t, err)
stopReq := &pb.StopRequest{
ComputationId: "test-stop",
}
stopResp, err := rs.Stop(context.Background(), stopReq)
require.NoError(t, err)
require.NotNil(t, stopResp)
}
// TestStopWhenNotRunning tests stopping when no computation is running.
func TestStopWhenNotRunning(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
stopReq := &pb.StopRequest{
ComputationId: "test-not-running",
}
stopResp, err := rs.Stop(context.Background(), stopReq)
require.NoError(t, err)
require.NotNil(t, stopResp)
}
// TestConcurrentRun tests that concurrent runs are properly serialized.
func TestConcurrentRun(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-concurrent",
AlgoType: "bin",
Algorithm: []byte("#!/bin/bash\nsleep 15"),
Args: []string{},
}
// Start first run in goroutine (will run for 15 seconds)
go func() {
_, _ = rs.Run(context.Background(), req)
}()
// Give it time to actually start
time.Sleep(500 * time.Millisecond)
// Concurrent attempt should fail
resp2, err := rs.Run(context.Background(), req)
require.NoError(t, err)
assert.Equal(t, "computation already running", resp2.Error)
}
// TestRunWithMultipleArgs tests running with multiple arguments.
func TestRunWithMultipleArgs(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := &MockEventService{}
rs := New(logger, eventSvc)
req := &pb.RunRequest{
ComputationId: "test-multi-args",
AlgoType: "bin",
Algorithm: []byte("#!/bin/bash\necho $@"),
Args: []string{"arg1", "arg2", "arg3", "arg4"},
}
resp, err := rs.Run(context.Background(), req)
require.NoError(t, err)
require.NotNil(t, resp)
assert.Equal(t, "test-multi-args", resp.ComputationId)
}
+52 -41
View File
@@ -16,17 +16,15 @@ import (
"github.com/absmach/supermq/pkg/errors"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
"github.com/ultravioletrs/cocos/agent/algorithm/docker"
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
"github.com/ultravioletrs/cocos/agent/events"
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
"github.com/ultravioletrs/cocos/agent/statemachine"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
attestation_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/attestation"
runner_client "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner"
"golang.org/x/crypto/sha3"
)
@@ -130,8 +128,12 @@ type Service interface {
type agentService struct {
mu sync.Mutex
computation Computation // Holds the current computation request details.
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
computation Computation // Holds the current computation request details.
runnerClient runner_client.Client
algoType string
algoArgs []string
algoRequirements []byte
algoReceived bool
result []byte // Stores the result of the computation.
sm statemachine.StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
@@ -146,13 +148,14 @@ type agentService struct {
var _ Service = (*agentService)(nil)
// New instantiates the agent service implementation.
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, vmlp int) Service {
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, attestationClient attestation_client.Client, runnerClient runner_client.Client, vmlp int) Service {
sm := statemachine.NewStateMachine(Idle)
ctx, cancel := context.WithCancel(ctx)
svc := &agentService{
sm: sm,
eventSvc: eventSvc,
attestationClient: attestationClient,
runnerClient: runnerClient,
logger: logger,
cancel: cancel,
vmpl: vmlp,
@@ -233,10 +236,9 @@ func (as *agentService) StopComputation(ctx context.Context) error {
as.cancel()
if as.algorithm != nil {
if err := as.algorithm.Stop(); err != nil {
return fmt.Errorf("error stopping computation: %v", err)
}
if _, err := as.runnerClient.Stop(ctx, &runnerpb.StopRequest{ComputationId: as.computation.ID}); err != nil {
as.logger.Warn("failed to stop runner", "error", err)
// proceed to cleanup
}
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
@@ -250,7 +252,10 @@ func (as *agentService) StopComputation(ctx context.Context) error {
as.sm.Reset(Idle)
as.computation = Computation{}
as.algorithm = nil
as.algoReceived = false
as.algoType = ""
as.algoArgs = nil
as.algoRequirements = nil
as.result = nil
as.runError = nil
as.resultsConsumed = false
@@ -278,7 +283,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
}
as.mu.Lock()
defer as.mu.Unlock()
if as.algorithm != nil {
if as.algoReceived {
return ErrAllManifestItemsReceived
}
@@ -317,38 +322,16 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
args := algorithm.AlgorithmArgsFromContext(ctx)
switch algoType {
case string(algorithm.AlgoTypeBin):
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args, as.computation.ID)
case string(algorithm.AlgoTypePython):
var requirementsFile string
if len(algo.Requirements) > 0 {
fr, err := os.CreateTemp("", "requirements.txt")
if err != nil {
return fmt.Errorf("error creating requirments file: %v", err)
}
if _, err := fr.Write(algo.Requirements); err != nil {
return fmt.Errorf("error writing requirements to file: %v", err)
}
if err := fr.Close(); err != nil {
return fmt.Errorf("error closing file: %v", err)
}
requirementsFile = fr.Name()
}
runtime := python.PythonRunTimeFromContext(ctx)
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args, as.computation.ID)
case string(algorithm.AlgoTypeWasm):
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, args, f.Name(), as.computation.ID)
case string(algorithm.AlgoTypeDocker):
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name(), as.computation.ID)
}
as.algoType = algoType
as.algoArgs = args
as.algoRequirements = algo.Requirements
as.algoReceived = true
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
return fmt.Errorf("error creating datasets directory: %v", err)
}
if as.algorithm != nil {
if as.algoReceived {
as.sm.SendEvent(AlgorithmReceived)
}
@@ -478,14 +461,42 @@ func (as *agentService) runComputation(state statemachine.State) {
}
}()
// Read algo file
currentDir, _ := os.Getwd()
algoFile := filepath.Join(currentDir, "algo")
algoBytes, err := os.ReadFile(algoFile)
if err != nil {
as.runError = fmt.Errorf("failed to read algo file: %w", err)
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
as.publishEvent(InProgress.String())(state)
if err := as.algorithm.Run(); err != nil {
// Call Runner
resp, err := as.runnerClient.Run(context.Background(), &runnerpb.RunRequest{
ComputationId: as.computation.ID,
AlgoType: as.algoType,
Algorithm: algoBytes,
Requirements: as.algoRequirements,
Args: as.algoArgs,
// Datasets implicit on shared FS
})
if err != nil {
as.runError = err
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
if resp.Error != "" {
as.runError = errors.New(resp.Error)
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", resp.Error))
as.publishEvent(Failed.String())(state)
return
}
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
if err != nil {
as.runError = err
+46 -47
View File
@@ -18,16 +18,18 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/algorithm"
algomocks "github.com/ultravioletrs/cocos/agent/algorithm/mocks"
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/events/mocks"
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
"github.com/ultravioletrs/cocos/agent/statemachine"
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/emptypb"
)
var (
@@ -123,7 +125,9 @@ func TestAlgo(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client := new(MockAttestationClient)
svc := New(ctx, mglog.NewMock(), events, client, 0)
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
@@ -217,7 +221,9 @@ func TestData(t *testing.T) {
defer cancel()
client := new(MockAttestationClient)
svc := New(ctx, mglog.NewMock(), events, client, 0)
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
@@ -294,6 +300,7 @@ func TestResult(t *testing.T) {
}
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
sm := new(smmocks.StateMachine)
sm.On("Start", ctx).Return(nil)
@@ -304,6 +311,7 @@ func TestResult(t *testing.T) {
sm: sm,
eventSvc: events,
attestationClient: client,
runnerClient: runnerCli,
computation: testComputation(t),
}
@@ -400,7 +408,8 @@ func TestAttestation(t *testing.T) {
}
defer getQuote.Unset()
svc := New(ctx, mglog.NewMock(), events, client, 0)
runnerCli := new(runnermocks.Client)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
time.Sleep(300 * time.Millisecond)
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
@@ -420,18 +429,7 @@ func TestAzureAttestationToken(t *testing.T) {
name: "Azure token fetch successful",
nonce: [32]byte{1, 2, 3}, // any test nonce
token: []byte("mockToken"),
err: nil, // fixed expectation as err was ErrAttestationType in original but logic suggests success if token returns? Wait, orig test had ErrAttestationType? Ah, maybe provider mock returns error.
// Re-reading original test:
// err: ErrAttestationType
// provider.On(...).Return(tc.token, tc.err)
// svc.AzureAttestationToken...
// In original code, AzureAttestationToken checked `attestation.CCPlatform() != attestation.Azure`.
// Since test runs on non-azure, it returns ErrAttestationType.
// My new client calls GetAzureToken. The logic for checking platform moved to attestation-service.
// So `agent` just calls the client.
// So here we should expect whatever the client returns.
// Mock client returns tc.err.
// If I want to test success, I should set err: nil.
err: nil,
},
{
name: "Azure token fetch failed",
@@ -450,7 +448,8 @@ func TestAzureAttestationToken(t *testing.T) {
ctx := context.Background()
svc := New(ctx, mglog.NewMock(), events, client, 0)
runnerCli := new(runnermocks.Client)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
_, err := svc.AzureAttestationToken(ctx, tc.nonce)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
@@ -489,9 +488,6 @@ func testComputation(t *testing.T) Computation {
}
func TestStopComputation(t *testing.T) {
testDataDir := "test_datasets"
testResultsDir := "test_results"
cases := []struct {
name string
setupDirs bool
@@ -511,22 +507,9 @@ func TestStopComputation(t *testing.T) {
setupDirs: true,
setupAlgo: true,
algoStopErr: fmt.Errorf("algorithm stop failed"),
expectedErr: fmt.Errorf("error stopping computation: algorithm stop failed"),
},
{
name: "Stop computation without algorithm",
setupDirs: true,
setupAlgo: false,
algoStopErr: nil,
expectedErr: nil,
},
{
name: "Stop computation with missing directories",
setupDirs: false,
setupAlgo: false,
algoStopErr: nil,
expectedErr: nil, // os.RemoveAll doesn't error on non-existing directories
expectedErr: nil, // Warn only
},
// We log warnings but don't return error in StopComputation in new implementation for Stop failure.
}
for _, tc := range cases {
@@ -539,7 +522,16 @@ func TestStopComputation(t *testing.T) {
defer cancel()
client := new(MockAttestationClient)
svc := New(ctx, mglog.NewMock(), events, client, 0).(*agentService)
runnerCli := new(runnermocks.Client)
// Mock Stop call
var stopErr error
if tc.algoStopErr != nil {
stopErr = tc.algoStopErr
}
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, stopErr)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0).(*agentService)
svc.computation = Computation{
ID: "test-computation",
@@ -547,17 +539,17 @@ func TestStopComputation(t *testing.T) {
}
if tc.setupDirs {
err := os.MkdirAll(testDataDir, 0o755)
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
err = os.MkdirAll(testResultsDir, 0o755)
err = os.MkdirAll(algorithm.ResultsDir, 0o755)
require.NoError(t, err)
}
if tc.setupAlgo {
mockAlgo := new(algomocks.Algorithm)
mockAlgo.On("Stop").Return(tc.algoStopErr)
svc.algorithm = mockAlgo
}
// Use real dirs for test
// algorithm.DatasetsDir refers to global var?
// "github.com/ultravioletrs/cocos/agent/algorithm"
// It uses hardcoded path "datasets" and "results" in current dir.
// Tests create them in current dir.
err := svc.StopComputation(ctx)
@@ -575,8 +567,8 @@ func TestStopComputation(t *testing.T) {
events.AssertExpectations(t)
_ = os.RemoveAll(testDataDir)
_ = os.RemoveAll(testResultsDir)
_ = os.RemoveAll(algorithm.DatasetsDir)
_ = os.RemoveAll(algorithm.ResultsDir)
})
}
}
@@ -608,7 +600,11 @@ func TestStopComputationIntegration(t *testing.T) {
defer cancel()
client := new(MockAttestationClient)
svc := New(ctx, mglog.NewMock(), events, client, 0)
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
computation := Computation{
ID: "integration-test",
@@ -647,7 +643,10 @@ func TestStopComputationConcurrent(t *testing.T) {
defer cancel()
client := new(MockAttestationClient)
svc := New(ctx, mglog.NewMock(), events, client, 0)
runnerCli := new(runnermocks.Client)
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
svc.(*agentService).computation = Computation{
ID: "concurrent-test",