NOISSUE - Simplify manager to vm provision only (#353)

* new agent structure

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* cvm tests fix

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* manager server, for vm provisioning

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add cli and test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* restore result cli

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix failing tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix failing test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: remove context from docker struct and use local context in Run method

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* delete: remove unused gRPC API and related server implementation

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-01-20 13:56:18 +03:00
committed by GitHub
parent ecad6514f3
commit 1f32f516b0
36 changed files with 625 additions and 4532 deletions
+68
View File
@@ -0,0 +1,68 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"github.com/fatih/color"
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/types/known/emptypb"
)
func (c *CLI) NewCreateVMCmd() *cobra.Command {
return &cobra.Command{
Use: "create-vm",
Short: "Create a new virtual machine",
Example: `create-vm`,
Args: cobra.ExactArgs(0),
Run: func(cmd *cobra.Command, args []string) {
if err := c.InitializeManagerClient(cmd); err == nil {
defer c.Close()
}
if c.connectErr != nil {
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
return
}
cmd.Println("🔗 Creating a new virtual machine")
res, err := c.managerClient.CreateVm(cmd.Context(), &emptypb.Empty{})
if err != nil {
printError(cmd, "Error creating virtual machine: %v ❌ ", err)
return
}
cmd.Println(color.New(color.FgGreen).Sprintf("✅ Virtual machine created successfully with id %s and port %s", res.SvmId, res.ForwardedPort))
},
}
}
func (c *CLI) NewRemoveVMCmd() *cobra.Command {
return &cobra.Command{
Use: "remove-vm",
Short: "Remove a virtual machine",
Example: `remove-vm <svm_id>`,
Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) {
if err := c.InitializeManagerClient(cmd); err == nil {
defer c.Close()
}
if c.connectErr != nil {
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
return
}
cmd.Println("🔗 Removing virtual machine")
_, err := c.managerClient.RemoveVm(cmd.Context(), &manager.RemoveReq{SvmId: args[0]})
if err != nil {
printError(cmd, "Error removing virtual machine: %v ❌ ", err)
return
}
cmd.Println(color.New(color.FgGreen).Sprintf("✅ Virtual machine removed successfully"))
},
}
}
+27 -8
View File
@@ -6,28 +6,33 @@ import (
"context"
"github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/clients/grpc/agent"
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
"github.com/ultravioletrs/cocos/pkg/sdk"
)
var Verbose bool
type CLI struct {
agentSDK sdk.SDK
config grpc.AgentClientConfig
client grpc.Client
connectErr error
agentSDK sdk.SDK
agentConfig grpc.AgentClientConfig
managerConfig grpc.ManagerClientConfig
client grpc.Client
managerClient manager.ManagerServiceClient
connectErr error
}
func New(config grpc.AgentClientConfig) *CLI {
func New(agentConfig grpc.AgentClientConfig, managerConfig grpc.ManagerClientConfig) *CLI {
return &CLI{
config: config,
agentConfig: agentConfig,
managerConfig: managerConfig,
}
}
func (c *CLI) InitializeSDK(cmd *cobra.Command) error {
agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.config)
func (c *CLI) InitializeAgentSDK(cmd *cobra.Command) error {
agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.agentConfig)
if err != nil {
c.connectErr = err
return err
@@ -39,6 +44,20 @@ func (c *CLI) InitializeSDK(cmd *cobra.Command) error {
return nil
}
func (c *CLI) InitializeManagerClient(cmd *cobra.Command) error {
managerGRPCClient, managerClient, err := managergrpc.NewManagerClient(c.managerConfig)
if err != nil {
c.connectErr = err
return err
}
cmd.Println("🔗 Connected to manager using ", managerGRPCClient.Secure())
c.client = managerGRPCClient
c.managerClient = managerClient
return nil
}
func (c *CLI) Close() {
c.client.Close()
}
+17 -7
View File
@@ -19,11 +19,12 @@ import (
)
const (
svcName = "cli"
envPrefixAgentGRPC = "AGENT_GRPC_"
completion = "completion"
filePermision = 0o755
cocosDirectory = ".cocos"
svcName = "cli"
envPrefixAgentGRPC = "AGENT_GRPC_"
envPrefixManagerGRPC = "MANAGER_GRPC_"
completion = "completion"
filePermision = 0o755
cocosDirectory = ".cocos"
)
type config struct {
@@ -98,9 +99,16 @@ func main() {
return
}
cliSVC := cli.New(agentGRPCConfig)
managerGRPCConfig := grpc.ManagerClientConfig{}
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixManagerGRPC}); err != nil {
message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)
rootCmd.Println(message)
return
}
if err := cliSVC.InitializeSDK(rootCmd); err == nil {
cliSVC := cli.New(agentGRPCConfig, managerGRPCConfig)
if err := cliSVC.InitializeAgentSDK(rootCmd); err == nil {
defer cliSVC.Close()
}
@@ -119,6 +127,8 @@ func main() {
rootCmd.AddCommand(attestationPolicyCmd)
rootCmd.AddCommand(keysCmd)
rootCmd.AddCommand(cliSVC.NewCABundleCmd(directoryCachePath))
rootCmd.AddCommand(cliSVC.NewCreateVMCmd())
rootCmd.AddCommand(cliSVC.NewRemoveVMCmd())
// Attestation commands
attestationCmd.AddCommand(cliSVC.NewGetAttestationCmd())
+15 -47
View File
@@ -10,25 +10,24 @@ import (
"log/slog"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/jaeger"
"github.com/absmach/magistrala/pkg/prometheus"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/caarlos0/env/v11"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/api"
managerapi "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/manager/events"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/tracing"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
const (
@@ -92,64 +91,33 @@ func main() {
args := qemuCfg.ConstructQemuArgs()
logger.Info(strings.Join(args, " "))
managerGRPCConfig := pkggrpc.CVMClientConfig{}
managerGRPCConfig := server.ServerConfig{}
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
exitCode = 1
return
}
managerGRPCClient, managerClient, err := managergrpc.NewManagerClient(managerGRPCConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer managerGRPCClient.Close()
pc, err := managerClient.Process(ctx)
svc, err := newService(logger, tracer, qemuCfg, cfg.AttestationPolicyBinary, cfg.EosVersion)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
eventsChan := make(chan *manager.ClientStreamMessage, clientBufferSize)
svc, err := newService(logger, tracer, qemuCfg, eventsChan, cfg.AttestationPolicyBinary, cfg.EosVersion)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
registerManagerServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
}
eventsSvc, err := events.New(logger, svc.ReportBrokenConnection, eventsChan)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
go eventsSvc.Listen(ctx)
mc := managerapi.NewClient(pc, svc, eventsChan, logger)
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil)
g.Go(func() error {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(ch)
select {
case <-ch:
logger.Info("Received signal, shutting down...")
cancel()
return nil
case <-ctx.Done():
return ctx.Err()
}
return gs.Start()
})
g.Go(func() error {
return mc.Process(ctx, cancel)
return server.StopHandler(ctx, cancel, logger, svcName, gs)
})
if err := g.Wait(); err != nil {
@@ -157,8 +125,8 @@ func main() {
}
}
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, eventsChan chan *manager.ClientStreamMessage, attestationPolicyPath string, eosVersion string) (manager.Service, error) {
svc, err := manager.New(qemuCfg, attestationPolicyPath, logger, eventsChan, qemu.NewVM, eosVersion)
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, attestationPolicyPath string, eosVersion string) (manager.Service, error) {
svc, err := manager.New(qemuCfg, attestationPolicyPath, logger, qemu.NewVM, eosVersion)
if err != nil {
return nil, err
}
+2 -2
View File
@@ -5,7 +5,6 @@ go 1.23.0
require (
github.com/absmach/magistrala v0.15.1
github.com/caarlos0/env/v11 v11.2.2
github.com/cenkalti/backoff/v4 v4.3.0
github.com/fatih/color v1.18.0
github.com/go-kit/kit v0.13.0
github.com/gofrs/uuid v4.4.0+incompatible
@@ -25,6 +24,7 @@ require (
require (
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/distribution/reference v0.6.0 // indirect
github.com/docker/go-connections v0.5.0 // indirect
@@ -59,7 +59,7 @@ require (
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-configfs-tsm v0.2.2 // indirect
github.com/google/logger v1.1.1
github.com/google/uuid v1.6.0 // indirect
github.com/google/uuid v1.6.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.23.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
-247
View File
@@ -1,247 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package vsock
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"google.golang.org/protobuf/proto"
)
const (
maxRetries = 3
retryDelay = time.Second
maxMessageSize = 1 << 20 // 1 MB
ackTimeout = 5 * time.Second
maxConcurrent = 100
)
type MessageStatus int
const (
StatusPending MessageStatus = iota
StatusSent
StatusAcknowledged
StatusFailed
)
type Message struct {
ID uint32
Content []byte
Status MessageStatus
Retries int
}
type AckWriter struct {
conn net.Conn
pendingMessages chan *Message
messageStore sync.Map // map[uint32]*Message
nextID uint32
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func NewAckWriter(conn net.Conn) io.WriteCloser {
ctx, cancel := context.WithCancel(context.Background())
aw := &AckWriter{
conn: conn,
pendingMessages: make(chan *Message, maxConcurrent),
nextID: 1,
ctx: ctx,
cancel: cancel,
}
aw.wg.Add(2)
go aw.sendMessages()
go aw.handleAcknowledgments()
return aw
}
func (aw *AckWriter) Write(p []byte) (int, error) {
if len(p) > maxMessageSize {
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
}
messageID := atomic.AddUint32(&aw.nextID, 1)
message := &Message{
ID: messageID,
Content: make([]byte, len(p)),
Status: StatusPending,
}
copy(message.Content, p)
aw.messageStore.Store(messageID, message)
select {
case aw.pendingMessages <- message:
return len(p), nil
case <-aw.ctx.Done():
return 0, fmt.Errorf("writer is closed")
}
}
func (aw *AckWriter) sendMessages() {
defer aw.wg.Done()
for {
select {
case <-aw.ctx.Done():
return
case msg := <-aw.pendingMessages:
if err := aw.sendWithRetry(msg); err != nil {
log.Printf("Failed to send message %d after all retries: %v", msg.ID, err)
msg.Status = StatusFailed
aw.messageStore.Store(msg.ID, msg)
}
}
}
}
func (aw *AckWriter) sendWithRetry(msg *Message) error {
for msg.Retries < maxRetries {
if err := aw.writeMessage(msg.ID, msg.Content); err != nil {
msg.Retries++
msg.Status = StatusPending
log.Printf("Error writing message %d (attempt %d): %v", msg.ID, msg.Retries, err)
time.Sleep(retryDelay)
continue
}
msg.Status = StatusSent
aw.messageStore.Store(msg.ID, msg)
return nil
}
return fmt.Errorf("max retries reached")
}
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
return fmt.Errorf("failed to write message ID: %w", err)
}
messageLen := uint32(len(p))
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
return fmt.Errorf("failed to write message length: %w", err)
}
if _, err := aw.conn.Write(p); err != nil {
return fmt.Errorf("failed to write message content: %w", err)
}
return nil
}
func (aw *AckWriter) handleAcknowledgments() {
defer aw.wg.Done()
for {
select {
case <-aw.ctx.Done():
return
default:
var ackID uint32
if err := binary.Read(aw.conn, binary.LittleEndian, &ackID); err != nil {
if err == io.EOF {
log.Println("Connection closed, stopping acknowledgment handler")
return
}
log.Printf("Error reading ACK: %v", err)
time.Sleep(retryDelay)
continue
}
if msg, ok := aw.messageStore.Load(ackID); ok {
m := msg.(*Message)
m.Status = StatusAcknowledged
aw.messageStore.Store(ackID, m)
// Clean up old messages periodically
go aw.cleanupOldMessages(ackID)
} else {
log.Printf("Received ACK for unknown message ID: %d", ackID)
}
}
}
}
func (aw *AckWriter) cleanupOldMessages(currentID uint32) {
aw.messageStore.Range(func(key, value interface{}) bool {
msgID := key.(uint32)
msg := value.(*Message)
// Clean up acknowledged messages that are old
if msg.Status == StatusAcknowledged && msgID < currentID-maxConcurrent {
aw.messageStore.Delete(msgID)
}
return true
})
}
func (aw *AckWriter) Close() error {
aw.cancel()
aw.wg.Wait()
return aw.conn.Close()
}
type Reader interface {
Read() ([]byte, error)
ReadProto(msg proto.Message) error
}
type AckReader struct {
conn net.Conn
ctx context.Context
}
func NewAckReader(conn net.Conn) Reader {
return &AckReader{
conn: conn,
ctx: context.Background(),
}
}
func (ar *AckReader) ReadProto(msg proto.Message) error {
data, err := ar.Read()
if err != nil {
return fmt.Errorf("failed to read proto message: %w", err)
}
return proto.Unmarshal(data, msg)
}
func (ar *AckReader) Read() ([]byte, error) {
var messageID uint32
if err := binary.Read(ar.conn, binary.LittleEndian, &messageID); err != nil {
return nil, fmt.Errorf("error reading message ID: %w", err)
}
var messageLen uint32
if err := binary.Read(ar.conn, binary.LittleEndian, &messageLen); err != nil {
return nil, fmt.Errorf("error reading message length: %w", err)
}
if messageLen > maxMessageSize {
return nil, fmt.Errorf("message size %d exceeds maximum allowed size of %d bytes", messageLen, maxMessageSize)
}
data := make([]byte, messageLen)
if _, err := io.ReadFull(ar.conn, data); err != nil {
return nil, fmt.Errorf("error reading message content: %w", err)
}
if err := ar.sendAck(messageID); err != nil {
return nil, fmt.Errorf("error sending ACK: %w", err)
}
return data, nil
}
func (ar *AckReader) sendAck(messageID uint32) error {
return binary.Write(ar.conn, binary.LittleEndian, messageID)
}
-337
View File
@@ -1,337 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package vsock
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/proto"
)
// MockConn implements net.Conn for testing purposes.
type MockConn struct {
ReadData []byte
WrittenData []byte
ReadErr error
WriteErr error
closed bool
mu sync.Mutex
}
func (m *MockConn) Read(b []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return 0, io.EOF
}
if len(m.ReadData) == 0 {
return 0, io.EOF // Ensure we handle this case more predictably
}
if m.ReadErr != nil {
return 0, m.ReadErr
}
n = copy(b, m.ReadData)
m.ReadData = m.ReadData[n:]
return n, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return 0, errors.New("connection closed")
}
if m.WriteErr != nil {
return 0, m.WriteErr
}
m.WrittenData = append(m.WrittenData, b...)
return len(b), nil
}
func (m *MockConn) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.closed = true
return nil
}
// Implement other net.Conn methods with empty implementations.
func (m *MockConn) LocalAddr() net.Addr { return nil }
func (m *MockConn) RemoteAddr() net.Addr { return nil }
func (m *MockConn) SetDeadline(t time.Time) error { return nil }
func (m *MockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil }
func TestAckReader_Read(t *testing.T) {
tests := []struct {
name string
data []byte
wantErr bool
}{
{"Valid message", []byte("Hello, World!"), false},
{"Empty message", []byte{}, false},
{"Message at max size", make([]byte, maxMessageSize), false},
{"Message exceeds max size", make([]byte, maxMessageSize+1), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := &MockConn{}
ar := NewAckReader(mockConn)
// Prepare mock data
messageID := uint32(1)
messageLen := uint32(len(tt.data))
mockData := make([]byte, 8+len(tt.data))
binary.LittleEndian.PutUint32(mockData[:4], messageID)
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
copy(mockData[8:], tt.data)
mockConn.ReadData = mockData
data, err := ar.Read()
if (err != nil) != tt.wantErr {
t.Errorf("AckReader.Read() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if !bytes.Equal(data, tt.data) {
t.Errorf("AckReader.Read() got = %v, want %v", data, tt.data)
}
// Check if ACK was sent
if len(mockConn.WrittenData) != 4 {
t.Errorf("AckReader.Read() did not send ACK")
} else {
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
if ackID != messageID {
t.Errorf("AckReader.Read() sent wrong ACK ID, got %d, want %d", ackID, messageID)
}
}
}
})
}
}
func TestAckReader_ReadProto(t *testing.T) {
tests := []struct {
name string
msg *manager.ClientStreamMessage
wantErr bool
}{
{"Valid proto message", &manager.ClientStreamMessage{}, false},
{"Empty proto message", &manager.ClientStreamMessage{}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := &MockConn{}
ar := NewAckReader(mockConn)
// Prepare mock data
protoData, _ := proto.Marshal(tt.msg)
messageID := uint32(1)
messageLen := uint32(len(protoData))
mockData := make([]byte, 8+len(protoData))
binary.LittleEndian.PutUint32(mockData[:4], messageID)
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
copy(mockData[8:], protoData)
mockConn.ReadData = mockData
receivedMsg := &manager.ClientStreamMessage{}
err := ar.ReadProto(receivedMsg)
if (err != nil) != tt.wantErr {
t.Errorf("AckReader.ReadProto() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if receivedMsg.Message != tt.msg.Message {
t.Errorf("AckReader.ReadProto() got = %v, want %v", receivedMsg, tt.msg)
}
// Check if ACK was sent
if len(mockConn.WrittenData) != 4 {
t.Errorf("AckReader.ReadProto() did not send ACK")
} else {
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
if ackID != messageID {
t.Errorf("AckReader.ReadProto() sent wrong ACK ID, got %d, want %d", ackID, messageID)
}
}
}
})
}
}
func TestNewAckWriter(t *testing.T) {
mockConn := &MockConn{}
writer := NewAckWriter(mockConn)
if _, ok := writer.(io.Writer); !ok {
t.Errorf("NewAckWriter() did not return an io.Writer")
}
}
func TestNewAckReader(t *testing.T) {
mockConn := &MockConn{}
reader := NewAckReader(mockConn)
assert.NotNil(t, reader)
}
func TestAckWriter_Close(t *testing.T) {
mockConn := &MockConn{}
aw := NewAckWriter(mockConn)
err := aw.Close()
if err != nil {
t.Errorf("AckWriter.Close() error = %v, wantErr %v", err, nil)
}
if !mockConn.closed {
t.Errorf("AckWriter.Close() did not close the connection")
}
}
func TestAckWriter_Write(t *testing.T) {
tests := []struct {
name string
input []byte
expectErr bool
expectedError string
}{
{
name: "Message exceeds max size",
input: make([]byte, maxMessageSize+1),
expectErr: true,
expectedError: "message size exceeds maximum allowed size",
},
{
name: "Write succeeds",
input: []byte("Hello, world!"),
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := &MockConn{
mu: sync.Mutex{},
}
writer := NewAckWriter(mockConn)
defer writer.Close()
if tt.expectErr {
writer.(*AckWriter).ctx.Done()
}
time.Sleep(100 * time.Millisecond)
n, err := writer.Write(tt.input)
if tt.expectErr {
assert.Error(t, err)
if tt.expectedError != "" {
assert.Contains(t, err.Error(), tt.expectedError)
}
assert.Zero(t, n)
} else {
assert.NoError(t, err)
assert.Equal(t, len(tt.input), n)
}
})
}
}
func TestAckWriter_CleanupOldMessages(t *testing.T) {
mockConn := &MockConn{}
writer := NewAckWriter(mockConn).(*AckWriter)
defer writer.Close()
for i := uint32(1); i <= maxConcurrent+10; i++ {
msg := &Message{
ID: i,
Content: []byte("test"),
Status: StatusAcknowledged,
}
writer.messageStore.Store(i, msg)
}
writer.cleanupOldMessages(maxConcurrent + 11)
var count int
writer.messageStore.Range(func(key, value interface{}) bool {
count++
return true
})
assert.LessOrEqual(t, count, maxConcurrent)
}
func TestAckReader_LargeMessage(t *testing.T) {
mockConn := &MockConn{}
reader := NewAckReader(mockConn)
largeMessage := make([]byte, maxMessageSize-1)
for i := range largeMessage {
largeMessage[i] = byte(i % 256)
}
messageID := uint32(1)
messageLen := uint32(len(largeMessage))
mockData := make([]byte, 8+len(largeMessage))
binary.LittleEndian.PutUint32(mockData[:4], messageID)
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
copy(mockData[8:], largeMessage)
mockConn.ReadData = mockData
data, err := reader.Read()
assert.NoError(t, err)
assert.Equal(t, largeMessage, data)
assert.Equal(t, 4, len(mockConn.WrittenData))
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
assert.Equal(t, messageID, ackID)
}
func TestAckWriter_FailedSends(t *testing.T) {
mockConn := &MockConn{
WriteErr: errors.New("write error"),
}
writer := NewAckWriter(mockConn).(*AckWriter)
defer writer.Close()
// Add some messages to the channel
for i := 0; i < 5; i++ {
msg := &Message{
ID: uint32(i + 1),
Content: []byte(fmt.Sprintf("Message %d", i+1)),
Status: StatusPending,
}
writer.pendingMessages <- msg
}
// Wait for the messages to be sent
time.Sleep(100 * time.Millisecond)
// Check that the messages were marked as failed
writer.messageStore.Range(func(key, value interface{}) bool {
msg := value.(*Message)
assert.Equal(t, StatusFailed, msg.Status)
return true
})
}
-66
View File
@@ -1,66 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager
import (
"fmt"
"regexp"
"strconv"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/types/known/timestamppb"
)
var (
errFailedToParseCID = fmt.Errorf("failed to parse computation ID")
errComputationNotFound = fmt.Errorf("computation not found")
)
func (ms *managerService) computationIDFromAddress(address string) (string, error) {
re := regexp.MustCompile(`vm\((\d+)\)`)
matches := re.FindStringSubmatch(address)
if len(matches) > 1 {
cid, err := strconv.Atoi(matches[1])
if err != nil {
return "", err
}
return ms.findComputationID(cid)
}
return "", errFailedToParseCID
}
func (ms *managerService) findComputationID(cid int) (string, error) {
ms.mu.Lock()
defer ms.mu.Unlock()
for cmpID, vm := range ms.vms {
if vm.GetCID() == cid {
return cmpID, nil
}
}
return "", errComputationNotFound
}
func (ms *managerService) reportBrokenConnection(cmpID string) {
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentEvent{
AgentEvent: &AgentEvent{
EventType: ms.vms[cmpID].State(),
ComputationId: cmpID,
Status: manager.Disconnected.String(),
Timestamp: timestamppb.Now(),
Originator: "manager",
},
},
}
}
func (ms *managerService) ReportBrokenConnection(addr string) {
cmpID, err := ms.computationIDFromAddress(addr)
if err != nil {
ms.logger.Warn(err.Error())
return
}
ms.reportBrokenConnection(cmpID)
}
-64
View File
@@ -1,64 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/pkg/manager"
)
func TestComputationIDFromAddress(t *testing.T) {
ms := &managerService{
vms: map[string]vm.VM{
"comp1": qemu.NewVM(qemu.VMInfo{Config: qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}}, func(event interface{}) error { return nil }, "comp1"),
"comp2": qemu.NewVM(qemu.VMInfo{Config: qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}}, func(event interface{}) error { return nil }, "comp2"),
},
}
tests := []struct {
name string
address string
want string
wantErr bool
}{
{"Valid address", "vm(3)", "comp1", false},
{"Invalid address", "invalid", "", true},
{"Non-existent CID", "vm(10)", "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ms.computationIDFromAddress(tt.address)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
func TestReportBrokenConnection(t *testing.T) {
ms := &managerService{
eventsChan: make(chan *ClientStreamMessage, 1),
vms: map[string]vm.VM{
"comp1": qemu.NewVM(qemu.VMInfo{Config: qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}}, func(event interface{}) error { return nil }, "comp1"),
},
}
ms.reportBrokenConnection("comp1")
select {
case msg := <-ms.eventsChan:
assert.Equal(t, "comp1", msg.GetAgentEvent().ComputationId)
assert.Equal(t, manager.Disconnected.String(), msg.GetAgentEvent().Status)
assert.Equal(t, "manager", msg.GetAgentEvent().Originator)
default:
t.Error("Expected message in eventsChan, but none received")
}
}
-242
View File
@@ -1,242 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"log/slog"
"sync"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/qemu"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
)
var (
errTerminationFromServer = errors.New("server requested client termination")
errCorruptedManifest = errors.New("received manifest may be corrupted")
sendTimeout = 5 * time.Second
)
type ManagerClient struct {
stream manager.ManagerService_ProcessClient
svc manager.Service
messageQueue chan *manager.ClientStreamMessage
logger *slog.Logger
runReqManager *runRequestManager
}
// NewClient returns new gRPC client instance.
func NewClient(stream manager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *manager.ClientStreamMessage, logger *slog.Logger) ManagerClient {
return ManagerClient{
stream: stream,
svc: svc,
messageQueue: messageQueue,
logger: logger,
runReqManager: newRunRequestManager(),
}
}
func (client ManagerClient) Process(ctx context.Context, cancel context.CancelFunc) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return client.handleIncomingMessages(ctx)
})
eg.Go(func() error {
return client.handleOutgoingMessages(ctx)
})
return eg.Wait()
}
func (client ManagerClient) handleIncomingMessages(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
req, err := client.stream.Recv()
if err != nil {
return err
}
if err := client.processIncomingMessage(ctx, req); err != nil {
return err
}
}
}
}
func (client ManagerClient) processIncomingMessage(ctx context.Context, req *manager.ServerStreamMessage) error {
switch mes := req.Message.(type) {
case *manager.ServerStreamMessage_RunReqChunks:
return client.handleRunReqChunks(ctx, mes)
case *manager.ServerStreamMessage_TerminateReq:
return client.handleTerminateReq(mes)
case *manager.ServerStreamMessage_StopComputation:
go client.handleStopComputation(ctx, mes)
case *manager.ServerStreamMessage_AttestationPolicyReq:
go client.handleAttestationPolicyReq(ctx, mes)
case *manager.ServerStreamMessage_SvmInfoReq:
go client.handleSVMInfoReq(ctx, mes)
default:
return errors.New("unknown message type")
}
return nil
}
func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *manager.ServerStreamMessage_RunReqChunks) error {
buffer, complete := client.runReqManager.addChunk(mes.RunReqChunks.Id, mes.RunReqChunks.Data, mes.RunReqChunks.IsLast)
if complete {
var runReq manager.ComputationRunReq
if err := proto.Unmarshal(buffer, &runReq); err != nil {
return errors.Wrap(err, errCorruptedManifest)
}
go client.executeRun(ctx, &runReq)
}
return nil
}
func (client ManagerClient) executeRun(ctx context.Context, runReq *manager.ComputationRunReq) {
port, err := client.svc.Run(ctx, runReq)
if err != nil {
client.logger.Warn(err.Error())
return
}
runRes := &manager.ClientStreamMessage_RunRes{
RunRes: &manager.RunResponse{
AgentPort: port,
ComputationId: runReq.Id,
},
}
client.sendMessage(&manager.ClientStreamMessage{Message: runRes})
}
func (client ManagerClient) handleTerminateReq(mes *manager.ServerStreamMessage_TerminateReq) error {
return errors.Wrap(errTerminationFromServer, errors.New(mes.TerminateReq.Message))
}
func (client ManagerClient) handleStopComputation(ctx context.Context, mes *manager.ServerStreamMessage_StopComputation) {
msg := &manager.ClientStreamMessage_StopComputationRes{
StopComputationRes: &manager.StopComputationResponse{
ComputationId: mes.StopComputation.ComputationId,
},
}
if err := client.svc.Stop(ctx, mes.StopComputation.ComputationId); err != nil {
msg.StopComputationRes.Message = err.Error()
}
client.sendMessage(&manager.ClientStreamMessage{Message: msg})
}
func (client ManagerClient) handleAttestationPolicyReq(ctx context.Context, mes *manager.ServerStreamMessage_AttestationPolicyReq) {
res, err := client.svc.FetchAttestationPolicy(ctx, mes.AttestationPolicyReq.Id)
if err != nil {
client.logger.Warn(err.Error())
return
}
info := &manager.ClientStreamMessage_AttestationPolicy{
AttestationPolicy: &manager.AttestationPolicy{
Info: res,
Id: mes.AttestationPolicyReq.Id,
},
}
client.sendMessage(&manager.ClientStreamMessage{Message: info})
}
func (client ManagerClient) handleSVMInfoReq(ctx context.Context, mes *manager.ServerStreamMessage_SvmInfoReq) {
ovmfVersion, cpuNum, cpuType, eosVersion := client.svc.ReturnSVMInfo(ctx)
info := &manager.ClientStreamMessage_SvmInfo{
SvmInfo: &manager.SVMInfo{
OvmfVersion: ovmfVersion,
CpuNum: int32(cpuNum),
CpuType: cpuType,
KernelCmd: qemu.KernelCommandLine,
EosVersion: eosVersion,
Id: mes.SvmInfoReq.Id,
},
}
client.sendMessage(&manager.ClientStreamMessage{Message: info})
}
func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case mes := <-client.messageQueue:
if err := client.stream.Send(mes); err != nil {
return err
}
}
}
}
func (client ManagerClient) sendMessage(mes *manager.ClientStreamMessage) {
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
defer cancel()
select {
case client.messageQueue <- mes:
case <-ctx.Done():
client.logger.Warn("Failed to send message: timeout exceeded")
}
}
type runRequestManager struct {
requests map[string]*runRequest
mu sync.Mutex
}
type runRequest struct {
buffer []byte
lastChunk time.Time
timer *time.Timer
}
func newRunRequestManager() *runRequestManager {
return &runRequestManager{
requests: make(map[string]*runRequest),
}
}
func (m *runRequestManager) addChunk(id string, chunk []byte, isLast bool) ([]byte, bool) {
m.mu.Lock()
defer m.mu.Unlock()
req, exists := m.requests[id]
if !exists {
req = &runRequest{
buffer: make([]byte, 0),
lastChunk: time.Now(),
timer: time.AfterFunc(runReqTimeout, func() { m.timeoutRequest(id) }),
}
m.requests[id] = req
}
req.buffer = append(req.buffer, chunk...)
req.lastChunk = time.Now()
req.timer.Reset(runReqTimeout)
if isLast {
delete(m.requests, id)
req.timer.Stop()
return req.buffer, true
}
return nil, false
}
func (m *runRequestManager) timeoutRequest(id string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.requests, id)
// Log timeout or handle it as needed
}
-322
View File
@@ -1,322 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/mocks"
"github.com/ultravioletrs/cocos/manager/qemu"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
type mockStream struct {
mock.Mock
grpc.ClientStream
}
func (m *mockStream) Recv() (*manager.ServerStreamMessage, error) {
args := m.Called()
return args.Get(0).(*manager.ServerStreamMessage), args.Error(1)
}
func (m *mockStream) Send(msg *manager.ClientStreamMessage) error {
args := m.Called(msg)
return args.Error(0)
}
func TestManagerClient_Process1(t *testing.T) {
tests := []struct {
name string
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service)
expectError bool
errorMsg string
}{
{
name: "Stop computation",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_StopComputation{
StopComputation: &manager.StopComputation{},
},
}, nil)
mockStream.On("Send", mock.Anything).Return(nil)
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil)
},
expectError: true,
errorMsg: "context deadline exceeded",
},
{
name: "Terminate request",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_TerminateReq{
TerminateReq: &manager.Terminate{},
},
}, nil)
},
expectError: true,
errorMsg: errTerminationFromServer.Error(),
},
{
name: "Attestation Policy request",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_AttestationPolicyReq{
AttestationPolicyReq: &manager.AttestationPolicyReq{},
},
}, nil)
mockStream.On("Send", mock.Anything).Return(nil).Once()
mockSvc.On("FetchAttestationPolicy", mock.Anything, mock.Anything).Return(nil, assert.AnError)
},
expectError: true,
},
{
name: "Run request chunks",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &manager.RunReqChunks{},
},
}, nil)
mockStream.On("Send", mock.Anything).Return(nil).Once()
mockSvc.On("Run", mock.Anything, mock.Anything).Return("", assert.AnError).Once()
},
expectError: true,
},
{
name: "Receive error",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
mockStream.On("Recv").Return(&manager.ServerStreamMessage{}, assert.AnError)
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
tc.setupMocks(mockStream, mockSvc)
err := client.Process(ctx, cancel)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
assert.Contains(t, err.Error(), tc.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestManagerClient_handleRunReqChunks(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
runReq := &manager.ComputationRunReq{
Id: "test-id",
}
runReqBytes, _ := proto.Marshal(runReq)
chunk1 := &manager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &manager.RunReqChunks{
Id: "chunk-1",
Data: runReqBytes[:len(runReqBytes)/2],
IsLast: false,
},
}
chunk2 := &manager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &manager.RunReqChunks{
Id: "chunk-1",
Data: runReqBytes[len(runReqBytes)/2:],
IsLast: true,
},
}
mockSvc.On("Run", mock.Anything, mock.AnythingOfType("*manager.ComputationRunReq")).Return("8080", nil)
err := client.handleRunReqChunks(context.Background(), chunk1)
assert.NoError(t, err)
err = client.handleRunReqChunks(context.Background(), chunk2)
assert.NoError(t, err)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
runRes, ok := msg.Message.(*manager.ClientStreamMessage_RunRes)
assert.True(t, ok)
assert.Equal(t, "8080", runRes.RunRes.AgentPort)
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
}
func TestManagerClient_handleTerminateReq(t *testing.T) {
client := ManagerClient{}
terminateReq := &manager.ServerStreamMessage_TerminateReq{
TerminateReq: &manager.Terminate{
Message: "Test termination",
},
}
err := client.handleTerminateReq(terminateReq)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Test termination")
assert.True(t, errors.Contains(err, errTerminationFromServer))
}
func TestManagerClient_handleStopComputation(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
stopReq := &manager.ServerStreamMessage_StopComputation{
StopComputation: &manager.StopComputation{
ComputationId: "test-comp-id",
},
}
mockSvc.On("Stop", mock.Anything, "test-comp-id").Return(nil)
client.handleStopComputation(context.Background(), stopReq)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
stopRes, ok := msg.Message.(*manager.ClientStreamMessage_StopComputationRes)
assert.True(t, ok)
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
assert.Empty(t, stopRes.StopComputationRes.Message)
}
func TestManagerClient_handleAttestationPolicyReq(t *testing.T) {
t.Run("success", func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
infoReq := &manager.ServerStreamMessage_AttestationPolicyReq{
AttestationPolicyReq: &manager.AttestationPolicyReq{
Id: "test-info-id",
},
}
mockSvc.On("FetchAttestationPolicy", context.Background(), infoReq.AttestationPolicyReq.Id).Return([]byte("test-attestation-policy"), nil)
client.handleAttestationPolicyReq(context.Background(), infoReq)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
infoRes, ok := msg.Message.(*manager.ClientStreamMessage_AttestationPolicy)
assert.True(t, ok)
assert.Equal(t, "test-info-id", infoRes.AttestationPolicy.Id)
assert.Equal(t, []byte("test-attestation-policy"), infoRes.AttestationPolicy.Info)
})
t.Run("error", func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
infoReq := &manager.ServerStreamMessage_AttestationPolicyReq{
AttestationPolicyReq: &manager.AttestationPolicyReq{
Id: "test-info-id",
},
}
mockSvc.On("FetchAttestationPolicy", context.Background(), infoReq.AttestationPolicyReq.Id).Return(nil, assert.AnError)
client.handleAttestationPolicyReq(context.Background(), infoReq)
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 0)
})
}
func TestManagerClient_handleSVMInfoReq(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
mockSvc.On("ReturnSVMInfo", context.Background()).Return("edk2-stable202408", 4, "EPYC", "")
client.handleSVMInfoReq(context.Background(), &manager.ServerStreamMessage_SvmInfoReq{SvmInfoReq: &manager.SVMInfoReq{Id: "test-svm-info-id"}})
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
infoRes, ok := msg.Message.(*manager.ClientStreamMessage_SvmInfo)
assert.True(t, ok)
assert.Equal(t, "edk2-stable202408", infoRes.SvmInfo.OvmfVersion)
assert.Equal(t, int32(4), infoRes.SvmInfo.CpuNum)
assert.Equal(t, "EPYC", infoRes.SvmInfo.CpuType)
assert.Equal(t, "", infoRes.SvmInfo.EosVersion)
assert.Equal(t, qemu.KernelCommandLine, infoRes.SvmInfo.KernelCmd)
}
func TestManagerClient_timeoutRequest(t *testing.T) {
rm := newRunRequestManager()
rm.requests["test-id"] = &runRequest{
timer: time.NewTimer(100 * time.Millisecond),
buffer: []byte("test-data"),
lastChunk: time.Now(),
}
rm.timeoutRequest("test-id")
assert.Len(t, rm.requests, 0)
}
+43 -104
View File
@@ -3,17 +3,11 @@
package grpc
import (
"bytes"
"context"
"errors"
"io"
"time"
"github.com/ultravioletrs/cocos/manager"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"
)
var (
@@ -21,113 +15,58 @@ var (
ErrUnexpectedMsg = errors.New("unknown message type")
)
const (
bufferSize = 1024 * 1024 // 1 MB
runReqTimeout = 30 * time.Second
)
type SendFunc func(*manager.ServerStreamMessage) error
type grpcServer struct {
manager.UnimplementedManagerServiceServer
incoming chan *manager.ClientStreamMessage
svc Service
}
type Service interface {
Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo)
svc manager.Service
}
// NewServer returns new AuthServiceServer instance.
func NewServer(incoming chan *manager.ClientStreamMessage, svc Service) manager.ManagerServiceServer {
func NewServer(svc manager.Service) manager.ManagerServiceServer {
return &grpcServer{
incoming: incoming,
svc: svc,
svc: svc,
}
}
func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error {
client, ok := peer.FromContext(stream.Context())
if !ok {
return errors.New("failed to get peer info")
}
eg, ctx := errgroup.WithContext(stream.Context())
eg.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
req, err := stream.Recv()
if err != nil {
return err
}
s.incoming <- req
}
}
})
eg.Go(func() error {
sendMessage := func(msg *manager.ServerStreamMessage) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
switch m := msg.Message.(type) {
case *manager.ServerStreamMessage_RunReq:
return s.sendRunReqInChunks(stream, m.RunReq)
default:
return stream.Send(msg)
}
}
}
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
return nil
})
return eg.Wait()
}
func (s *grpcServer) sendRunReqInChunks(stream manager.ManagerService_ProcessServer, runReq *manager.ComputationRunReq) error {
data, err := proto.Marshal(runReq)
func (s *grpcServer) CreateVm(ctx context.Context, _ *emptypb.Empty) (*manager.CreateRes, error) {
port, id, err := s.svc.CreateVM(ctx)
if err != nil {
return err
return nil, err
}
dataBuffer := bytes.NewBuffer(data)
buf := make([]byte, bufferSize)
for {
n, err := dataBuffer.Read(buf)
isLast := false
if err == io.EOF {
isLast = true
} else if err != nil {
return err
}
chunk := &manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &manager.RunReqChunks{
Id: runReq.Id,
Data: buf[:n],
IsLast: isLast,
},
},
}
if err := stream.Send(chunk); err != nil {
return err
}
if isLast {
break
}
}
return nil
return &manager.CreateRes{
ForwardedPort: port,
SvmId: id,
}, nil
}
func (s *grpcServer) RemoveVm(ctx context.Context, req *manager.RemoveReq) (*emptypb.Empty, error) {
if err := s.svc.RemoveVM(ctx, req.SvmId); err != nil {
return nil, err
}
return &emptypb.Empty{}, nil
}
func (s *grpcServer) SVMInfo(ctx context.Context, req *manager.SVMInfoReq) (*manager.SVMInfoRes, error) {
ovmf, cpunum, cputype, eosversion := s.svc.ReturnSVMInfo(ctx)
return &manager.SVMInfoRes{
OvmfVersion: ovmf,
CpuNum: int32(cpunum),
CpuType: cputype,
EosVersion: eosversion,
Id: req.Id,
}, nil
}
func (s *grpcServer) AttestationPolicy(ctx context.Context, req *manager.AttestationPolicyReq) (*manager.AttestationPolicyRes, error) {
policy, err := s.svc.FetchAttestationPolicy(ctx, req.Id)
if err != nil {
return nil, err
}
return &manager.AttestationPolicyRes{
Info: policy,
Id: req.Id,
}, nil
}
-290
View File
@@ -1,290 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"testing"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
)
type mockServerStream struct {
mock.Mock
manager.ManagerService_ProcessServer
}
func (m *mockServerStream) Send(msg *manager.ServerStreamMessage) error {
args := m.Called(msg)
return args.Error(0)
}
func (m *mockServerStream) Recv() (*manager.ClientStreamMessage, error) {
args := m.Called()
return args.Get(0).(*manager.ClientStreamMessage), args.Error(1)
}
func (m *mockServerStream) Context() context.Context {
args := m.Called()
return args.Get(0).(context.Context)
}
type mockService struct {
mock.Mock
}
func (m *mockService) Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) {
m.Called(ctx, ipAddress, sendMessage, authInfo)
}
func TestNewServer(t *testing.T) {
incoming := make(chan *manager.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc)
assert.NotNil(t, server)
assert.IsType(t, &grpcServer{}, server)
}
func TestGrpcServer_Process(t *testing.T) {
tests := []struct {
name string
recvReturn *manager.ClientStreamMessage
recvError error
expectedError string
}{
{
name: "Process with context deadline exceeded",
recvReturn: &manager.ClientStreamMessage{},
recvError: nil,
expectedError: "context deadline exceeded",
},
{
name: "Process with Recv error",
recvReturn: &manager.ClientStreamMessage{},
recvError: errors.New("recv error"),
expectedError: "recv error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
incoming := make(chan *manager.ClientStreamMessage, 1)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
Addr: mockAddr{},
AuthInfo: mockAuthInfo{},
}))
if tt.recvError == nil {
go func() {
for mes := range incoming {
assert.NotNil(t, mes)
}
}()
}
mockStream.On("Recv").Return(tt.recvReturn, tt.recvError)
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
err := server.Process(mockStream)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedError)
mockStream.AssertExpectations(t)
mockSvc.AssertExpectations(t)
})
}
}
func TestGrpcServer_sendRunReqInChunks(t *testing.T) {
incoming := make(chan *manager.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
runReq := &manager.ComputationRunReq{
Id: "test-id",
}
largePayload := make([]byte, bufferSize*2)
for i := range largePayload {
largePayload[i] = byte(i % 256)
}
runReq.Algorithm = &manager.Algorithm{}
runReq.Algorithm.UserKey = largePayload
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil).Times(4)
err := server.sendRunReqInChunks(mockStream, runReq)
assert.NoError(t, err)
mockStream.AssertExpectations(t)
calls := mockStream.Calls
assert.Equal(t, 4, len(calls))
for i, call := range calls {
msg := call.Arguments[0].(*manager.ServerStreamMessage)
chunk := msg.GetRunReqChunks()
assert.NotNil(t, chunk)
assert.Equal(t, "test-id", chunk.Id)
if i < 3 {
assert.False(t, chunk.IsLast)
} else {
assert.Equal(t, 0, len(chunk.Data))
assert.True(t, chunk.IsLast)
}
}
}
type mockAddr struct{}
func (mockAddr) Network() string { return "test network" }
func (mockAddr) String() string { return "test" }
type mockAuthInfo struct{}
func (mockAuthInfo) AuthType() string { return "test auth" }
func TestGrpcServer_ProcessWithMockService(t *testing.T) {
tests := []struct {
name string
setupMockFn func(*mockService, *mockServerStream)
}{
{
name: "Run Request Test",
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
Run(func(args mock.Arguments) {
sendFunc := args.Get(2).(SendFunc)
runReq := &manager.ComputationRunReq{Id: "test-run-id"}
err := sendFunc(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_RunReq{
RunReq: runReq,
},
})
assert.NoError(t, err)
}).
Return()
mockStream.On("Send", mock.MatchedBy(func(msg *manager.ServerStreamMessage) bool {
chunks := msg.GetRunReqChunks()
return chunks != nil && chunks.Id == "test-run-id"
})).Return(nil)
},
},
{
name: "Terminate Request Test",
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
Run(func(args mock.Arguments) {
sendFunc := args.Get(2).(SendFunc)
err := sendFunc(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_TerminateReq{
TerminateReq: &manager.Terminate{},
},
})
assert.NoError(t, err)
}).Return()
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
incoming := make(chan *manager.ClientStreamMessage, 10)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
go func() {
for mes := range incoming {
assert.NotNil(t, mes)
}
}()
mockStream := new(mockServerStream)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
peerCtx := peer.NewContext(ctx, &peer.Peer{
Addr: mockAddr{},
AuthInfo: mockAuthInfo{},
})
mockStream.On("Context").Return(peerCtx)
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil).Maybe()
tt.setupMockFn(mockSvc, mockStream)
go func() {
time.Sleep(150 * time.Millisecond)
cancel()
}()
err := server.Process(mockStream)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
mockStream.AssertExpectations(t)
mockSvc.AssertExpectations(t)
})
}
}
func TestGrpcServer_sendRunReqInChunksError(t *testing.T) {
incoming := make(chan *manager.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
runReq := &manager.ComputationRunReq{
Id: "test-id",
}
// Simulate an error when sending
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(errors.New("send error")).Once()
err := server.sendRunReqInChunks(mockStream, runReq)
assert.Error(t, err)
assert.Contains(t, err.Error(), "send error")
mockStream.AssertExpectations(t)
}
func TestGrpcServer_ProcessMissingPeerInfo(t *testing.T) {
incoming := make(chan *manager.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
ctx := context.Background()
// Return a context without peer info
mockStream.On("Context").Return(ctx)
err := server.Process(mockStream)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get peer info")
mockStream.AssertExpectations(t)
}
+6 -10
View File
@@ -27,9 +27,9 @@ func LoggingMiddleware(svc manager.Service, logger *slog.Logger) manager.Service
return &loggingMiddleware{logger, svc}
}
func (lm *loggingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (agentAddr string, err error) {
func (lm *loggingMiddleware) CreateVM(ctx context.Context) (agentAddr string, id string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method Run for computation took %s to complete", time.Since(begin))
message := fmt.Sprintf("Method CreateVM for id %s on port %s took %s to complete", id, agentAddr, time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
@@ -37,12 +37,12 @@ func (lm *loggingMiddleware) Run(ctx context.Context, mc *manager.ComputationRun
lm.logger.Info(message)
}(time.Now())
return lm.svc.Run(ctx, mc)
return lm.svc.CreateVM(ctx)
}
func (lm *loggingMiddleware) Stop(ctx context.Context, computationID string) (err error) {
func (lm *loggingMiddleware) RemoveVM(ctx context.Context, id string) (err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method Stop for computation took %s to complete", time.Since(begin))
message := fmt.Sprintf("Method RemoveVM for vm %s took %s to complete", id, time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
@@ -50,7 +50,7 @@ func (lm *loggingMiddleware) Stop(ctx context.Context, computationID string) (er
lm.logger.Info(message)
}(time.Now())
return lm.svc.Stop(ctx, computationID)
return lm.svc.RemoveVM(ctx, id)
}
func (lm *loggingMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId string) (body []byte, err error) {
@@ -67,10 +67,6 @@ func (lm *loggingMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId s
return lm.svc.FetchAttestationPolicy(ctx, cmpId)
}
func (lm *loggingMiddleware) ReportBrokenConnection(addr string) {
lm.svc.ReportBrokenConnection(addr)
}
func (lm *loggingMiddleware) ReturnSVMInfo(ctx context.Context) (string, int, string, string) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method ReturnSVMInfo for computation took %s to complete", time.Since(begin))
+4 -8
View File
@@ -32,22 +32,22 @@ func MetricsMiddleware(svc manager.Service, counter metrics.Counter, latency met
}
}
func (ms *metricsMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
func (ms *metricsMiddleware) CreateVM(ctx context.Context) (string, string, error) {
defer func(begin time.Time) {
ms.counter.With("method", "Run").Add(1)
ms.latency.With("method", "Run").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.Run(ctx, mc)
return ms.svc.CreateVM(ctx)
}
func (ms *metricsMiddleware) Stop(ctx context.Context, computationID string) error {
func (ms *metricsMiddleware) RemoveVM(ctx context.Context, computationID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "Stop").Add(1)
ms.latency.With("method", "Stop").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.Stop(ctx, computationID)
return ms.svc.RemoveVM(ctx, computationID)
}
func (ms *metricsMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId string) ([]byte, error) {
@@ -59,10 +59,6 @@ func (ms *metricsMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId s
return ms.svc.FetchAttestationPolicy(ctx, cmpId)
}
func (ms *metricsMiddleware) ReportBrokenConnection(addr string) {
ms.svc.ReportBrokenConnection(addr)
}
func (ms *metricsMiddleware) ReturnSVMInfo(ctx context.Context) (string, int, string, string) {
defer func(begin time.Time) {
ms.counter.With("method", "ReturnSVMInfo").Add(1)
-9
View File
@@ -1,9 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import "context"
type Listener interface {
Listen(ctx context.Context)
}
-125
View File
@@ -1,125 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"context"
"log/slog"
"net"
"github.com/mdlayher/vsock"
agentevents "github.com/ultravioletrs/cocos/agent/events"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/proto"
)
const ManagerVsockPort = 9997
type ReportBrokenConnectionFunc func(address string)
type events struct {
reportBrokenConnection ReportBrokenConnectionFunc
lis net.Listener
logger *slog.Logger
eventsChan chan *manager.ClientStreamMessage
}
func New(logger *slog.Logger, reportBrokenConnection ReportBrokenConnectionFunc, eventsChan chan *manager.ClientStreamMessage) (Listener, error) {
l, err := vsock.Listen(ManagerVsockPort, nil)
if err != nil {
return nil, err
}
return &events{
lis: l,
reportBrokenConnection: reportBrokenConnection,
logger: logger,
eventsChan: eventsChan,
}, nil
}
func (e *events) Listen(ctx context.Context) {
for {
select {
case <-ctx.Done():
e.logger.Info("Listener shutting down")
return
default:
conn, err := e.lis.Accept()
if err != nil {
e.logger.Warn(err.Error())
continue
}
go e.handleConnection(conn)
}
}
}
func (e *events) handleConnection(conn net.Conn) {
defer conn.Close()
ackReader := internalvsock.NewAckReader(conn)
for {
var message agentevents.EventsLogs
data, err := ackReader.Read()
if err != nil {
go e.reportBrokenConnection(conn.RemoteAddr().String())
e.logger.Warn(err.Error())
return
}
if err := proto.Unmarshal(data, &message); err != nil {
e.logger.Warn(err.Error())
continue
}
var mes manager.ClientStreamMessage
args := []any{}
switch message.Message.(type) {
case *agentevents.EventsLogs_AgentEvent:
args = append(args, slog.Group("agent-event",
slog.String("event-type", message.GetAgentEvent().GetEventType()),
slog.String("computation-id", message.GetAgentEvent().GetComputationId()),
slog.String("status", message.GetAgentEvent().GetStatus()),
slog.String("originator", message.GetAgentEvent().GetOriginator()),
slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()),
slog.String("details", string(message.GetAgentEvent().GetDetails()))))
mes = manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
EventType: message.GetAgentEvent().GetEventType(),
ComputationId: message.GetAgentEvent().GetComputationId(),
Status: message.GetAgentEvent().GetStatus(),
Originator: message.GetAgentEvent().GetOriginator(),
Timestamp: message.GetAgentEvent().GetTimestamp(),
Details: message.GetAgentEvent().GetDetails(),
},
},
}
case *agentevents.EventsLogs_AgentLog:
args = append(args, slog.Group("agent-log",
slog.String("computation-id", message.GetAgentLog().GetComputationId()),
slog.String("level", message.GetAgentLog().GetLevel()),
slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()),
slog.String("message", message.GetAgentLog().GetMessage())))
mes = manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentLog{
AgentLog: &manager.AgentLog{
ComputationId: message.GetAgentLog().GetComputationId(),
Level: message.GetAgentLog().GetLevel(),
Timestamp: message.GetAgentLog().GetTimestamp(),
Message: message.GetAgentLog().GetMessage(),
},
},
}
}
e.eventsChan <- &mes
e.logger.Info("", args...)
}
}
-295
View File
@@ -1,295 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"net"
"os"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
type MockVsockListener struct {
mock.Mock
}
func (m *MockVsockListener) Accept() (net.Conn, error) {
args := m.Called()
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *MockVsockListener) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockVsockListener) Addr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
var _ net.Conn = (*MockConn)(nil)
type MockConn struct {
mock.Mock
}
func (m *MockConn) Read(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockConn) LocalAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) RemoteAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) SetDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func TestNew(t *testing.T) {
logger := &slog.Logger{}
reportBrokenConnection := func(address string) {}
eventsChan := make(chan *manager.ClientStreamMessage)
e, err := New(logger, reportBrokenConnection, eventsChan)
if vsockDeviceExists() {
assert.NoError(t, err)
assert.NotNil(t, e)
assert.IsType(t, &events{}, e)
} else {
assert.Error(t, err)
}
}
func TestListen(t *testing.T) {
mockListener := new(MockVsockListener)
mockConn := new(MockConn)
e := &events{
lis: mockListener,
logger: mglog.NewMock(),
}
mockListener.On("Accept").Return(mockConn, fmt.Errorf("mock error")).Once()
mockListener.On("Accept").Return(mockConn, nil)
mockConn.On("Close").Return(nil)
mockConn.On("Read", mock.Anything).Return(0, nil)
go e.Listen(context.Background())
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
}
func TestListenContextDone(t *testing.T) {
mockListener := new(MockVsockListener)
mockConn := new(MockConn)
e := &events{
lis: mockListener,
logger: mglog.NewMock(),
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
mockListener.On("Accept").Return(mockConn, nil)
e.Listen(ctx)
time.Sleep(100 * time.Millisecond)
}
func vsockDeviceExists() bool {
fs, err := os.Stat("/dev/vsock")
if err != nil {
return false
}
if fs.Mode()&os.ModeDevice == 0 {
return false
}
return true
}
type MockConnWithBuffer struct {
mock.Mock
readBuf *bytes.Buffer
writeBuf *bytes.Buffer
}
func NewMockConnWithBuffer() *MockConnWithBuffer {
return &MockConnWithBuffer{
readBuf: new(bytes.Buffer),
writeBuf: new(bytes.Buffer),
}
}
func (m *MockConnWithBuffer) Read(b []byte) (n int, err error) {
return m.readBuf.Read(b)
}
func (m *MockConnWithBuffer) Write(b []byte) (n int, err error) {
return m.writeBuf.Write(b)
}
func (m *MockConnWithBuffer) Close() error {
return nil
}
func (m *MockConnWithBuffer) LocalAddr() net.Addr {
return nil
}
func (m *MockConnWithBuffer) RemoteAddr() net.Addr {
return &net.IPAddr{IP: net.ParseIP("localhost")}
}
func (m *MockConnWithBuffer) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConnWithBuffer) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConnWithBuffer) SetWriteDeadline(t time.Time) error {
return nil
}
func TestHandleConnection(t *testing.T) {
tests := []struct {
name string
message *manager.ClientStreamMessage
}{
{
name: "handle agent event",
message: &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
EventType: "test_event",
ComputationId: "test_computation",
Status: "test_status",
Originator: "test_originator",
Timestamp: timestamppb.Now(),
Details: []byte("test_details"),
},
},
},
},
{
name: "handle agent log",
message: &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentLog{
AgentLog: &manager.AgentLog{
ComputationId: "test_computation",
Timestamp: timestamppb.Now(),
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := NewMockConnWithBuffer()
eventsChan := make(chan *manager.ClientStreamMessage, 1)
e := &events{
logger: mglog.NewMock(),
eventsChan: eventsChan,
reportBrokenConnection: func(address string) {},
}
data, err := proto.Marshal(tt.message)
assert.NoError(t, err)
messageID := uint32(1)
err = binary.Write(mockConn.readBuf, binary.LittleEndian, messageID)
assert.NoError(t, err)
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(len(data)))
assert.NoError(t, err)
_, err = mockConn.readBuf.Write(data)
assert.NoError(t, err)
// Add EOF to signal end of stream
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
assert.NoError(t, err)
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
assert.NoError(t, err)
done := make(chan struct{})
go func() {
e.handleConnection(mockConn)
close(done)
}()
var receivedMessage *manager.ClientStreamMessage
select {
case receivedMessage = <-eventsChan:
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for message in eventsChan")
}
assert.NotNil(t, receivedMessage)
select {
case <-done:
// handleConnection has exited
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for handleConnection to exit")
}
// Check if ack was written
var receivedAck uint32
err = binary.Read(mockConn.writeBuf, binary.LittleEndian, &receivedAck)
assert.NoError(t, err)
assert.Equal(t, messageID, receivedAck)
// Ensure no unexpected calls were made on the mock
mockConn.AssertExpectations(t)
})
}
}
+136 -1250
View File
File diff suppressed because it is too large Load Diff
+12 -98
View File
@@ -3,40 +3,34 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
import "google/protobuf/empty.proto";
package manager;
option go_package = "./manager";
service ManagerService {
rpc Process(stream ClientStreamMessage) returns (stream ServerStreamMessage) {}
rpc CreateVm(google.protobuf.Empty) returns (CreateRes) {}
rpc RemoveVm(RemoveReq) returns (google.protobuf.Empty) {}
rpc SVMInfo(SVMInfoReq) returns (SVMInfoRes) {}
rpc AttestationPolicy(AttestationPolicyReq) returns (AttestationPolicyRes) {}
}
message Terminate {
string message = 1;
message CreateRes{
string forwarded_port = 1;
string svm_id = 2;
}
message StopComputation {
string computation_id = 1;
message RemoveReq{
string svm_id = 1;
}
message StopComputationResponse {
string computation_id = 1;
string message = 2;
}
message RunResponse{
string agent_port = 1;
string computation_id = 2;
}
message AttestationPolicy{
message AttestationPolicyRes{
bytes info = 1;
string id = 2;
}
message SVMInfo{
message SVMInfoRes{
string id = 1;
string ovmf_version = 2;
int32 cpu_num = 3;
@@ -45,60 +39,6 @@ message SVMInfo{
string eos_version = 6;
}
message AgentEvent {
string event_type = 1;
google.protobuf.Timestamp timestamp = 2;
string computation_id = 3;
bytes details = 4;
string originator = 5;
string status = 6;
}
message AgentLog {
string message = 1;
string computation_id = 2;
string level = 3;
google.protobuf.Timestamp timestamp = 4;
}
message ClientStreamMessage {
oneof message {
AgentLog agent_log = 1;
AgentEvent agent_event = 2;
RunResponse run_res = 3;
AttestationPolicy attestationPolicy = 4;
StopComputationResponse stopComputationRes = 5;
SVMInfo svm_info = 6;
}
}
message ServerStreamMessage {
oneof message {
RunReqChunks runReqChunks = 1;
ComputationRunReq runReq = 2;
Terminate terminateReq = 3;
StopComputation stopComputation = 4;
AttestationPolicyReq attestationPolicyReq = 5;
SVMInfoReq svmInfoReq = 6;
}
}
message RunReqChunks {
bytes data = 1;
string id = 2;
bool is_last = 3;
}
message ComputationRunReq {
string id = 1;
string name = 2;
string description = 3;
repeated Dataset datasets = 4;
Algorithm algorithm = 5;
repeated ResultConsumer result_consumers = 6;
AgentConfig agent_config = 7;
}
message AttestationPolicyReq {
string id = 1;
}
@@ -107,29 +47,3 @@ message SVMInfoReq {
string id = 1;
}
message ResultConsumer {
bytes userKey = 1;
}
message Dataset {
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
bytes userKey = 2;
string filename = 3;
}
message Algorithm {
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
bytes userKey = 2;
}
message AgentConfig {
string port = 1;
string host = 2;
string cert_file = 3;
string key_file = 4;
string client_ca_file = 5;
string server_ca_file = 6;
string log_level = 7;
bool attested_tls = 8;
}
+142 -21
View File
@@ -14,6 +14,7 @@ import (
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
emptypb "google.golang.org/protobuf/types/known/emptypb"
)
// This is a compile-time assertion to ensure that this generated file
@@ -22,14 +23,20 @@ import (
const _ = grpc.SupportPackageIsVersion9
const (
ManagerService_Process_FullMethodName = "/manager.ManagerService/Process"
ManagerService_CreateVm_FullMethodName = "/manager.ManagerService/CreateVm"
ManagerService_RemoveVm_FullMethodName = "/manager.ManagerService/RemoveVm"
ManagerService_SVMInfo_FullMethodName = "/manager.ManagerService/SVMInfo"
ManagerService_AttestationPolicy_FullMethodName = "/manager.ManagerService/AttestationPolicy"
)
// ManagerServiceClient is the client API for ManagerService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ManagerServiceClient interface {
Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error)
CreateVm(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*CreateRes, error)
RemoveVm(ctx context.Context, in *RemoveReq, opts ...grpc.CallOption) (*emptypb.Empty, error)
SVMInfo(ctx context.Context, in *SVMInfoReq, opts ...grpc.CallOption) (*SVMInfoRes, error)
AttestationPolicy(ctx context.Context, in *AttestationPolicyReq, opts ...grpc.CallOption) (*AttestationPolicyRes, error)
}
type managerServiceClient struct {
@@ -40,24 +47,54 @@ func NewManagerServiceClient(cc grpc.ClientConnInterface) ManagerServiceClient {
return &managerServiceClient{cc}
}
func (c *managerServiceClient) Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error) {
func (c *managerServiceClient) CreateVm(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*CreateRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &ManagerService_ServiceDesc.Streams[0], ManagerService_Process_FullMethodName, cOpts...)
out := new(CreateRes)
err := c.cc.Invoke(ctx, ManagerService_CreateVm_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[ClientStreamMessage, ServerStreamMessage]{ClientStream: stream}
return x, nil
return out, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type ManagerService_ProcessClient = grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage]
func (c *managerServiceClient) RemoveVm(ctx context.Context, in *RemoveReq, opts ...grpc.CallOption) (*emptypb.Empty, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, ManagerService_RemoveVm_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managerServiceClient) SVMInfo(ctx context.Context, in *SVMInfoReq, opts ...grpc.CallOption) (*SVMInfoRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SVMInfoRes)
err := c.cc.Invoke(ctx, ManagerService_SVMInfo_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managerServiceClient) AttestationPolicy(ctx context.Context, in *AttestationPolicyReq, opts ...grpc.CallOption) (*AttestationPolicyRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AttestationPolicyRes)
err := c.cc.Invoke(ctx, ManagerService_AttestationPolicy_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagerServiceServer is the server API for ManagerService service.
// All implementations must embed UnimplementedManagerServiceServer
// for forward compatibility.
type ManagerServiceServer interface {
Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error
CreateVm(context.Context, *emptypb.Empty) (*CreateRes, error)
RemoveVm(context.Context, *RemoveReq) (*emptypb.Empty, error)
SVMInfo(context.Context, *SVMInfoReq) (*SVMInfoRes, error)
AttestationPolicy(context.Context, *AttestationPolicyReq) (*AttestationPolicyRes, error)
mustEmbedUnimplementedManagerServiceServer()
}
@@ -68,8 +105,17 @@ type ManagerServiceServer interface {
// pointer dereference when methods are called.
type UnimplementedManagerServiceServer struct{}
func (UnimplementedManagerServiceServer) Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error {
return status.Errorf(codes.Unimplemented, "method Process not implemented")
func (UnimplementedManagerServiceServer) CreateVm(context.Context, *emptypb.Empty) (*CreateRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method CreateVm not implemented")
}
func (UnimplementedManagerServiceServer) RemoveVm(context.Context, *RemoveReq) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method RemoveVm not implemented")
}
func (UnimplementedManagerServiceServer) SVMInfo(context.Context, *SVMInfoReq) (*SVMInfoRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method SVMInfo not implemented")
}
func (UnimplementedManagerServiceServer) AttestationPolicy(context.Context, *AttestationPolicyReq) (*AttestationPolicyRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method AttestationPolicy not implemented")
}
func (UnimplementedManagerServiceServer) mustEmbedUnimplementedManagerServiceServer() {}
func (UnimplementedManagerServiceServer) testEmbeddedByValue() {}
@@ -92,12 +138,77 @@ func RegisterManagerServiceServer(s grpc.ServiceRegistrar, srv ManagerServiceSer
s.RegisterService(&ManagerService_ServiceDesc, srv)
}
func _ManagerService_Process_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ManagerServiceServer).Process(&grpc.GenericServerStream[ClientStreamMessage, ServerStreamMessage]{ServerStream: stream})
func _ManagerService_CreateVm_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(emptypb.Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagerServiceServer).CreateVm(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ManagerService_CreateVm_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagerServiceServer).CreateVm(ctx, req.(*emptypb.Empty))
}
return interceptor(ctx, in, info, handler)
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type ManagerService_ProcessServer = grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]
func _ManagerService_RemoveVm_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagerServiceServer).RemoveVm(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ManagerService_RemoveVm_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagerServiceServer).RemoveVm(ctx, req.(*RemoveReq))
}
return interceptor(ctx, in, info, handler)
}
func _ManagerService_SVMInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SVMInfoReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagerServiceServer).SVMInfo(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ManagerService_SVMInfo_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagerServiceServer).SVMInfo(ctx, req.(*SVMInfoReq))
}
return interceptor(ctx, in, info, handler)
}
func _ManagerService_AttestationPolicy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AttestationPolicyReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagerServiceServer).AttestationPolicy(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ManagerService_AttestationPolicy_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagerServiceServer).AttestationPolicy(ctx, req.(*AttestationPolicyReq))
}
return interceptor(ctx, in, info, handler)
}
// ManagerService_ServiceDesc is the grpc.ServiceDesc for ManagerService service.
// It's only intended for direct use with grpc.RegisterService,
@@ -105,14 +216,24 @@ type ManagerService_ProcessServer = grpc.BidiStreamingServer[ClientStreamMessage
var ManagerService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "manager.ManagerService",
HandlerType: (*ManagerServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
Methods: []grpc.MethodDesc{
{
StreamName: "Process",
Handler: _ManagerService_Process_Handler,
ServerStreams: true,
ClientStreams: true,
MethodName: "CreateVm",
Handler: _ManagerService_CreateVm_Handler,
},
{
MethodName: "RemoveVm",
Handler: _ManagerService_RemoveVm_Handler,
},
{
MethodName: "SVMInfo",
Handler: _ManagerService_SVMInfo_Handler,
},
{
MethodName: "AttestationPolicy",
Handler: _ManagerService_AttestationPolicy_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "manager/manager.proto",
}
-68
View File
@@ -1,68 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager_test
import (
"bytes"
"context"
"testing"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
func TestProcess(t *testing.T) {
ctx := context.Background()
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}
defer conn.Close()
client := manager.NewManagerServiceClient(conn)
stream, err := client.Process(ctx)
if err != nil {
t.Fatalf("Process failed: %v", err)
}
var data bytes.Buffer
for {
msg, err := stream.Recv()
if err != nil {
t.Fatalf("Failed to receive ServerStreamMessage: %v", err)
}
switch m := msg.Message.(type) {
case *manager.ServerStreamMessage_TerminateReq:
if m.TerminateReq.Message != "test terminate" {
t.Fatalf("Unexpected terminate message: %v", m.TerminateReq.Message)
}
case *manager.ServerStreamMessage_RunReqChunks:
if len(m.RunReqChunks.Data) == 0 {
var runReq manager.ComputationRunReq
if err = proto.Unmarshal(data.Bytes(), &runReq); err != nil {
t.Fatalf("Failed to create run request: %v", err)
}
runRes := &manager.ClientStreamMessage_AgentLog{
AgentLog: &manager.AgentLog{
Message: "test log",
ComputationId: "comp1",
Level: "DEBUG",
},
}
if runReq.Id != "1" || runReq.Name != "sample computation" || runReq.Description != "sample description" {
t.Fatalf("Unexpected run request message: %v", &runReq)
}
if err := stream.Send(&manager.ClientStreamMessage{Message: runRes}); err != nil {
t.Fatalf("Failed to send ClientStreamMessage: %v", err)
}
return
}
data.Write(m.RunReqChunks.Data)
default:
t.Fatalf("Unexpected message type: %T", m)
}
}
}
+91 -119
View File
@@ -9,7 +9,6 @@ import (
context "context"
mock "github.com/stretchr/testify/mock"
manager "github.com/ultravioletrs/cocos/manager"
)
// Service is an autogenerated mock type for the Service type
@@ -25,6 +24,69 @@ func (_m *Service) EXPECT() *Service_Expecter {
return &Service_Expecter{mock: &_m.Mock}
}
// CreateVM provides a mock function with given fields: ctx
func (_m *Service) CreateVM(ctx context.Context) (string, string, error) {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for CreateVM")
}
var r0 string
var r1 string
var r2 error
if rf, ok := ret.Get(0).(func(context.Context) (string, string, error)); ok {
return rf(ctx)
}
if rf, ok := ret.Get(0).(func(context.Context) string); ok {
r0 = rf(ctx)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(context.Context) string); ok {
r1 = rf(ctx)
} else {
r1 = ret.Get(1).(string)
}
if rf, ok := ret.Get(2).(func(context.Context) error); ok {
r2 = rf(ctx)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// Service_CreateVM_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateVM'
type Service_CreateVM_Call struct {
*mock.Call
}
// CreateVM is a helper method to define mock.On call
// - ctx context.Context
func (_e *Service_Expecter) CreateVM(ctx interface{}) *Service_CreateVM_Call {
return &Service_CreateVM_Call{Call: _e.mock.On("CreateVM", ctx)}
}
func (_c *Service_CreateVM_Call) Run(run func(ctx context.Context)) *Service_CreateVM_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *Service_CreateVM_Call) Return(_a0 string, _a1 string, _a2 error) *Service_CreateVM_Call {
_c.Call.Return(_a0, _a1, _a2)
return _c
}
func (_c *Service_CreateVM_Call) RunAndReturn(run func(context.Context) (string, string, error)) *Service_CreateVM_Call {
_c.Call.Return(run)
return _c
}
// FetchAttestationPolicy provides a mock function with given fields: ctx, computationID
func (_m *Service) FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error) {
ret := _m.Called(ctx, computationID)
@@ -84,35 +146,49 @@ func (_c *Service_FetchAttestationPolicy_Call) RunAndReturn(run func(context.Con
return _c
}
// ReportBrokenConnection provides a mock function with given fields: addr
func (_m *Service) ReportBrokenConnection(addr string) {
_m.Called(addr)
// RemoveVM provides a mock function with given fields: ctx, computationID
func (_m *Service) RemoveVM(ctx context.Context, computationID string) error {
ret := _m.Called(ctx, computationID)
if len(ret) == 0 {
panic("no return value specified for RemoveVM")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, computationID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_ReportBrokenConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportBrokenConnection'
type Service_ReportBrokenConnection_Call struct {
// Service_RemoveVM_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveVM'
type Service_RemoveVM_Call struct {
*mock.Call
}
// ReportBrokenConnection is a helper method to define mock.On call
// - addr string
func (_e *Service_Expecter) ReportBrokenConnection(addr interface{}) *Service_ReportBrokenConnection_Call {
return &Service_ReportBrokenConnection_Call{Call: _e.mock.On("ReportBrokenConnection", addr)}
// RemoveVM is a helper method to define mock.On call
// - ctx context.Context
// - computationID string
func (_e *Service_Expecter) RemoveVM(ctx interface{}, computationID interface{}) *Service_RemoveVM_Call {
return &Service_RemoveVM_Call{Call: _e.mock.On("RemoveVM", ctx, computationID)}
}
func (_c *Service_ReportBrokenConnection_Call) Run(run func(addr string)) *Service_ReportBrokenConnection_Call {
func (_c *Service_RemoveVM_Call) Run(run func(ctx context.Context, computationID string)) *Service_RemoveVM_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *Service_ReportBrokenConnection_Call) Return() *Service_ReportBrokenConnection_Call {
_c.Call.Return()
func (_c *Service_RemoveVM_Call) Return(_a0 error) *Service_RemoveVM_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Service_ReportBrokenConnection_Call) RunAndReturn(run func(string)) *Service_ReportBrokenConnection_Call {
func (_c *Service_RemoveVM_Call) RunAndReturn(run func(context.Context, string) error) *Service_RemoveVM_Call {
_c.Call.Return(run)
return _c
}
@@ -187,110 +263,6 @@ func (_c *Service_ReturnSVMInfo_Call) RunAndReturn(run func(context.Context) (st
return _c
}
// Run provides a mock function with given fields: ctx, c
func (_m *Service) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) {
ret := _m.Called(ctx, c)
if len(ret) == 0 {
panic("no return value specified for Run")
}
var r0 string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) (string, error)); ok {
return rf(ctx, c)
}
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) string); ok {
r0 = rf(ctx, c)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(context.Context, *manager.ComputationRunReq) error); ok {
r1 = rf(ctx, c)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Service_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
type Service_Run_Call struct {
*mock.Call
}
// Run is a helper method to define mock.On call
// - ctx context.Context
// - c *manager.ComputationRunReq
func (_e *Service_Expecter) Run(ctx interface{}, c interface{}) *Service_Run_Call {
return &Service_Run_Call{Call: _e.mock.On("Run", ctx, c)}
}
func (_c *Service_Run_Call) Run(run func(ctx context.Context, c *manager.ComputationRunReq)) *Service_Run_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*manager.ComputationRunReq))
})
return _c
}
func (_c *Service_Run_Call) Return(_a0 string, _a1 error) *Service_Run_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *Service_Run_Call) RunAndReturn(run func(context.Context, *manager.ComputationRunReq) (string, error)) *Service_Run_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function with given fields: ctx, computationID
func (_m *Service) Stop(ctx context.Context, computationID string) error {
ret := _m.Called(ctx, computationID)
if len(ret) == 0 {
panic("no return value specified for Stop")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, computationID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type Service_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
// - ctx context.Context
// - computationID string
func (_e *Service_Expecter) Stop(ctx interface{}, computationID interface{}) *Service_Stop_Call {
return &Service_Stop_Call{Call: _e.mock.On("Stop", ctx, computationID)}
}
func (_c *Service_Stop_Call) Run(run func(ctx context.Context, computationID string)) *Service_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *Service_Stop_Call) Return(_a0 error) *Service_Stop_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Service_Stop_Call) RunAndReturn(run func(context.Context, string) error) *Service_Stop_Call {
_c.Call.Return(run)
return _c
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
+10 -30
View File
@@ -13,7 +13,6 @@ import (
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
@@ -31,19 +30,17 @@ type VMInfo struct {
}
type qemuVM struct {
vmi VMInfo
cmd *exec.Cmd
eventsLogsSender vm.EventSender
computationId string
vmi VMInfo
cmd *exec.Cmd
computationId string
vm.StateMachine
}
func NewVM(config interface{}, eventsLogsSender vm.EventSender, computationId string) vm.VM {
func NewVM(config interface{}, computationId string) vm.VM {
return &qemuVM{
vmi: config.(VMInfo),
eventsLogsSender: eventsLogsSender,
computationId: computationId,
StateMachine: vm.NewStateMachine(),
vmi: config.(VMInfo),
computationId: computationId,
StateMachine: vm.NewStateMachine(),
}
}
@@ -79,8 +76,8 @@ func (v *qemuVM) Start() (err error) {
}
v.cmd = exec.Command(exe, args...)
v.cmd.Stdout = &vm.Stdout{ComputationId: v.computationId, EventSender: v.eventsLogsSender}
v.cmd.Stderr = &vm.Stderr{EventSender: v.eventsLogsSender, ComputationId: v.computationId, StateMachine: v.StateMachine}
v.cmd.Stdout = os.Stdout
v.cmd.Stderr = os.Stderr
return v.cmd.Start()
}
@@ -89,15 +86,7 @@ func (v *qemuVM) Stop() error {
defer func() {
err := v.StateMachine.Transition(manager.StopComputationRun)
if err != nil {
if err := v.eventsLogsSender(&vm.Event{
EventType: v.StateMachine.State(),
Timestamp: timestamppb.Now(),
ComputationId: v.computationId,
Originator: "manager",
Status: manager.Warning.String(),
}); err != nil {
return
}
return
}
}()
err := v.cmd.Process.Signal(syscall.SIGTERM)
@@ -163,15 +152,6 @@ func (v *qemuVM) executableAndArgs() (string, []string, error) {
func (v *qemuVM) checkVMProcessPeriodically() {
for {
if !processExists(v.GetProcess()) {
if err := v.eventsLogsSender(&vm.Event{
EventType: v.StateMachine.State(),
Timestamp: timestamppb.Now(),
ComputationId: v.computationId,
Originator: "manager",
Status: manager.Stopped.String(),
}); err != nil {
return
}
break
}
time.Sleep(interval)
+3 -36
View File
@@ -6,10 +6,8 @@ import (
"os"
"os/exec"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/manager/vm/mocks"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
@@ -19,7 +17,7 @@ const testComputationID = "test-computation"
func TestNewVM(t *testing.T) {
config := VMInfo{Config: Config{}}
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID)
vm := NewVM(config, testComputationID)
assert.NotNil(t, vm)
assert.IsType(t, &qemuVM{}, vm)
@@ -38,7 +36,7 @@ func TestStart(t *testing.T) {
QemuBinPath: "echo",
}}
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
vm := NewVM(config, testComputationID).(*qemuVM)
err = vm.Start()
assert.NoError(t, err)
@@ -61,7 +59,7 @@ func TestStartSudo(t *testing.T) {
UseSudo: true,
}}
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
vm := NewVM(config, testComputationID).(*qemuVM)
err = vm.Start()
assert.NoError(t, err)
@@ -101,9 +99,6 @@ func TestStop(t *testing.T) {
Process: cmd.Process,
},
StateMachine: sm,
eventsLogsSender: func(event interface{}) error {
return nil
},
}
err = vm.Stop()
@@ -165,31 +160,3 @@ func TestGetConfig(t *testing.T) {
config := vm.GetConfig()
assert.Equal(t, expectedConfig, config)
}
func TestCheckVMProcessPeriodically(t *testing.T) {
logsChan := make(chan interface{}, 1)
vmi := &qemuVM{
eventsLogsSender: func(event interface{}) error {
logsChan <- event
return nil
},
computationId: testComputationID,
cmd: &exec.Cmd{
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
},
StateMachine: vm.NewStateMachine(),
}
go vmi.checkVMProcessPeriodically()
select {
case msg := <-logsChan:
assert.NotNil(t, msg)
msgE := msg.(*vm.Event)
assert.Equal(t, testComputationID, msgE.ComputationId)
assert.Equal(t, pkgmanager.VmProvision.String(), msgE.EventType)
assert.Equal(t, pkgmanager.Stopped.String(), msgE.Status)
case <-time.After(2 * interval):
t.Fatal("Timeout waiting for VM stopped message")
}
}
+21 -127
View File
@@ -5,7 +5,6 @@ package manager
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"net"
@@ -17,15 +16,13 @@ import (
"syscall"
"github.com/absmach/magistrala/pkg/errors"
"github.com/cenkalti/backoff/v4"
"github.com/google/go-sev-guest/proto/check"
"github.com/ultravioletrs/cocos/agent"
"github.com/google/uuid"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/crypto/sha3"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
@@ -48,8 +45,6 @@ var (
// ErrFailedToAllocatePort indicates no free port was found on host.
ErrFailedToAllocatePort = errors.New("failed to allocate free port on host")
errInvalidHashLength = errors.New("hash must be of byte length 32")
// ErrFailedToCalculateHash indicates that agent computation returned an error while calculating the hash of the computation.
ErrFailedToCalculateHash = errors.New("error while calculating the hash of the computation")
@@ -67,13 +62,11 @@ var (
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
// Run create a computation.
Run(ctx context.Context, c *ComputationRunReq) (string, error)
CreateVM(ctx context.Context) (string, string, error)
// Stop stops a computation.
Stop(ctx context.Context, computationID string) error
RemoveVM(ctx context.Context, computationID string) error
// FetchAttestationPolicy measures and fetches the attestation policy.
FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error)
// ReportBrokenConnection reports a broken connection.
ReportBrokenConnection(addr string)
// ReturnSVMInfo returns SVM information needed for attestation verification and validation.
ReturnSVMInfo(ctx context.Context) (string, int, string, string)
}
@@ -84,7 +77,6 @@ type managerService struct {
qemuCfg qemu.Config
attestationPolicyBinaryPath string
logger *slog.Logger
eventsChan chan *ClientStreamMessage
vms map[string]vm.VM
vmFactory vm.Provider
portRangeMin int
@@ -96,7 +88,7 @@ type managerService struct {
var _ Service = (*managerService)(nil)
// New instantiates the manager service implementation.
func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger, eventsChan chan *ClientStreamMessage, vmFactory vm.Provider, eosVersion string) (Service, error) {
func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string) (Service, error) {
start, end, err := decodeRange(cfg.HostFwdRange)
if err != nil {
return nil, err
@@ -111,7 +103,6 @@ func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger,
qemuCfg: cfg,
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
vmFactory: vmFactory,
attestationPolicyBinaryPath: attestationPolicyBinPath,
portRangeMin: start,
@@ -127,7 +118,8 @@ func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger,
return ms, nil
}
func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string, error) {
func (ms *managerService) CreateVM(ctx context.Context) (string, string, error) {
id := uuid.New().String()
ms.mu.Lock()
cfg := qemu.VMInfo{
Config: ms.qemuCfg,
@@ -142,54 +134,29 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
_, err := cmd.Output()
ms.ap.Unlock()
if err != nil {
return "", errors.Wrap(ErrFailedToCreateAttestationPolicy, err)
return "", id, errors.Wrap(ErrFailedToCreateAttestationPolicy, err)
}
ms.ap.Lock()
f, err := os.ReadFile("./attestation_policy.json")
ms.ap.Unlock()
if err != nil {
return "", errors.Wrap(ErrFailedToReadPolicy, err)
return "", id, errors.Wrap(ErrFailedToReadPolicy, err)
}
var attestationPolicy check.Config
if err = protojson.Unmarshal(f, &attestationPolicy); err != nil {
return "", errors.Wrap(ErrUnmarshalFailed, err)
return "", id, errors.Wrap(ErrUnmarshalFailed, err)
}
// Define the TCB that was present at launch of the VM.
cfg.LaunchTCB = attestationPolicy.Policy.MinimumLaunchTcb
}
ms.publishEvent(manager.VmProvision.String(), c.Id, manager.Starting.String(), json.RawMessage{})
ac := agent.Computation{
ID: c.Id,
Name: c.Name,
Description: c.Description,
}
if len(c.Algorithm.Hash) != hashLength {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
return "", errInvalidHashLength
}
ac.Algorithm = agent.Algorithm{Hash: [hashLength]byte(c.Algorithm.Hash), UserKey: c.Algorithm.UserKey}
for _, data := range c.Datasets {
if len(data.Hash) != hashLength {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
return "", errInvalidHashLength
}
ac.Datasets = append(ac.Datasets, agent.Dataset{Hash: [hashLength]byte(data.Hash), UserKey: data.UserKey, Filename: data.Filename})
}
for _, rc := range c.ResultConsumers {
ac.ResultConsumers = append(ac.ResultConsumers, agent.ResultConsumer{UserKey: rc.UserKey})
}
agentPort, err := getFreePort(ms.portRangeMin, ms.portRangeMax)
if err != nil {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
return "", errors.Wrap(ErrFailedToAllocatePort, err)
return "", id, errors.Wrap(ErrFailedToAllocatePort, err)
}
cfg.Config.HostFwdAgent = agentPort
@@ -210,30 +177,23 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
cfg.Config.VSockConfig.GuestCID = cid
if cfg.Config.EnableSEVSNP {
ch, err := computationHash(ac)
if err != nil {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
return "", errors.Wrap(ErrFailedToCalculateHash, err)
}
todo := sha3.Sum256([]byte("TODO"))
// Define host-data value of QEMU for SEV-SNP, with a base64 encoding of the computation hash.
cfg.Config.SevConfig.HostData = base64.StdEncoding.EncodeToString(ch[:])
cfg.Config.SevConfig.HostData = base64.StdEncoding.EncodeToString(todo[:])
}
cvm := ms.vmFactory(cfg, ms.eventsLogsSender, c.Id)
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.InProgress.String(), json.RawMessage{})
cvm := ms.vmFactory(cfg, id)
if err = cvm.Start(); err != nil {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
return "", err
return "", id, err
}
ms.mu.Lock()
ms.vms[c.Id] = cvm
ms.vms[id] = cvm
ms.mu.Unlock()
pid := cvm.GetProcess()
state := qemu.VMState{
ID: c.Id,
ID: id,
VMinfo: cfg,
PID: pid,
}
@@ -241,34 +201,23 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
ms.logger.Error("Failed to persist VM state", "error", err)
}
err = backoff.Retry(func() error {
return cvm.SendAgentConfig(ac)
}, backoff.NewExponentialBackOff())
if err != nil {
return "", err
}
ms.mu.Lock()
if err := ms.vms[c.Id].Transition(manager.VmRunning); err != nil {
ms.logger.Warn("Failed to transition VM state", "computation", c.Id, "error", err)
if err := ms.vms[id].Transition(manager.VmRunning); err != nil {
ms.logger.Warn("Failed to transition VM state", "cvm", id, "error", err)
}
ms.mu.Unlock()
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Completed.String(), json.RawMessage{})
return fmt.Sprint(agentPort), nil
return fmt.Sprint(agentPort), id, nil
}
func (ms *managerService) Stop(ctx context.Context, computationID string) error {
func (ms *managerService) RemoveVM(ctx context.Context, computationID string) error {
ms.mu.Lock()
defer ms.mu.Unlock()
cvm, ok := ms.vms[computationID]
if !ok {
defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Failed.String(), json.RawMessage{})
return ErrNotFound
}
if err := cvm.Stop(); err != nil {
defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Failed.String(), json.RawMessage{})
return err
}
delete(ms.vms, computationID)
@@ -277,7 +226,6 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error
ms.logger.Error("Failed to delete persisted VM state", "error", err)
}
defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Completed.String(), json.RawMessage{})
return nil
}
@@ -329,30 +277,6 @@ func checkPortisFree(port int) bool {
return true
}
func (ms *managerService) publishEvent(event, cmpID, status string, details json.RawMessage) {
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentEvent{
AgentEvent: &AgentEvent{
EventType: event,
ComputationId: cmpID,
Status: status,
Details: details,
Timestamp: timestamppb.Now(),
Originator: "manager",
},
},
}
}
func computationHash(ac agent.Computation) ([32]byte, error) {
jsonData, err := json.Marshal(ac)
if err != nil {
return [32]byte{}, err
}
return sha3.Sum256(jsonData), nil
}
func decodeRange(input string) (int, int, error) {
re := regexp.MustCompile(`(\d+)-(\d+)`)
matches := re.FindStringSubmatch(input)
@@ -392,7 +316,7 @@ func (ms *managerService) restoreVMs() error {
continue
}
cvm := ms.vmFactory(state.VMinfo, ms.eventsLogsSender, state.ID)
cvm := ms.vmFactory(state.VMinfo, state.ID)
if err = cvm.SetProcess(state.PID); err != nil {
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
@@ -425,33 +349,3 @@ func (ms *managerService) processExists(pid int) bool {
}
return false
}
func (ms *managerService) eventsLogsSender(e interface{}) error {
switch msg := e.(type) {
case *vm.Event:
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentEvent{
AgentEvent: &AgentEvent{
EventType: msg.EventType,
Timestamp: msg.Timestamp,
ComputationId: msg.ComputationId,
Originator: msg.Originator,
Status: msg.Status,
Details: msg.Details,
},
},
}
case *vm.Log:
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentLog{
AgentLog: &AgentLog{
ComputationId: msg.ComputationId,
Level: msg.Level,
Timestamp: msg.Timestamp,
Message: msg.Message,
},
},
}
}
return nil
}
+6 -156
View File
@@ -4,7 +4,6 @@ package manager
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
@@ -17,7 +16,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/manager/qemu"
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
"github.com/ultravioletrs/cocos/manager/vm"
@@ -29,10 +27,9 @@ func TestNew(t *testing.T) {
HostFwdRange: "6000-6100",
}
logger := slog.Default()
eventsChan := make(chan *ClientStreamMessage)
vmf := new(mocks.Provider)
service, err := New(cfg, "", logger, eventsChan, vmf.Execute, "")
service, err := New(cfg, "", logger, vmf.Execute, "")
require.NoError(t, err)
assert.NotNil(t, service)
@@ -46,82 +43,24 @@ func TestRun(t *testing.T) {
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
tests := []struct {
name string
req *ComputationRunReq
binaryBehavior string
vmStartError error
expectedError error
}{
{
name: "Successful run",
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &AgentConfig{},
},
name: "Successful run",
binaryBehavior: "success",
vmStartError: nil,
expectedError: nil,
},
{
name: "VM start failure",
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &AgentConfig{},
},
name: "VM start failure",
binaryBehavior: "success",
vmStartError: assert.AnError,
expectedError: assert.AnError,
},
{
name: "Invalid algorithm hash",
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &Algorithm{
Hash: make([]byte, hashLength-1),
},
AgentConfig: &AgentConfig{},
},
binaryBehavior: "success",
vmStartError: nil,
expectedError: errInvalidHashLength,
},
{
name: "Invalid dataset hash",
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &AgentConfig{},
Datasets: []*Dataset{
{
Hash: make([]byte, hashLength-1),
},
},
},
binaryBehavior: "success",
vmStartError: nil,
expectedError: errInvalidHashLength,
},
{
name: "Invalid attestation policy",
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &AgentConfig{},
},
name: "Invalid attestation policy",
binaryBehavior: "fail",
vmStartError: nil,
expectedError: ErrFailedToCreateAttestationPolicy,
@@ -149,7 +88,6 @@ func TestRun(t *testing.T) {
},
}
logger := slog.Default()
eventsChan := make(chan *ClientStreamMessage, 10)
tempDir := CreateDummyAttestationPolicyBinary(t, tt.binaryBehavior)
defer os.RemoveAll(tempDir)
@@ -159,14 +97,13 @@ func TestRun(t *testing.T) {
attestationPolicyBinaryPath: tempDir,
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
vmFactory: vmf.Execute,
persistence: persistence,
}
ctx := context.Background()
port, err := ms.Run(ctx, tt.req)
port, _, err := ms.CreateVM(ctx)
if tt.expectedError != nil {
assert.Error(t, err)
@@ -179,10 +116,6 @@ func TestRun(t *testing.T) {
}
vmf.AssertExpectations(t)
for len(eventsChan) > 0 {
<-eventsChan
}
})
}
}
@@ -226,11 +159,9 @@ func TestStop(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := slog.Default()
eventsChan := make(chan *ClientStreamMessage, 10)
ms := &managerService{
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
persistence: persistence,
}
vmMock := new(mocks.VM)
@@ -247,7 +178,7 @@ func TestStop(t *testing.T) {
ms.vms[tt.computationID] = vmMock
}
err := ms.Stop(context.Background(), tt.computationID)
err := ms.RemoveVM(context.Background(), tt.computationID)
if tt.expectedError != nil {
assert.Error(t, err)
@@ -256,10 +187,6 @@ func TestStop(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, ms.vms, 0)
}
for len(eventsChan) > 0 {
<-eventsChan
}
})
}
}
@@ -278,82 +205,6 @@ func TestGetFreePort(t *testing.T) {
assert.Greater(t, port, 6000)
}
func TestPublishEvent(t *testing.T) {
tests := []struct {
name string
event string
computationID string
status string
details json.RawMessage
}{
{
name: "Standard event",
event: "test-event",
computationID: "test-computation",
status: "test-status",
details: nil,
},
{
name: "Event with details",
event: "detailed-event",
computationID: "detailed-computation",
status: "detailed-status",
details: json.RawMessage(`{"key": "value"}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
eventsChan := make(chan *ClientStreamMessage, 1)
ms := &managerService{
eventsChan: eventsChan,
}
ms.publishEvent(tt.event, tt.computationID, tt.status, tt.details)
assert.Len(t, eventsChan, 1)
event := <-eventsChan
assert.Equal(t, tt.event, event.GetAgentEvent().EventType)
assert.Equal(t, tt.computationID, event.GetAgentEvent().ComputationId)
assert.Equal(t, tt.status, event.GetAgentEvent().Status)
assert.Equal(t, "manager", event.GetAgentEvent().Originator)
assert.Equal(t, tt.details, json.RawMessage(event.GetAgentEvent().Details))
})
}
}
func TestComputationHash(t *testing.T) {
tests := []struct {
name string
computation agent.Computation
wantErr bool
}{
{
name: "Valid computation",
computation: agent.Computation{
ID: "test-id",
Name: "test-name",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hash, err := computationHash(tt.computation)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, hash)
hash2, _ := computationHash(tt.computation)
assert.Equal(t, hash, hash2)
}
})
}
}
func TestDecodeRange(t *testing.T) {
tests := []struct {
name string
@@ -393,7 +244,6 @@ func TestRestoreVMs(t *testing.T) {
ms := &managerService{
persistence: mockPersistence,
vms: make(map[string]vm.VM),
eventsChan: make(chan *ClientStreamMessage, 10),
vmFactory: vmf.Execute,
logger: mglog.NewMock(),
}
-131
View File
@@ -1,131 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager_test
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"log/slog"
"net"
"os"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/ultravioletrs/cocos/manager"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/bufconn"
)
const (
bufSize = 1024 * 1024
keyBitSize = 4096
)
var (
lis *bufconn.Listener
algoPath = "../test/manual/algo/lin_reg.py"
dataPath = "../test/manual/data/iris.csv"
attestedTLS = false
)
type svc struct {
logger *slog.Logger
t *testing.T
}
func TestMain(m *testing.M) {
logger := mglog.NewMock()
lis = bufconn.Listen(bufSize)
s := grpc.NewServer()
manager.RegisterManagerServiceServer(s, managergrpc.NewServer(make(chan *manager.ClientStreamMessage, 1), &svc{logger: logger}))
go func() {
if err := s.Serve(lis); err != nil {
panic(err)
}
}()
code := m.Run()
s.Stop()
lis.Close()
os.Exit(code)
}
func bufDialer(context.Context, string) (net.Conn, error) {
return lis.Dial()
}
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc.SendFunc, authInfo credentials.AuthInfo) {
privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize)
if err != nil {
s.t.Fatalf("Error generating public key: %v", err)
}
pubKey, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
if err != nil {
s.t.Fatalf("Error marshalling public key: %v", err)
}
pubPemBytes := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKey,
})
go func() {
time.Sleep(time.Millisecond * 100)
if err := sendMessage(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_TerminateReq{
TerminateReq: &manager.Terminate{Message: "test terminate"},
},
}); err != nil {
s.t.Fatalf("failed to send terminate request: %s", err)
}
}()
go func() {
time.Sleep(time.Millisecond * 100)
algo, err := os.ReadFile(algoPath)
if err != nil {
s.t.Fatalf("failed to read algorithm file: %s", err)
return
}
data, err := os.ReadFile(dataPath)
if err != nil {
s.t.Fatalf("failed to read data file: %s", err)
return
}
pubPem, _ := pem.Decode(pubPemBytes)
algoHash := sha3.Sum256(algo)
dataHash := sha3.Sum256(data)
if err := sendMessage(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_RunReq{
RunReq: &manager.ComputationRunReq{
Id: "1",
Name: "sample computation",
Description: "sample description",
Datasets: []*manager.Dataset{{Hash: dataHash[:], UserKey: pubPem.Bytes}},
Algorithm: &manager.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
ResultConsumers: []*manager.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: &manager.AgentConfig{
Port: "7002",
LogLevel: "debug",
AttestedTls: attestedTLS,
},
},
},
}); err != nil {
s.t.Fatalf("failed to send run request: %s", err)
}
}()
}
+4 -8
View File
@@ -21,18 +21,18 @@ func New(svc manager.Service, tracer trace.Tracer) manager.Service {
return &tracingMiddleware{tracer, svc}
}
func (tm *tracingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
func (tm *tracingMiddleware) CreateVM(ctx context.Context) (string, string, error) {
ctx, span := tm.tracer.Start(ctx, "run")
defer span.End()
return tm.svc.Run(ctx, mc)
return tm.svc.CreateVM(ctx)
}
func (tm *tracingMiddleware) Stop(ctx context.Context, computationID string) error {
func (tm *tracingMiddleware) RemoveVM(ctx context.Context, id string) error {
ctx, span := tm.tracer.Start(ctx, "stop")
defer span.End()
return tm.svc.Stop(ctx, computationID)
return tm.svc.RemoveVM(ctx, id)
}
func (tm *tracingMiddleware) FetchAttestationPolicy(ctx context.Context, computationId string) ([]byte, error) {
@@ -42,10 +42,6 @@ func (tm *tracingMiddleware) FetchAttestationPolicy(ctx context.Context, computa
return tm.svc.FetchAttestationPolicy(ctx, computationId)
}
func (tm *tracingMiddleware) ReportBrokenConnection(addr string) {
tm.svc.ReportBrokenConnection(addr)
}
func (tm *tracingMiddleware) ReturnSVMInfo(ctx context.Context) (string, int, string, string) {
_, span := tm.tracer.Start(ctx, "return_svm_info")
defer span.End()
-108
View File
@@ -1,108 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package vm
import (
"bytes"
"io"
"log/slog"
"strings"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/types/known/timestamppb"
)
var (
_ io.Writer = &Stdout{}
_ io.Writer = &Stderr{}
)
const bufSize = 1024
type Stdout struct {
EventSender EventSender
ComputationId string
}
// Write implements io.Writer.
func (s *Stdout) Write(p []byte) (n int, err error) {
inBuf := bytes.NewBuffer(p)
buf := make([]byte, bufSize)
for {
n, err := inBuf.Read(buf)
if err != nil {
if err == io.EOF {
break
}
return len(p) - inBuf.Len(), err
}
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil {
return len(p) - inBuf.Len(), err
}
}
return len(p), nil
}
type Stderr struct {
EventSender EventSender
ComputationId string
StateMachine StateMachine
}
// Write implements io.Writer.
func (s *Stderr) Write(p []byte) (n int, err error) {
inBuf := bytes.NewBuffer(p)
buf := make([]byte, bufSize)
for {
n, err := inBuf.Read(buf)
if err != nil {
if err == io.EOF {
break
}
return len(p) - inBuf.Len(), err
}
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), ""); err != nil {
return len(p) - inBuf.Len(), err
}
}
eventMsg := &Event{
ComputationId: s.ComputationId,
EventType: s.StateMachine.State(),
Timestamp: timestamppb.Now(),
Originator: "manager",
Status: pkgmanager.Warning.String(),
}
return len(p), s.EventSender(eventMsg)
}
func sendLog(eventSender EventSender, computationID, message, level string) error {
if len(message) < 3 {
return nil
}
if level == "" {
if strings.Contains(strings.ToLower(message), "warning") {
level = slog.LevelWarn.String()
} else {
level = slog.LevelError.String()
}
}
msg := Log{
Message: message,
ComputationId: computationID,
Level: level,
Timestamp: timestamppb.Now(),
}
return eventSender(&msg)
}
-180
View File
@@ -1,180 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package vm
import (
"log/slog"
"testing"
"time"
"github.com/stretchr/testify/assert"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
func TestStdoutWrite(t *testing.T) {
tests := []struct {
name string
input string
expectedWrites int
}{
{
name: "Single write within buffer size",
input: "Hello, World!",
expectedWrites: 1,
},
{
name: "Multiple writes within buffer size",
input: "This is a longer message that will be split into multiple writes.",
expectedWrites: 1,
},
{
name: "Large write exceeding buffer size",
input: string(make([]byte, bufSize*2+3)),
expectedWrites: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
eventLogChan := make(chan interface{}, 10)
s := &Stdout{
EventSender: func(event interface{}) error {
eventLogChan <- event
return nil
},
ComputationId: "test-computation",
}
n, err := s.Write([]byte(tt.input))
assert.NoError(t, err)
assert.Equal(t, len(tt.input), n)
var receivedWrites int
for i := 0; i < tt.expectedWrites; i++ {
select {
case msg := <-eventLogChan:
receivedWrites++
agentLog := msg.(*Log)
assert.NotNil(t, agentLog)
assert.Equal(t, "test-computation", agentLog.ComputationId)
assert.Equal(t, slog.LevelDebug.String(), agentLog.Level)
assert.NotEmpty(t, agentLog.Message)
assert.NotNil(t, agentLog.Timestamp)
case <-time.After(time.Second):
t.Fatal("Timed out waiting for log message")
}
}
assert.Equal(t, tt.expectedWrites, receivedWrites)
})
}
}
func TestStderrWrite(t *testing.T) {
tests := []struct {
name string
input string
expectedWrites int
}{
{
name: "Single write within buffer size",
input: "Error: Something went wrong",
expectedWrites: 1,
},
{
name: "Multiple writes within buffer size",
input: "This is a longer error message that will be split into multiple writes.",
expectedWrites: 1,
},
{
name: "Large write exceeding buffer size",
input: string(make([]byte, bufSize*2)),
expectedWrites: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
eventLogChan := make(chan interface{}, 10)
s := &Stderr{
EventSender: func(event interface{}) error {
eventLogChan <- event
return nil
},
ComputationId: "test-computation",
StateMachine: NewStateMachine(),
}
err := s.StateMachine.Transition(pkgmanager.VmRunning)
assert.NoError(t, err)
n, err := s.Write([]byte(tt.input))
assert.NoError(t, err)
assert.Equal(t, len(tt.input), n)
var receivedWrites int
for i := 0; i < tt.expectedWrites; i++ {
select {
case msg := <-eventLogChan:
receivedWrites++
switch logEv := msg.(type) {
case *Log:
assert.NotNil(t, logEv)
assert.Equal(t, "test-computation", logEv.ComputationId)
assert.Equal(t, slog.LevelError.String(), logEv.Level)
assert.NotEmpty(t, logEv.Message)
assert.NotNil(t, logEv.Timestamp)
case *Event:
assert.NotNil(t, logEv)
assert.Equal(t, "test-computation", logEv.ComputationId)
assert.Equal(t, pkgmanager.VmRunning.String(), logEv.EventType)
assert.Equal(t, pkgmanager.Warning.String(), logEv.Status)
assert.NotNil(t, logEv.Timestamp)
}
case <-time.After(time.Second):
t.Fatal("Timed out waiting for log message")
}
}
assert.Equal(t, tt.expectedWrites, receivedWrites)
})
}
}
func TestStdoutWriteErrorHandling(t *testing.T) {
eventLogChan := make(chan interface{}, 10)
s := &Stdout{
EventSender: func(event interface{}) error {
eventLogChan <- event
return assert.AnError
},
ComputationId: "test-computation",
}
message := []byte("This should fail")
n, err := s.Write(message)
assert.Error(t, err)
assert.Equal(t, len(message), n)
assert.Equal(t, assert.AnError, err)
}
func TestStderrWriteErrorHandling(t *testing.T) {
eventLogChan := make(chan interface{}, 10)
s := &Stderr{
EventSender: func(event interface{}) error {
eventLogChan <- event
return assert.AnError
},
ComputationId: "test-computation",
}
message := []byte("This should fail")
n, err := s.Write(message)
assert.Error(t, err)
assert.Equal(t, len(message), n)
assert.Equal(t, assert.AnError, err)
}
+10 -11
View File
@@ -23,17 +23,17 @@ func (_m *Provider) EXPECT() *Provider_Expecter {
return &Provider_Expecter{mock: &_m.Mock}
}
// Execute provides a mock function with given fields: config, eventSender, computationId
func (_m *Provider) Execute(config interface{}, eventSender vm.EventSender, computationId string) vm.VM {
ret := _m.Called(config, eventSender, computationId)
// Execute provides a mock function with given fields: config, computationId
func (_m *Provider) Execute(config interface{}, computationId string) vm.VM {
ret := _m.Called(config, computationId)
if len(ret) == 0 {
panic("no return value specified for Execute")
}
var r0 vm.VM
if rf, ok := ret.Get(0).(func(interface{}, vm.EventSender, string) vm.VM); ok {
r0 = rf(config, eventSender, computationId)
if rf, ok := ret.Get(0).(func(interface{}, string) vm.VM); ok {
r0 = rf(config, computationId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(vm.VM)
@@ -50,15 +50,14 @@ type Provider_Execute_Call struct {
// Execute is a helper method to define mock.On call
// - config interface{}
// - eventSender vm.EventSender
// - computationId string
func (_e *Provider_Expecter) Execute(config interface{}, eventSender interface{}, computationId interface{}) *Provider_Execute_Call {
return &Provider_Execute_Call{Call: _e.mock.On("Execute", config, eventSender, computationId)}
func (_e *Provider_Expecter) Execute(config interface{}, computationId interface{}) *Provider_Execute_Call {
return &Provider_Execute_Call{Call: _e.mock.On("Execute", config, computationId)}
}
func (_c *Provider_Execute_Call) Run(run func(config interface{}, eventSender vm.EventSender, computationId string)) *Provider_Execute_Call {
func (_c *Provider_Execute_Call) Run(run func(config interface{}, computationId string)) *Provider_Execute_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(interface{}), args[1].(vm.EventSender), args[2].(string))
run(args[0].(interface{}), args[1].(string))
})
return _c
}
@@ -68,7 +67,7 @@ func (_c *Provider_Execute_Call) Return(_a0 vm.VM) *Provider_Execute_Call {
return _c
}
func (_c *Provider_Execute_Call) RunAndReturn(run func(interface{}, vm.EventSender, string) vm.VM) *Provider_Execute_Call {
func (_c *Provider_Execute_Call) RunAndReturn(run func(interface{}, string) vm.VM) *Provider_Execute_Call {
_c.Call.Return(run)
return _c
}
+1 -3
View File
@@ -21,7 +21,7 @@ type VM interface {
GetConfig() interface{}
}
type Provider func(config interface{}, eventSender EventSender, computationId string) VM
type Provider func(config interface{}, computationId string) VM
type Event struct {
EventType string
@@ -38,5 +38,3 @@ type Log struct {
Level string
Timestamp *timestamppb.Timestamp
}
type EventSender func(event interface{}) error
+4
View File
@@ -68,6 +68,10 @@ type AgentClientConfig struct {
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
}
type ManagerClientConfig struct {
BaseConfig
}
type CVMClientConfig struct {
BaseConfig
}
+1 -1
View File
@@ -8,7 +8,7 @@ import (
)
// NewManagerClient creates new manager gRPC client instance.
func NewManagerClient(cfg grpc.CVMClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
client, err := grpc.NewClient(cfg)
if err != nil {
return nil, nil, err
+2 -2
View File
@@ -13,12 +13,12 @@ import (
func TestNewManagerClient(t *testing.T) {
tests := []struct {
name string
cfg grpc.CVMClientConfig
cfg grpc.ManagerClientConfig
err error
}{
{
name: "Valid config",
cfg: grpc.CVMClientConfig{
cfg: grpc.ManagerClientConfig{
BaseConfig: grpc.BaseConfig{
URL: "localhost:7001",
},