mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Simplify manager to vm provision only (#353)
* new agent structure Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * cvm tests fix Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * manager server, for vm provisioning Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add cli and test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * restore result cli Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix failing tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix failing test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: remove context from docker struct and use local context in Run method Signed-off-by: Sammy Oina <sammyoina@gmail.com> * delete: remove unused gRPC API and related server implementation Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ecad6514f3
commit
1f32f516b0
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "create-vm",
|
||||
Short: "Create a new virtual machine",
|
||||
Example: `create-vm`,
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
if c.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("🔗 Creating a new virtual machine")
|
||||
|
||||
res, err := c.managerClient.CreateVm(cmd.Context(), &emptypb.Empty{})
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating virtual machine: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("✅ Virtual machine created successfully with id %s and port %s", res.SvmId, res.ForwardedPort))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CLI) NewRemoveVMCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "remove-vm",
|
||||
Short: "Remove a virtual machine",
|
||||
Example: `remove-vm <svm_id>`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
if c.connectErr != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println("🔗 Removing virtual machine")
|
||||
|
||||
_, err := c.managerClient.RemoveVm(cmd.Context(), &manager.RemoveReq{SvmId: args[0]})
|
||||
if err != nil {
|
||||
printError(cmd, "Error removing virtual machine: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("✅ Virtual machine removed successfully"))
|
||||
},
|
||||
}
|
||||
}
|
||||
+27
-8
@@ -6,28 +6,33 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc/agent"
|
||||
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
)
|
||||
|
||||
var Verbose bool
|
||||
|
||||
type CLI struct {
|
||||
agentSDK sdk.SDK
|
||||
config grpc.AgentClientConfig
|
||||
client grpc.Client
|
||||
connectErr error
|
||||
agentSDK sdk.SDK
|
||||
agentConfig grpc.AgentClientConfig
|
||||
managerConfig grpc.ManagerClientConfig
|
||||
client grpc.Client
|
||||
managerClient manager.ManagerServiceClient
|
||||
connectErr error
|
||||
}
|
||||
|
||||
func New(config grpc.AgentClientConfig) *CLI {
|
||||
func New(agentConfig grpc.AgentClientConfig, managerConfig grpc.ManagerClientConfig) *CLI {
|
||||
return &CLI{
|
||||
config: config,
|
||||
agentConfig: agentConfig,
|
||||
managerConfig: managerConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CLI) InitializeSDK(cmd *cobra.Command) error {
|
||||
agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.config)
|
||||
func (c *CLI) InitializeAgentSDK(cmd *cobra.Command) error {
|
||||
agentGRPCClient, agentClient, err := agent.NewAgentClient(context.Background(), c.agentConfig)
|
||||
if err != nil {
|
||||
c.connectErr = err
|
||||
return err
|
||||
@@ -39,6 +44,20 @@ func (c *CLI) InitializeSDK(cmd *cobra.Command) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CLI) InitializeManagerClient(cmd *cobra.Command) error {
|
||||
managerGRPCClient, managerClient, err := managergrpc.NewManagerClient(c.managerConfig)
|
||||
if err != nil {
|
||||
c.connectErr = err
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("🔗 Connected to manager using ", managerGRPCClient.Secure())
|
||||
c.client = managerGRPCClient
|
||||
|
||||
c.managerClient = managerClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CLI) Close() {
|
||||
c.client.Close()
|
||||
}
|
||||
|
||||
+17
-7
@@ -19,11 +19,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "cli"
|
||||
envPrefixAgentGRPC = "AGENT_GRPC_"
|
||||
completion = "completion"
|
||||
filePermision = 0o755
|
||||
cocosDirectory = ".cocos"
|
||||
svcName = "cli"
|
||||
envPrefixAgentGRPC = "AGENT_GRPC_"
|
||||
envPrefixManagerGRPC = "MANAGER_GRPC_"
|
||||
completion = "completion"
|
||||
filePermision = 0o755
|
||||
cocosDirectory = ".cocos"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
@@ -98,9 +99,16 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
cliSVC := cli.New(agentGRPCConfig)
|
||||
managerGRPCConfig := grpc.ManagerClientConfig{}
|
||||
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixManagerGRPC}); err != nil {
|
||||
message := color.New(color.FgRed).Sprintf("failed to load %s gRPC client configuration : %s", svcName, err)
|
||||
rootCmd.Println(message)
|
||||
return
|
||||
}
|
||||
|
||||
if err := cliSVC.InitializeSDK(rootCmd); err == nil {
|
||||
cliSVC := cli.New(agentGRPCConfig, managerGRPCConfig)
|
||||
|
||||
if err := cliSVC.InitializeAgentSDK(rootCmd); err == nil {
|
||||
defer cliSVC.Close()
|
||||
}
|
||||
|
||||
@@ -119,6 +127,8 @@ func main() {
|
||||
rootCmd.AddCommand(attestationPolicyCmd)
|
||||
rootCmd.AddCommand(keysCmd)
|
||||
rootCmd.AddCommand(cliSVC.NewCABundleCmd(directoryCachePath))
|
||||
rootCmd.AddCommand(cliSVC.NewCreateVMCmd())
|
||||
rootCmd.AddCommand(cliSVC.NewRemoveVMCmd())
|
||||
|
||||
// Attestation commands
|
||||
attestationCmd.AddCommand(cliSVC.NewGetAttestationCmd())
|
||||
|
||||
+15
-47
@@ -10,25 +10,24 @@ import (
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/jaeger"
|
||||
"github.com/absmach/magistrala/pkg/prometheus"
|
||||
"github.com/absmach/magistrala/pkg/uuid"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/api"
|
||||
managerapi "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager/events"
|
||||
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/tracing"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -92,64 +91,33 @@ func main() {
|
||||
args := qemuCfg.ConstructQemuArgs()
|
||||
logger.Info(strings.Join(args, " "))
|
||||
|
||||
managerGRPCConfig := pkggrpc.CVMClientConfig{}
|
||||
managerGRPCConfig := server.ServerConfig{}
|
||||
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
managerGRPCClient, managerClient, err := managergrpc.NewManagerClient(managerGRPCConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer managerGRPCClient.Close()
|
||||
|
||||
pc, err := managerClient.Process(ctx)
|
||||
svc, err := newService(logger, tracer, qemuCfg, cfg.AttestationPolicyBinary, cfg.EosVersion)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, clientBufferSize)
|
||||
svc, err := newService(logger, tracer, qemuCfg, eventsChan, cfg.AttestationPolicyBinary, cfg.EosVersion)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
registerManagerServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
|
||||
}
|
||||
|
||||
eventsSvc, err := events.New(logger, svc.ReportBrokenConnection, eventsChan)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
go eventsSvc.Listen(ctx)
|
||||
|
||||
mc := managerapi.NewClient(pc, svc, eventsChan, logger)
|
||||
gs := grpcserver.New(ctx, cancel, svcName, managerGRPCConfig, registerManagerServiceServer, logger, nil, nil)
|
||||
|
||||
g.Go(func() error {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(ch)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
logger.Info("Received signal, shutting down...")
|
||||
cancel()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
return gs.Start()
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return mc.Process(ctx, cancel)
|
||||
return server.StopHandler(ctx, cancel, logger, svcName, gs)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
@@ -157,8 +125,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, eventsChan chan *manager.ClientStreamMessage, attestationPolicyPath string, eosVersion string) (manager.Service, error) {
|
||||
svc, err := manager.New(qemuCfg, attestationPolicyPath, logger, eventsChan, qemu.NewVM, eosVersion)
|
||||
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, attestationPolicyPath string, eosVersion string) (manager.Service, error) {
|
||||
svc, err := manager.New(qemuCfg, attestationPolicyPath, logger, qemu.NewVM, eosVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ go 1.23.0
|
||||
require (
|
||||
github.com/absmach/magistrala v0.15.1
|
||||
github.com/caarlos0/env/v11 v11.2.2
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
github.com/fatih/color v1.18.0
|
||||
github.com/go-kit/kit v0.13.0
|
||||
github.com/gofrs/uuid v4.4.0+incompatible
|
||||
@@ -25,6 +24,7 @@ require (
|
||||
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/go-connections v0.5.0 // indirect
|
||||
@@ -59,7 +59,7 @@ require (
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/go-configfs-tsm v0.2.2 // indirect
|
||||
github.com/google/logger v1.1.1
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.23.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
|
||||
@@ -1,247 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package vsock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
maxRetries = 3
|
||||
retryDelay = time.Second
|
||||
maxMessageSize = 1 << 20 // 1 MB
|
||||
ackTimeout = 5 * time.Second
|
||||
maxConcurrent = 100
|
||||
)
|
||||
|
||||
type MessageStatus int
|
||||
|
||||
const (
|
||||
StatusPending MessageStatus = iota
|
||||
StatusSent
|
||||
StatusAcknowledged
|
||||
StatusFailed
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint32
|
||||
Content []byte
|
||||
Status MessageStatus
|
||||
Retries int
|
||||
}
|
||||
|
||||
type AckWriter struct {
|
||||
conn net.Conn
|
||||
pendingMessages chan *Message
|
||||
messageStore sync.Map // map[uint32]*Message
|
||||
nextID uint32
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAckWriter(conn net.Conn) io.WriteCloser {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
aw := &AckWriter{
|
||||
conn: conn,
|
||||
pendingMessages: make(chan *Message, maxConcurrent),
|
||||
nextID: 1,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
aw.wg.Add(2)
|
||||
go aw.sendMessages()
|
||||
go aw.handleAcknowledgments()
|
||||
return aw
|
||||
}
|
||||
|
||||
func (aw *AckWriter) Write(p []byte) (int, error) {
|
||||
if len(p) > maxMessageSize {
|
||||
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
|
||||
}
|
||||
|
||||
messageID := atomic.AddUint32(&aw.nextID, 1)
|
||||
message := &Message{
|
||||
ID: messageID,
|
||||
Content: make([]byte, len(p)),
|
||||
Status: StatusPending,
|
||||
}
|
||||
copy(message.Content, p)
|
||||
|
||||
aw.messageStore.Store(messageID, message)
|
||||
select {
|
||||
case aw.pendingMessages <- message:
|
||||
return len(p), nil
|
||||
case <-aw.ctx.Done():
|
||||
return 0, fmt.Errorf("writer is closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *AckWriter) sendMessages() {
|
||||
defer aw.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-aw.ctx.Done():
|
||||
return
|
||||
case msg := <-aw.pendingMessages:
|
||||
if err := aw.sendWithRetry(msg); err != nil {
|
||||
log.Printf("Failed to send message %d after all retries: %v", msg.ID, err)
|
||||
msg.Status = StatusFailed
|
||||
aw.messageStore.Store(msg.ID, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *AckWriter) sendWithRetry(msg *Message) error {
|
||||
for msg.Retries < maxRetries {
|
||||
if err := aw.writeMessage(msg.ID, msg.Content); err != nil {
|
||||
msg.Retries++
|
||||
msg.Status = StatusPending
|
||||
log.Printf("Error writing message %d (attempt %d): %v", msg.ID, msg.Retries, err)
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
msg.Status = StatusSent
|
||||
aw.messageStore.Store(msg.ID, msg)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("max retries reached")
|
||||
}
|
||||
|
||||
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
|
||||
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
|
||||
return fmt.Errorf("failed to write message ID: %w", err)
|
||||
}
|
||||
|
||||
messageLen := uint32(len(p))
|
||||
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
|
||||
return fmt.Errorf("failed to write message length: %w", err)
|
||||
}
|
||||
|
||||
if _, err := aw.conn.Write(p); err != nil {
|
||||
return fmt.Errorf("failed to write message content: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (aw *AckWriter) handleAcknowledgments() {
|
||||
defer aw.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-aw.ctx.Done():
|
||||
return
|
||||
default:
|
||||
var ackID uint32
|
||||
if err := binary.Read(aw.conn, binary.LittleEndian, &ackID); err != nil {
|
||||
if err == io.EOF {
|
||||
log.Println("Connection closed, stopping acknowledgment handler")
|
||||
return
|
||||
}
|
||||
log.Printf("Error reading ACK: %v", err)
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
if msg, ok := aw.messageStore.Load(ackID); ok {
|
||||
m := msg.(*Message)
|
||||
m.Status = StatusAcknowledged
|
||||
aw.messageStore.Store(ackID, m)
|
||||
|
||||
// Clean up old messages periodically
|
||||
go aw.cleanupOldMessages(ackID)
|
||||
} else {
|
||||
log.Printf("Received ACK for unknown message ID: %d", ackID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *AckWriter) cleanupOldMessages(currentID uint32) {
|
||||
aw.messageStore.Range(func(key, value interface{}) bool {
|
||||
msgID := key.(uint32)
|
||||
msg := value.(*Message)
|
||||
|
||||
// Clean up acknowledged messages that are old
|
||||
if msg.Status == StatusAcknowledged && msgID < currentID-maxConcurrent {
|
||||
aw.messageStore.Delete(msgID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (aw *AckWriter) Close() error {
|
||||
aw.cancel()
|
||||
aw.wg.Wait()
|
||||
return aw.conn.Close()
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
Read() ([]byte, error)
|
||||
ReadProto(msg proto.Message) error
|
||||
}
|
||||
|
||||
type AckReader struct {
|
||||
conn net.Conn
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewAckReader(conn net.Conn) Reader {
|
||||
return &AckReader{
|
||||
conn: conn,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (ar *AckReader) ReadProto(msg proto.Message) error {
|
||||
data, err := ar.Read()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read proto message: %w", err)
|
||||
}
|
||||
return proto.Unmarshal(data, msg)
|
||||
}
|
||||
|
||||
func (ar *AckReader) Read() ([]byte, error) {
|
||||
var messageID uint32
|
||||
if err := binary.Read(ar.conn, binary.LittleEndian, &messageID); err != nil {
|
||||
return nil, fmt.Errorf("error reading message ID: %w", err)
|
||||
}
|
||||
|
||||
var messageLen uint32
|
||||
if err := binary.Read(ar.conn, binary.LittleEndian, &messageLen); err != nil {
|
||||
return nil, fmt.Errorf("error reading message length: %w", err)
|
||||
}
|
||||
|
||||
if messageLen > maxMessageSize {
|
||||
return nil, fmt.Errorf("message size %d exceeds maximum allowed size of %d bytes", messageLen, maxMessageSize)
|
||||
}
|
||||
|
||||
data := make([]byte, messageLen)
|
||||
if _, err := io.ReadFull(ar.conn, data); err != nil {
|
||||
return nil, fmt.Errorf("error reading message content: %w", err)
|
||||
}
|
||||
|
||||
if err := ar.sendAck(messageID); err != nil {
|
||||
return nil, fmt.Errorf("error sending ACK: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (ar *AckReader) sendAck(messageID uint32) error {
|
||||
return binary.Write(ar.conn, binary.LittleEndian, messageID)
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package vsock
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// MockConn implements net.Conn for testing purposes.
|
||||
type MockConn struct {
|
||||
ReadData []byte
|
||||
WrittenData []byte
|
||||
ReadErr error
|
||||
WriteErr error
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if len(m.ReadData) == 0 {
|
||||
return 0, io.EOF // Ensure we handle this case more predictably
|
||||
}
|
||||
if m.ReadErr != nil {
|
||||
return 0, m.ReadErr
|
||||
}
|
||||
n = copy(b, m.ReadData)
|
||||
m.ReadData = m.ReadData[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.closed {
|
||||
return 0, errors.New("connection closed")
|
||||
}
|
||||
if m.WriteErr != nil {
|
||||
return 0, m.WriteErr
|
||||
}
|
||||
m.WrittenData = append(m.WrittenData, b...)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Implement other net.Conn methods with empty implementations.
|
||||
func (m *MockConn) LocalAddr() net.Addr { return nil }
|
||||
func (m *MockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (m *MockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestAckReader_Read(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid message", []byte("Hello, World!"), false},
|
||||
{"Empty message", []byte{}, false},
|
||||
{"Message at max size", make([]byte, maxMessageSize), false},
|
||||
{"Message exceeds max size", make([]byte, maxMessageSize+1), true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
ar := NewAckReader(mockConn)
|
||||
|
||||
// Prepare mock data
|
||||
messageID := uint32(1)
|
||||
messageLen := uint32(len(tt.data))
|
||||
mockData := make([]byte, 8+len(tt.data))
|
||||
binary.LittleEndian.PutUint32(mockData[:4], messageID)
|
||||
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
|
||||
copy(mockData[8:], tt.data)
|
||||
mockConn.ReadData = mockData
|
||||
|
||||
data, err := ar.Read()
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AckReader.Read() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if !bytes.Equal(data, tt.data) {
|
||||
t.Errorf("AckReader.Read() got = %v, want %v", data, tt.data)
|
||||
}
|
||||
|
||||
// Check if ACK was sent
|
||||
if len(mockConn.WrittenData) != 4 {
|
||||
t.Errorf("AckReader.Read() did not send ACK")
|
||||
} else {
|
||||
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
|
||||
if ackID != messageID {
|
||||
t.Errorf("AckReader.Read() sent wrong ACK ID, got %d, want %d", ackID, messageID)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckReader_ReadProto(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *manager.ClientStreamMessage
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid proto message", &manager.ClientStreamMessage{}, false},
|
||||
{"Empty proto message", &manager.ClientStreamMessage{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
ar := NewAckReader(mockConn)
|
||||
|
||||
// Prepare mock data
|
||||
protoData, _ := proto.Marshal(tt.msg)
|
||||
messageID := uint32(1)
|
||||
messageLen := uint32(len(protoData))
|
||||
mockData := make([]byte, 8+len(protoData))
|
||||
binary.LittleEndian.PutUint32(mockData[:4], messageID)
|
||||
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
|
||||
copy(mockData[8:], protoData)
|
||||
mockConn.ReadData = mockData
|
||||
|
||||
receivedMsg := &manager.ClientStreamMessage{}
|
||||
err := ar.ReadProto(receivedMsg)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AckReader.ReadProto() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if receivedMsg.Message != tt.msg.Message {
|
||||
t.Errorf("AckReader.ReadProto() got = %v, want %v", receivedMsg, tt.msg)
|
||||
}
|
||||
|
||||
// Check if ACK was sent
|
||||
if len(mockConn.WrittenData) != 4 {
|
||||
t.Errorf("AckReader.ReadProto() did not send ACK")
|
||||
} else {
|
||||
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
|
||||
if ackID != messageID {
|
||||
t.Errorf("AckReader.ReadProto() sent wrong ACK ID, got %d, want %d", ackID, messageID)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAckWriter(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
writer := NewAckWriter(mockConn)
|
||||
|
||||
if _, ok := writer.(io.Writer); !ok {
|
||||
t.Errorf("NewAckWriter() did not return an io.Writer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAckReader(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
reader := NewAckReader(mockConn)
|
||||
|
||||
assert.NotNil(t, reader)
|
||||
}
|
||||
|
||||
func TestAckWriter_Close(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
aw := NewAckWriter(mockConn)
|
||||
|
||||
err := aw.Close()
|
||||
if err != nil {
|
||||
t.Errorf("AckWriter.Close() error = %v, wantErr %v", err, nil)
|
||||
}
|
||||
|
||||
if !mockConn.closed {
|
||||
t.Errorf("AckWriter.Close() did not close the connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckWriter_Write(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expectErr bool
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Message exceeds max size",
|
||||
input: make([]byte, maxMessageSize+1),
|
||||
expectErr: true,
|
||||
expectedError: "message size exceeds maximum allowed size",
|
||||
},
|
||||
{
|
||||
name: "Write succeeds",
|
||||
input: []byte("Hello, world!"),
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := &MockConn{
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
|
||||
writer := NewAckWriter(mockConn)
|
||||
defer writer.Close()
|
||||
|
||||
if tt.expectErr {
|
||||
writer.(*AckWriter).ctx.Done()
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
n, err := writer.Write(tt.input)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
}
|
||||
assert.Zero(t, n)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.input), n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAckWriter_CleanupOldMessages(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
writer := NewAckWriter(mockConn).(*AckWriter)
|
||||
defer writer.Close()
|
||||
|
||||
for i := uint32(1); i <= maxConcurrent+10; i++ {
|
||||
msg := &Message{
|
||||
ID: i,
|
||||
Content: []byte("test"),
|
||||
Status: StatusAcknowledged,
|
||||
}
|
||||
writer.messageStore.Store(i, msg)
|
||||
}
|
||||
|
||||
writer.cleanupOldMessages(maxConcurrent + 11)
|
||||
|
||||
var count int
|
||||
writer.messageStore.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
assert.LessOrEqual(t, count, maxConcurrent)
|
||||
}
|
||||
|
||||
func TestAckReader_LargeMessage(t *testing.T) {
|
||||
mockConn := &MockConn{}
|
||||
reader := NewAckReader(mockConn)
|
||||
|
||||
largeMessage := make([]byte, maxMessageSize-1)
|
||||
for i := range largeMessage {
|
||||
largeMessage[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
messageID := uint32(1)
|
||||
messageLen := uint32(len(largeMessage))
|
||||
mockData := make([]byte, 8+len(largeMessage))
|
||||
binary.LittleEndian.PutUint32(mockData[:4], messageID)
|
||||
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
|
||||
copy(mockData[8:], largeMessage)
|
||||
mockConn.ReadData = mockData
|
||||
|
||||
data, err := reader.Read()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeMessage, data)
|
||||
|
||||
assert.Equal(t, 4, len(mockConn.WrittenData))
|
||||
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
|
||||
assert.Equal(t, messageID, ackID)
|
||||
}
|
||||
|
||||
func TestAckWriter_FailedSends(t *testing.T) {
|
||||
mockConn := &MockConn{
|
||||
WriteErr: errors.New("write error"),
|
||||
}
|
||||
writer := NewAckWriter(mockConn).(*AckWriter)
|
||||
defer writer.Close()
|
||||
|
||||
// Add some messages to the channel
|
||||
for i := 0; i < 5; i++ {
|
||||
msg := &Message{
|
||||
ID: uint32(i + 1),
|
||||
Content: []byte(fmt.Sprintf("Message %d", i+1)),
|
||||
Status: StatusPending,
|
||||
}
|
||||
writer.pendingMessages <- msg
|
||||
}
|
||||
|
||||
// Wait for the messages to be sent
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Check that the messages were marked as failed
|
||||
writer.messageStore.Range(func(key, value interface{}) bool {
|
||||
msg := value.(*Message)
|
||||
assert.Equal(t, StatusFailed, msg.Status)
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
var (
|
||||
errFailedToParseCID = fmt.Errorf("failed to parse computation ID")
|
||||
errComputationNotFound = fmt.Errorf("computation not found")
|
||||
)
|
||||
|
||||
func (ms *managerService) computationIDFromAddress(address string) (string, error) {
|
||||
re := regexp.MustCompile(`vm\((\d+)\)`)
|
||||
matches := re.FindStringSubmatch(address)
|
||||
|
||||
if len(matches) > 1 {
|
||||
cid, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return ms.findComputationID(cid)
|
||||
}
|
||||
return "", errFailedToParseCID
|
||||
}
|
||||
|
||||
func (ms *managerService) findComputationID(cid int) (string, error) {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
for cmpID, vm := range ms.vms {
|
||||
if vm.GetCID() == cid {
|
||||
return cmpID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errComputationNotFound
|
||||
}
|
||||
|
||||
func (ms *managerService) reportBrokenConnection(cmpID string) {
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &AgentEvent{
|
||||
EventType: ms.vms[cmpID].State(),
|
||||
ComputationId: cmpID,
|
||||
Status: manager.Disconnected.String(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *managerService) ReportBrokenConnection(addr string) {
|
||||
cmpID, err := ms.computationIDFromAddress(addr)
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
ms.reportBrokenConnection(cmpID)
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
func TestComputationIDFromAddress(t *testing.T) {
|
||||
ms := &managerService{
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.VMInfo{Config: qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}}, func(event interface{}) error { return nil }, "comp1"),
|
||||
"comp2": qemu.NewVM(qemu.VMInfo{Config: qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}}, func(event interface{}) error { return nil }, "comp2"),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
address string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid address", "vm(3)", "comp1", false},
|
||||
{"Invalid address", "invalid", "", true},
|
||||
{"Non-existent CID", "vm(10)", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ms.computationIDFromAddress(tt.address)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReportBrokenConnection(t *testing.T) {
|
||||
ms := &managerService{
|
||||
eventsChan: make(chan *ClientStreamMessage, 1),
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.VMInfo{Config: qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}}, func(event interface{}) error { return nil }, "comp1"),
|
||||
},
|
||||
}
|
||||
|
||||
ms.reportBrokenConnection("comp1")
|
||||
|
||||
select {
|
||||
case msg := <-ms.eventsChan:
|
||||
assert.Equal(t, "comp1", msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, manager.Disconnected.String(), msg.GetAgentEvent().Status)
|
||||
assert.Equal(t, "manager", msg.GetAgentEvent().Originator)
|
||||
default:
|
||||
t.Error("Expected message in eventsChan, but none received")
|
||||
}
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
errTerminationFromServer = errors.New("server requested client termination")
|
||||
errCorruptedManifest = errors.New("received manifest may be corrupted")
|
||||
sendTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ManagerClient struct {
|
||||
stream manager.ManagerService_ProcessClient
|
||||
svc manager.Service
|
||||
messageQueue chan *manager.ClientStreamMessage
|
||||
logger *slog.Logger
|
||||
runReqManager *runRequestManager
|
||||
}
|
||||
|
||||
// NewClient returns new gRPC client instance.
|
||||
func NewClient(stream manager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *manager.ClientStreamMessage, logger *slog.Logger) ManagerClient {
|
||||
return ManagerClient{
|
||||
stream: stream,
|
||||
svc: svc,
|
||||
messageQueue: messageQueue,
|
||||
logger: logger,
|
||||
runReqManager: newRunRequestManager(),
|
||||
}
|
||||
}
|
||||
|
||||
func (client ManagerClient) Process(ctx context.Context, cancel context.CancelFunc) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
eg.Go(func() error {
|
||||
return client.handleIncomingMessages(ctx)
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
return client.handleOutgoingMessages(ctx)
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleIncomingMessages(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := client.stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.processIncomingMessage(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client ManagerClient) processIncomingMessage(ctx context.Context, req *manager.ServerStreamMessage) error {
|
||||
switch mes := req.Message.(type) {
|
||||
case *manager.ServerStreamMessage_RunReqChunks:
|
||||
return client.handleRunReqChunks(ctx, mes)
|
||||
case *manager.ServerStreamMessage_TerminateReq:
|
||||
return client.handleTerminateReq(mes)
|
||||
case *manager.ServerStreamMessage_StopComputation:
|
||||
go client.handleStopComputation(ctx, mes)
|
||||
case *manager.ServerStreamMessage_AttestationPolicyReq:
|
||||
go client.handleAttestationPolicyReq(ctx, mes)
|
||||
case *manager.ServerStreamMessage_SvmInfoReq:
|
||||
go client.handleSVMInfoReq(ctx, mes)
|
||||
default:
|
||||
return errors.New("unknown message type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *manager.ServerStreamMessage_RunReqChunks) error {
|
||||
buffer, complete := client.runReqManager.addChunk(mes.RunReqChunks.Id, mes.RunReqChunks.Data, mes.RunReqChunks.IsLast)
|
||||
|
||||
if complete {
|
||||
var runReq manager.ComputationRunReq
|
||||
if err := proto.Unmarshal(buffer, &runReq); err != nil {
|
||||
return errors.Wrap(err, errCorruptedManifest)
|
||||
}
|
||||
|
||||
go client.executeRun(ctx, &runReq)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client ManagerClient) executeRun(ctx context.Context, runReq *manager.ComputationRunReq) {
|
||||
port, err := client.svc.Run(ctx, runReq)
|
||||
if err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
runRes := &manager.ClientStreamMessage_RunRes{
|
||||
RunRes: &manager.RunResponse{
|
||||
AgentPort: port,
|
||||
ComputationId: runReq.Id,
|
||||
},
|
||||
}
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: runRes})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleTerminateReq(mes *manager.ServerStreamMessage_TerminateReq) error {
|
||||
return errors.Wrap(errTerminationFromServer, errors.New(mes.TerminateReq.Message))
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleStopComputation(ctx context.Context, mes *manager.ServerStreamMessage_StopComputation) {
|
||||
msg := &manager.ClientStreamMessage_StopComputationRes{
|
||||
StopComputationRes: &manager.StopComputationResponse{
|
||||
ComputationId: mes.StopComputation.ComputationId,
|
||||
},
|
||||
}
|
||||
if err := client.svc.Stop(ctx, mes.StopComputation.ComputationId); err != nil {
|
||||
msg.StopComputationRes.Message = err.Error()
|
||||
}
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: msg})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleAttestationPolicyReq(ctx context.Context, mes *manager.ServerStreamMessage_AttestationPolicyReq) {
|
||||
res, err := client.svc.FetchAttestationPolicy(ctx, mes.AttestationPolicyReq.Id)
|
||||
if err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
info := &manager.ClientStreamMessage_AttestationPolicy{
|
||||
AttestationPolicy: &manager.AttestationPolicy{
|
||||
Info: res,
|
||||
Id: mes.AttestationPolicyReq.Id,
|
||||
},
|
||||
}
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: info})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleSVMInfoReq(ctx context.Context, mes *manager.ServerStreamMessage_SvmInfoReq) {
|
||||
ovmfVersion, cpuNum, cpuType, eosVersion := client.svc.ReturnSVMInfo(ctx)
|
||||
info := &manager.ClientStreamMessage_SvmInfo{
|
||||
SvmInfo: &manager.SVMInfo{
|
||||
OvmfVersion: ovmfVersion,
|
||||
CpuNum: int32(cpuNum),
|
||||
CpuType: cpuType,
|
||||
KernelCmd: qemu.KernelCommandLine,
|
||||
EosVersion: eosVersion,
|
||||
Id: mes.SvmInfoReq.Id,
|
||||
},
|
||||
}
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: info})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case mes := <-client.messageQueue:
|
||||
if err := client.stream.Send(mes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client ManagerClient) sendMessage(mes *manager.ClientStreamMessage) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case client.messageQueue <- mes:
|
||||
case <-ctx.Done():
|
||||
client.logger.Warn("Failed to send message: timeout exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
type runRequestManager struct {
|
||||
requests map[string]*runRequest
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type runRequest struct {
|
||||
buffer []byte
|
||||
lastChunk time.Time
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newRunRequestManager() *runRequestManager {
|
||||
return &runRequestManager{
|
||||
requests: make(map[string]*runRequest),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *runRequestManager) addChunk(id string, chunk []byte, isLast bool) ([]byte, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
if !exists {
|
||||
req = &runRequest{
|
||||
buffer: make([]byte, 0),
|
||||
lastChunk: time.Now(),
|
||||
timer: time.AfterFunc(runReqTimeout, func() { m.timeoutRequest(id) }),
|
||||
}
|
||||
m.requests[id] = req
|
||||
}
|
||||
|
||||
req.buffer = append(req.buffer, chunk...)
|
||||
req.lastChunk = time.Now()
|
||||
req.timer.Reset(runReqTimeout)
|
||||
|
||||
if isLast {
|
||||
delete(m.requests, id)
|
||||
req.timer.Stop()
|
||||
return req.buffer, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *runRequestManager) timeoutRequest(id string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.requests, id)
|
||||
// Log timeout or handle it as needed
|
||||
}
|
||||
@@ -1,322 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type mockStream struct {
|
||||
mock.Mock
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (m *mockStream) Recv() (*manager.ServerStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*manager.ServerStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockStream) Send(msg *manager.ClientStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestManagerClient_Process1(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Stop computation",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &manager.StopComputation{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Terminate request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: errTerminationFromServer.Error(),
|
||||
},
|
||||
{
|
||||
name: "Attestation Policy request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_AttestationPolicyReq{
|
||||
AttestationPolicyReq: &manager.AttestationPolicyReq{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
mockSvc.On("FetchAttestationPolicy", mock.Anything, mock.Anything).Return(nil, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Run request chunks",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
mockSvc.On("Run", mock.Anything, mock.Anything).Return("", assert.AnError).Once()
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Receive error",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{}, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
tc.setupMocks(mockStream, mockSvc)
|
||||
|
||||
err := client.Process(ctx, cancel)
|
||||
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
if tc.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
runReq := &manager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk1 := &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[:len(runReqBytes)/2],
|
||||
IsLast: false,
|
||||
},
|
||||
}
|
||||
chunk2 := &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[len(runReqBytes)/2:],
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("Run", mock.Anything, mock.AnythingOfType("*manager.ComputationRunReq")).Return("8080", nil)
|
||||
|
||||
err := client.handleRunReqChunks(context.Background(), chunk1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
runRes, ok := msg.Message.(*manager.ClientStreamMessage_RunRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "8080", runRes.RunRes.AgentPort)
|
||||
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
|
||||
}
|
||||
|
||||
func TestManagerClient_handleTerminateReq(t *testing.T) {
|
||||
client := ManagerClient{}
|
||||
|
||||
terminateReq := &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{
|
||||
Message: "Test termination",
|
||||
},
|
||||
}
|
||||
|
||||
err := client.handleTerminateReq(terminateReq)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Test termination")
|
||||
assert.True(t, errors.Contains(err, errTerminationFromServer))
|
||||
}
|
||||
|
||||
func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
stopReq := &manager.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &manager.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("Stop", mock.Anything, "test-comp-id").Return(nil)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
stopRes, ok := msg.Message.(*manager.ClientStreamMessage_StopComputationRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
|
||||
assert.Empty(t, stopRes.StopComputationRes.Message)
|
||||
}
|
||||
|
||||
func TestManagerClient_handleAttestationPolicyReq(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &manager.ServerStreamMessage_AttestationPolicyReq{
|
||||
AttestationPolicyReq: &manager.AttestationPolicyReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("FetchAttestationPolicy", context.Background(), infoReq.AttestationPolicyReq.Id).Return([]byte("test-attestation-policy"), nil)
|
||||
|
||||
client.handleAttestationPolicyReq(context.Background(), infoReq)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
infoRes, ok := msg.Message.(*manager.ClientStreamMessage_AttestationPolicy)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-info-id", infoRes.AttestationPolicy.Id)
|
||||
assert.Equal(t, []byte("test-attestation-policy"), infoRes.AttestationPolicy.Info)
|
||||
})
|
||||
t.Run("error", func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &manager.ServerStreamMessage_AttestationPolicyReq{
|
||||
AttestationPolicyReq: &manager.AttestationPolicyReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("FetchAttestationPolicy", context.Background(), infoReq.AttestationPolicyReq.Id).Return(nil, assert.AnError)
|
||||
|
||||
client.handleAttestationPolicyReq(context.Background(), infoReq)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagerClient_handleSVMInfoReq(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
mockSvc.On("ReturnSVMInfo", context.Background()).Return("edk2-stable202408", 4, "EPYC", "")
|
||||
|
||||
client.handleSVMInfoReq(context.Background(), &manager.ServerStreamMessage_SvmInfoReq{SvmInfoReq: &manager.SVMInfoReq{Id: "test-svm-info-id"}})
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
infoRes, ok := msg.Message.(*manager.ClientStreamMessage_SvmInfo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "edk2-stable202408", infoRes.SvmInfo.OvmfVersion)
|
||||
assert.Equal(t, int32(4), infoRes.SvmInfo.CpuNum)
|
||||
assert.Equal(t, "EPYC", infoRes.SvmInfo.CpuType)
|
||||
assert.Equal(t, "", infoRes.SvmInfo.EosVersion)
|
||||
assert.Equal(t, qemu.KernelCommandLine, infoRes.SvmInfo.KernelCmd)
|
||||
}
|
||||
|
||||
func TestManagerClient_timeoutRequest(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
rm.requests["test-id"] = &runRequest{
|
||||
timer: time.NewTimer(100 * time.Millisecond),
|
||||
buffer: []byte("test-data"),
|
||||
lastChunk: time.Now(),
|
||||
}
|
||||
|
||||
rm.timeoutRequest("test-id")
|
||||
|
||||
assert.Len(t, rm.requests, 0)
|
||||
}
|
||||
+43
-104
@@ -3,17 +3,11 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -21,113 +15,58 @@ var (
|
||||
ErrUnexpectedMsg = errors.New("unknown message type")
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 1024 * 1024 // 1 MB
|
||||
runReqTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type SendFunc func(*manager.ServerStreamMessage) error
|
||||
|
||||
type grpcServer struct {
|
||||
manager.UnimplementedManagerServiceServer
|
||||
incoming chan *manager.ClientStreamMessage
|
||||
svc Service
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo)
|
||||
svc manager.Service
|
||||
}
|
||||
|
||||
// NewServer returns new AuthServiceServer instance.
|
||||
func NewServer(incoming chan *manager.ClientStreamMessage, svc Service) manager.ManagerServiceServer {
|
||||
func NewServer(svc manager.Service) manager.ManagerServiceServer {
|
||||
return &grpcServer{
|
||||
incoming: incoming,
|
||||
svc: svc,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error {
|
||||
client, ok := peer.FromContext(stream.Context())
|
||||
if !ok {
|
||||
return errors.New("failed to get peer info")
|
||||
}
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.incoming <- req
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
sendMessage := func(msg *manager.ServerStreamMessage) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
switch m := msg.Message.(type) {
|
||||
case *manager.ServerStreamMessage_RunReq:
|
||||
return s.sendRunReqInChunks(stream, m.RunReq)
|
||||
default:
|
||||
return stream.Send(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
|
||||
return nil
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (s *grpcServer) sendRunReqInChunks(stream manager.ManagerService_ProcessServer, runReq *manager.ComputationRunReq) error {
|
||||
data, err := proto.Marshal(runReq)
|
||||
func (s *grpcServer) CreateVm(ctx context.Context, _ *emptypb.Empty) (*manager.CreateRes, error) {
|
||||
port, id, err := s.svc.CreateVM(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataBuffer := bytes.NewBuffer(data)
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := dataBuffer.Read(buf)
|
||||
isLast := false
|
||||
|
||||
if err == io.EOF {
|
||||
isLast = true
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunk := &manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{
|
||||
Id: runReq.Id,
|
||||
Data: buf[:n],
|
||||
IsLast: isLast,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := stream.Send(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isLast {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return &manager.CreateRes{
|
||||
ForwardedPort: port,
|
||||
SvmId: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) RemoveVm(ctx context.Context, req *manager.RemoveReq) (*emptypb.Empty, error) {
|
||||
if err := s.svc.RemoveVM(ctx, req.SvmId); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &emptypb.Empty{}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) SVMInfo(ctx context.Context, req *manager.SVMInfoReq) (*manager.SVMInfoRes, error) {
|
||||
ovmf, cpunum, cputype, eosversion := s.svc.ReturnSVMInfo(ctx)
|
||||
|
||||
return &manager.SVMInfoRes{
|
||||
OvmfVersion: ovmf,
|
||||
CpuNum: int32(cpunum),
|
||||
CpuType: cputype,
|
||||
EosVersion: eosversion,
|
||||
Id: req.Id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) AttestationPolicy(ctx context.Context, req *manager.AttestationPolicyReq) (*manager.AttestationPolicyRes, error) {
|
||||
policy, err := s.svc.FetchAttestationPolicy(ctx, req.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &manager.AttestationPolicyRes{
|
||||
Info: policy,
|
||||
Id: req.Id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type mockServerStream struct {
|
||||
mock.Mock
|
||||
manager.ManagerService_ProcessServer
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Send(msg *manager.ServerStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Recv() (*manager.ClientStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*manager.ClientStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Context() context.Context {
|
||||
args := m.Called()
|
||||
return args.Get(0).(context.Context)
|
||||
}
|
||||
|
||||
type mockService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockService) Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) {
|
||||
m.Called(ctx, ipAddress, sendMessage, authInfo)
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
|
||||
server := NewServer(incoming, mockSvc)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
assert.IsType(t, &grpcServer{}, server)
|
||||
}
|
||||
|
||||
func TestGrpcServer_Process(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
recvReturn *manager.ClientStreamMessage
|
||||
recvError error
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Process with context deadline exceeded",
|
||||
recvReturn: &manager.ClientStreamMessage{},
|
||||
recvError: nil,
|
||||
expectedError: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Process with Recv error",
|
||||
recvReturn: &manager.ClientStreamMessage{},
|
||||
recvError: errors.New("recv error"),
|
||||
expectedError: "recv error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 1)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
}))
|
||||
|
||||
if tt.recvError == nil {
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mockStream.On("Recv").Return(tt.recvReturn, tt.recvError)
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunks(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
|
||||
runReq := &manager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
|
||||
largePayload := make([]byte, bufferSize*2)
|
||||
for i := range largePayload {
|
||||
largePayload[i] = byte(i % 256)
|
||||
}
|
||||
runReq.Algorithm = &manager.Algorithm{}
|
||||
runReq.Algorithm.UserKey = largePayload
|
||||
|
||||
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil).Times(4)
|
||||
|
||||
err := server.sendRunReqInChunks(mockStream, runReq)
|
||||
|
||||
assert.NoError(t, err)
|
||||
mockStream.AssertExpectations(t)
|
||||
|
||||
calls := mockStream.Calls
|
||||
assert.Equal(t, 4, len(calls))
|
||||
|
||||
for i, call := range calls {
|
||||
msg := call.Arguments[0].(*manager.ServerStreamMessage)
|
||||
chunk := msg.GetRunReqChunks()
|
||||
|
||||
assert.NotNil(t, chunk)
|
||||
assert.Equal(t, "test-id", chunk.Id)
|
||||
|
||||
if i < 3 {
|
||||
assert.False(t, chunk.IsLast)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(chunk.Data))
|
||||
assert.True(t, chunk.IsLast)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockAddr struct{}
|
||||
|
||||
func (mockAddr) Network() string { return "test network" }
|
||||
func (mockAddr) String() string { return "test" }
|
||||
|
||||
type mockAuthInfo struct{}
|
||||
|
||||
func (mockAuthInfo) AuthType() string { return "test auth" }
|
||||
|
||||
func TestGrpcServer_ProcessWithMockService(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMockFn func(*mockService, *mockServerStream)
|
||||
}{
|
||||
{
|
||||
name: "Run Request Test",
|
||||
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
runReq := &manager.ComputationRunReq{Id: "test-run-id"}
|
||||
err := sendFunc(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: runReq,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).
|
||||
Return()
|
||||
|
||||
mockStream.On("Send", mock.MatchedBy(func(msg *manager.ServerStreamMessage) bool {
|
||||
chunks := msg.GetRunReqChunks()
|
||||
return chunks != nil && chunks.Id == "test-run-id"
|
||||
})).Return(nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Terminate Request Test",
|
||||
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
err := sendFunc(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).Return()
|
||||
|
||||
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 10)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
peerCtx := peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
})
|
||||
|
||||
mockStream.On("Context").Return(peerCtx)
|
||||
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil).Maybe()
|
||||
|
||||
tt.setupMockFn(mockSvc, mockStream)
|
||||
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunksError(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
|
||||
runReq := &manager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
|
||||
// Simulate an error when sending
|
||||
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(errors.New("send error")).Once()
|
||||
|
||||
err := server.sendRunReqInChunks(mockStream, runReq)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "send error")
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGrpcServer_ProcessMissingPeerInfo(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx := context.Background()
|
||||
|
||||
// Return a context without peer info
|
||||
mockStream.On("Context").Return(ctx)
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to get peer info")
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
+6
-10
@@ -27,9 +27,9 @@ func LoggingMiddleware(svc manager.Service, logger *slog.Logger) manager.Service
|
||||
return &loggingMiddleware{logger, svc}
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (agentAddr string, err error) {
|
||||
func (lm *loggingMiddleware) CreateVM(ctx context.Context) (agentAddr string, id string, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method Run for computation took %s to complete", time.Since(begin))
|
||||
message := fmt.Sprintf("Method CreateVM for id %s on port %s took %s to complete", id, agentAddr, time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
|
||||
return
|
||||
@@ -37,12 +37,12 @@ func (lm *loggingMiddleware) Run(ctx context.Context, mc *manager.ComputationRun
|
||||
lm.logger.Info(message)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Run(ctx, mc)
|
||||
return lm.svc.CreateVM(ctx)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Stop(ctx context.Context, computationID string) (err error) {
|
||||
func (lm *loggingMiddleware) RemoveVM(ctx context.Context, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method Stop for computation took %s to complete", time.Since(begin))
|
||||
message := fmt.Sprintf("Method RemoveVM for vm %s took %s to complete", id, time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
|
||||
return
|
||||
@@ -50,7 +50,7 @@ func (lm *loggingMiddleware) Stop(ctx context.Context, computationID string) (er
|
||||
lm.logger.Info(message)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Stop(ctx, computationID)
|
||||
return lm.svc.RemoveVM(ctx, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId string) (body []byte, err error) {
|
||||
@@ -67,10 +67,6 @@ func (lm *loggingMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId s
|
||||
return lm.svc.FetchAttestationPolicy(ctx, cmpId)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ReportBrokenConnection(addr string) {
|
||||
lm.svc.ReportBrokenConnection(addr)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ReturnSVMInfo(ctx context.Context) (string, int, string, string) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method ReturnSVMInfo for computation took %s to complete", time.Since(begin))
|
||||
|
||||
@@ -32,22 +32,22 @@ func MetricsMiddleware(svc manager.Service, counter metrics.Counter, latency met
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
|
||||
func (ms *metricsMiddleware) CreateVM(ctx context.Context) (string, string, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "Run").Add(1)
|
||||
ms.latency.With("method", "Run").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.Run(ctx, mc)
|
||||
return ms.svc.CreateVM(ctx)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Stop(ctx context.Context, computationID string) error {
|
||||
func (ms *metricsMiddleware) RemoveVM(ctx context.Context, computationID string) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "Stop").Add(1)
|
||||
ms.latency.With("method", "Stop").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.Stop(ctx, computationID)
|
||||
return ms.svc.RemoveVM(ctx, computationID)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId string) ([]byte, error) {
|
||||
@@ -59,10 +59,6 @@ func (ms *metricsMiddleware) FetchAttestationPolicy(ctx context.Context, cmpId s
|
||||
return ms.svc.FetchAttestationPolicy(ctx, cmpId)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ReportBrokenConnection(addr string) {
|
||||
ms.svc.ReportBrokenConnection(addr)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ReturnSVMInfo(ctx context.Context) (string, int, string, string) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "ReturnSVMInfo").Add(1)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import "context"
|
||||
|
||||
type Listener interface {
|
||||
Listen(ctx context.Context)
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
agentevents "github.com/ultravioletrs/cocos/agent/events"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const ManagerVsockPort = 9997
|
||||
|
||||
type ReportBrokenConnectionFunc func(address string)
|
||||
|
||||
type events struct {
|
||||
reportBrokenConnection ReportBrokenConnectionFunc
|
||||
lis net.Listener
|
||||
logger *slog.Logger
|
||||
eventsChan chan *manager.ClientStreamMessage
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, reportBrokenConnection ReportBrokenConnectionFunc, eventsChan chan *manager.ClientStreamMessage) (Listener, error) {
|
||||
l, err := vsock.Listen(ManagerVsockPort, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &events{
|
||||
lis: l,
|
||||
reportBrokenConnection: reportBrokenConnection,
|
||||
logger: logger,
|
||||
eventsChan: eventsChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *events) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
e.logger.Info("Listener shutting down")
|
||||
return
|
||||
default:
|
||||
conn, err := e.lis.Accept()
|
||||
if err != nil {
|
||||
e.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
go e.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *events) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
var message agentevents.EventsLogs
|
||||
data, err := ackReader.Read()
|
||||
if err != nil {
|
||||
go e.reportBrokenConnection(conn.RemoteAddr().String())
|
||||
e.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := proto.Unmarshal(data, &message); err != nil {
|
||||
e.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
var mes manager.ClientStreamMessage
|
||||
|
||||
args := []any{}
|
||||
|
||||
switch message.Message.(type) {
|
||||
case *agentevents.EventsLogs_AgentEvent:
|
||||
args = append(args, slog.Group("agent-event",
|
||||
slog.String("event-type", message.GetAgentEvent().GetEventType()),
|
||||
slog.String("computation-id", message.GetAgentEvent().GetComputationId()),
|
||||
slog.String("status", message.GetAgentEvent().GetStatus()),
|
||||
slog.String("originator", message.GetAgentEvent().GetOriginator()),
|
||||
slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()),
|
||||
slog.String("details", string(message.GetAgentEvent().GetDetails()))))
|
||||
mes = manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: message.GetAgentEvent().GetEventType(),
|
||||
ComputationId: message.GetAgentEvent().GetComputationId(),
|
||||
Status: message.GetAgentEvent().GetStatus(),
|
||||
Originator: message.GetAgentEvent().GetOriginator(),
|
||||
Timestamp: message.GetAgentEvent().GetTimestamp(),
|
||||
Details: message.GetAgentEvent().GetDetails(),
|
||||
},
|
||||
},
|
||||
}
|
||||
case *agentevents.EventsLogs_AgentLog:
|
||||
args = append(args, slog.Group("agent-log",
|
||||
slog.String("computation-id", message.GetAgentLog().GetComputationId()),
|
||||
slog.String("level", message.GetAgentLog().GetLevel()),
|
||||
slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()),
|
||||
slog.String("message", message.GetAgentLog().GetMessage())))
|
||||
mes = manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
ComputationId: message.GetAgentLog().GetComputationId(),
|
||||
Level: message.GetAgentLog().GetLevel(),
|
||||
Timestamp: message.GetAgentLog().GetTimestamp(),
|
||||
Message: message.GetAgentLog().GetMessage(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
e.eventsChan <- &mes
|
||||
|
||||
e.logger.Info("", args...)
|
||||
}
|
||||
}
|
||||
@@ -1,295 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type MockVsockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockVsockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockVsockListener) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockVsockListener) Addr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
var _ net.Conn = (*MockConn)(nil)
|
||||
|
||||
type MockConn struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
logger := &slog.Logger{}
|
||||
reportBrokenConnection := func(address string) {}
|
||||
eventsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
e, err := New(logger, reportBrokenConnection, eventsChan)
|
||||
|
||||
if vsockDeviceExists() {
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, e)
|
||||
assert.IsType(t, &events{}, e)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
mockListener := new(MockVsockListener)
|
||||
mockConn := new(MockConn)
|
||||
|
||||
e := &events{
|
||||
lis: mockListener,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
mockListener.On("Accept").Return(mockConn, fmt.Errorf("mock error")).Once()
|
||||
mockListener.On("Accept").Return(mockConn, nil)
|
||||
mockConn.On("Close").Return(nil)
|
||||
mockConn.On("Read", mock.Anything).Return(0, nil)
|
||||
|
||||
go e.Listen(context.Background())
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListenContextDone(t *testing.T) {
|
||||
mockListener := new(MockVsockListener)
|
||||
mockConn := new(MockConn)
|
||||
|
||||
e := &events{
|
||||
lis: mockListener,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mockListener.On("Accept").Return(mockConn, nil)
|
||||
|
||||
e.Listen(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func vsockDeviceExists() bool {
|
||||
fs, err := os.Stat("/dev/vsock")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if fs.Mode()&os.ModeDevice == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type MockConnWithBuffer struct {
|
||||
mock.Mock
|
||||
readBuf *bytes.Buffer
|
||||
writeBuf *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewMockConnWithBuffer() *MockConnWithBuffer {
|
||||
return &MockConnWithBuffer{
|
||||
readBuf: new(bytes.Buffer),
|
||||
writeBuf: new(bytes.Buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) Read(b []byte) (n int, err error) {
|
||||
return m.readBuf.Read(b)
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) Write(b []byte) (n int, err error) {
|
||||
return m.writeBuf.Write(b)
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) RemoteAddr() net.Addr {
|
||||
return &net.IPAddr{IP: net.ParseIP("localhost")}
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message *manager.ClientStreamMessage
|
||||
}{
|
||||
{
|
||||
name: "handle agent event",
|
||||
message: &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: "test_event",
|
||||
ComputationId: "test_computation",
|
||||
Status: "test_status",
|
||||
Originator: "test_originator",
|
||||
Timestamp: timestamppb.Now(),
|
||||
Details: []byte("test_details"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle agent log",
|
||||
message: &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
ComputationId: "test_computation",
|
||||
Timestamp: timestamppb.Now(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := NewMockConnWithBuffer()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
|
||||
e := &events{
|
||||
logger: mglog.NewMock(),
|
||||
eventsChan: eventsChan,
|
||||
reportBrokenConnection: func(address string) {},
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(tt.message)
|
||||
assert.NoError(t, err)
|
||||
|
||||
messageID := uint32(1)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, messageID)
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(len(data)))
|
||||
assert.NoError(t, err)
|
||||
_, err = mockConn.readBuf.Write(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add EOF to signal end of stream
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
e.handleConnection(mockConn)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
var receivedMessage *manager.ClientStreamMessage
|
||||
select {
|
||||
case receivedMessage = <-eventsChan:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for message in eventsChan")
|
||||
}
|
||||
|
||||
assert.NotNil(t, receivedMessage)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// handleConnection has exited
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for handleConnection to exit")
|
||||
}
|
||||
|
||||
// Check if ack was written
|
||||
var receivedAck uint32
|
||||
err = binary.Read(mockConn.writeBuf, binary.LittleEndian, &receivedAck)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, messageID, receivedAck)
|
||||
|
||||
// Ensure no unexpected calls were made on the mock
|
||||
mockConn.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
+136
-1250
File diff suppressed because it is too large
Load Diff
+12
-98
@@ -3,40 +3,34 @@
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
package manager;
|
||||
|
||||
option go_package = "./manager";
|
||||
|
||||
service ManagerService {
|
||||
rpc Process(stream ClientStreamMessage) returns (stream ServerStreamMessage) {}
|
||||
rpc CreateVm(google.protobuf.Empty) returns (CreateRes) {}
|
||||
rpc RemoveVm(RemoveReq) returns (google.protobuf.Empty) {}
|
||||
rpc SVMInfo(SVMInfoReq) returns (SVMInfoRes) {}
|
||||
rpc AttestationPolicy(AttestationPolicyReq) returns (AttestationPolicyRes) {}
|
||||
}
|
||||
|
||||
message Terminate {
|
||||
string message = 1;
|
||||
message CreateRes{
|
||||
string forwarded_port = 1;
|
||||
string svm_id = 2;
|
||||
}
|
||||
|
||||
message StopComputation {
|
||||
string computation_id = 1;
|
||||
message RemoveReq{
|
||||
string svm_id = 1;
|
||||
}
|
||||
|
||||
message StopComputationResponse {
|
||||
string computation_id = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message RunResponse{
|
||||
string agent_port = 1;
|
||||
string computation_id = 2;
|
||||
}
|
||||
|
||||
message AttestationPolicy{
|
||||
message AttestationPolicyRes{
|
||||
bytes info = 1;
|
||||
string id = 2;
|
||||
}
|
||||
|
||||
message SVMInfo{
|
||||
message SVMInfoRes{
|
||||
string id = 1;
|
||||
string ovmf_version = 2;
|
||||
int32 cpu_num = 3;
|
||||
@@ -45,60 +39,6 @@ message SVMInfo{
|
||||
string eos_version = 6;
|
||||
}
|
||||
|
||||
message AgentEvent {
|
||||
string event_type = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
string computation_id = 3;
|
||||
bytes details = 4;
|
||||
string originator = 5;
|
||||
string status = 6;
|
||||
}
|
||||
|
||||
message AgentLog {
|
||||
string message = 1;
|
||||
string computation_id = 2;
|
||||
string level = 3;
|
||||
google.protobuf.Timestamp timestamp = 4;
|
||||
}
|
||||
|
||||
message ClientStreamMessage {
|
||||
oneof message {
|
||||
AgentLog agent_log = 1;
|
||||
AgentEvent agent_event = 2;
|
||||
RunResponse run_res = 3;
|
||||
AttestationPolicy attestationPolicy = 4;
|
||||
StopComputationResponse stopComputationRes = 5;
|
||||
SVMInfo svm_info = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message ServerStreamMessage {
|
||||
oneof message {
|
||||
RunReqChunks runReqChunks = 1;
|
||||
ComputationRunReq runReq = 2;
|
||||
Terminate terminateReq = 3;
|
||||
StopComputation stopComputation = 4;
|
||||
AttestationPolicyReq attestationPolicyReq = 5;
|
||||
SVMInfoReq svmInfoReq = 6;
|
||||
}
|
||||
}
|
||||
|
||||
message RunReqChunks {
|
||||
bytes data = 1;
|
||||
string id = 2;
|
||||
bool is_last = 3;
|
||||
}
|
||||
|
||||
message ComputationRunReq {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
string description = 3;
|
||||
repeated Dataset datasets = 4;
|
||||
Algorithm algorithm = 5;
|
||||
repeated ResultConsumer result_consumers = 6;
|
||||
AgentConfig agent_config = 7;
|
||||
}
|
||||
|
||||
message AttestationPolicyReq {
|
||||
string id = 1;
|
||||
}
|
||||
@@ -107,29 +47,3 @@ message SVMInfoReq {
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message ResultConsumer {
|
||||
bytes userKey = 1;
|
||||
}
|
||||
|
||||
message Dataset {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
string filename = 3;
|
||||
}
|
||||
|
||||
message Algorithm {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
}
|
||||
|
||||
message AgentConfig {
|
||||
string port = 1;
|
||||
string host = 2;
|
||||
string cert_file = 3;
|
||||
string key_file = 4;
|
||||
string client_ca_file = 5;
|
||||
string server_ca_file = 6;
|
||||
string log_level = 7;
|
||||
bool attested_tls = 8;
|
||||
}
|
||||
|
||||
|
||||
+142
-21
@@ -14,6 +14,7 @@ import (
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
@@ -22,14 +23,20 @@ import (
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
ManagerService_Process_FullMethodName = "/manager.ManagerService/Process"
|
||||
ManagerService_CreateVm_FullMethodName = "/manager.ManagerService/CreateVm"
|
||||
ManagerService_RemoveVm_FullMethodName = "/manager.ManagerService/RemoveVm"
|
||||
ManagerService_SVMInfo_FullMethodName = "/manager.ManagerService/SVMInfo"
|
||||
ManagerService_AttestationPolicy_FullMethodName = "/manager.ManagerService/AttestationPolicy"
|
||||
)
|
||||
|
||||
// ManagerServiceClient is the client API for ManagerService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ManagerServiceClient interface {
|
||||
Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error)
|
||||
CreateVm(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*CreateRes, error)
|
||||
RemoveVm(ctx context.Context, in *RemoveReq, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
SVMInfo(ctx context.Context, in *SVMInfoReq, opts ...grpc.CallOption) (*SVMInfoRes, error)
|
||||
AttestationPolicy(ctx context.Context, in *AttestationPolicyReq, opts ...grpc.CallOption) (*AttestationPolicyRes, error)
|
||||
}
|
||||
|
||||
type managerServiceClient struct {
|
||||
@@ -40,24 +47,54 @@ func NewManagerServiceClient(cc grpc.ClientConnInterface) ManagerServiceClient {
|
||||
return &managerServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *managerServiceClient) Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error) {
|
||||
func (c *managerServiceClient) CreateVm(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*CreateRes, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &ManagerService_ServiceDesc.Streams[0], ManagerService_Process_FullMethodName, cOpts...)
|
||||
out := new(CreateRes)
|
||||
err := c.cc.Invoke(ctx, ManagerService_CreateVm_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[ClientStreamMessage, ServerStreamMessage]{ClientStream: stream}
|
||||
return x, nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type ManagerService_ProcessClient = grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage]
|
||||
func (c *managerServiceClient) RemoveVm(ctx context.Context, in *RemoveReq, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, ManagerService_RemoveVm_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managerServiceClient) SVMInfo(ctx context.Context, in *SVMInfoReq, opts ...grpc.CallOption) (*SVMInfoRes, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(SVMInfoRes)
|
||||
err := c.cc.Invoke(ctx, ManagerService_SVMInfo_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managerServiceClient) AttestationPolicy(ctx context.Context, in *AttestationPolicyReq, opts ...grpc.CallOption) (*AttestationPolicyRes, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(AttestationPolicyRes)
|
||||
err := c.cc.Invoke(ctx, ManagerService_AttestationPolicy_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ManagerServiceServer is the server API for ManagerService service.
|
||||
// All implementations must embed UnimplementedManagerServiceServer
|
||||
// for forward compatibility.
|
||||
type ManagerServiceServer interface {
|
||||
Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error
|
||||
CreateVm(context.Context, *emptypb.Empty) (*CreateRes, error)
|
||||
RemoveVm(context.Context, *RemoveReq) (*emptypb.Empty, error)
|
||||
SVMInfo(context.Context, *SVMInfoReq) (*SVMInfoRes, error)
|
||||
AttestationPolicy(context.Context, *AttestationPolicyReq) (*AttestationPolicyRes, error)
|
||||
mustEmbedUnimplementedManagerServiceServer()
|
||||
}
|
||||
|
||||
@@ -68,8 +105,17 @@ type ManagerServiceServer interface {
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedManagerServiceServer struct{}
|
||||
|
||||
func (UnimplementedManagerServiceServer) Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Process not implemented")
|
||||
func (UnimplementedManagerServiceServer) CreateVm(context.Context, *emptypb.Empty) (*CreateRes, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateVm not implemented")
|
||||
}
|
||||
func (UnimplementedManagerServiceServer) RemoveVm(context.Context, *RemoveReq) (*emptypb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RemoveVm not implemented")
|
||||
}
|
||||
func (UnimplementedManagerServiceServer) SVMInfo(context.Context, *SVMInfoReq) (*SVMInfoRes, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SVMInfo not implemented")
|
||||
}
|
||||
func (UnimplementedManagerServiceServer) AttestationPolicy(context.Context, *AttestationPolicyReq) (*AttestationPolicyRes, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method AttestationPolicy not implemented")
|
||||
}
|
||||
func (UnimplementedManagerServiceServer) mustEmbedUnimplementedManagerServiceServer() {}
|
||||
func (UnimplementedManagerServiceServer) testEmbeddedByValue() {}
|
||||
@@ -92,12 +138,77 @@ func RegisterManagerServiceServer(s grpc.ServiceRegistrar, srv ManagerServiceSer
|
||||
s.RegisterService(&ManagerService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _ManagerService_Process_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(ManagerServiceServer).Process(&grpc.GenericServerStream[ClientStreamMessage, ServerStreamMessage]{ServerStream: stream})
|
||||
func _ManagerService_CreateVm_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(emptypb.Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagerServiceServer).CreateVm(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ManagerService_CreateVm_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagerServiceServer).CreateVm(ctx, req.(*emptypb.Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type ManagerService_ProcessServer = grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]
|
||||
func _ManagerService_RemoveVm_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RemoveReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagerServiceServer).RemoveVm(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ManagerService_RemoveVm_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagerServiceServer).RemoveVm(ctx, req.(*RemoveReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagerService_SVMInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(SVMInfoReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagerServiceServer).SVMInfo(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ManagerService_SVMInfo_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagerServiceServer).SVMInfo(ctx, req.(*SVMInfoReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagerService_AttestationPolicy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(AttestationPolicyReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagerServiceServer).AttestationPolicy(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: ManagerService_AttestationPolicy_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagerServiceServer).AttestationPolicy(ctx, req.(*AttestationPolicyReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ManagerService_ServiceDesc is the grpc.ServiceDesc for ManagerService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
@@ -105,14 +216,24 @@ type ManagerService_ProcessServer = grpc.BidiStreamingServer[ClientStreamMessage
|
||||
var ManagerService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "manager.ManagerService",
|
||||
HandlerType: (*ManagerServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
StreamName: "Process",
|
||||
Handler: _ManagerService_Process_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
MethodName: "CreateVm",
|
||||
Handler: _ManagerService_CreateVm_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RemoveVm",
|
||||
Handler: _ManagerService_RemoveVm_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SVMInfo",
|
||||
Handler: _ManagerService_SVMInfo_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "AttestationPolicy",
|
||||
Handler: _ManagerService_AttestationPolicy_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "manager/manager.proto",
|
||||
}
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestProcess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := manager.NewManagerServiceClient(conn)
|
||||
stream, err := client.Process(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Process failed: %v", err)
|
||||
}
|
||||
|
||||
var data bytes.Buffer
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to receive ServerStreamMessage: %v", err)
|
||||
}
|
||||
|
||||
switch m := msg.Message.(type) {
|
||||
case *manager.ServerStreamMessage_TerminateReq:
|
||||
if m.TerminateReq.Message != "test terminate" {
|
||||
t.Fatalf("Unexpected terminate message: %v", m.TerminateReq.Message)
|
||||
}
|
||||
case *manager.ServerStreamMessage_RunReqChunks:
|
||||
if len(m.RunReqChunks.Data) == 0 {
|
||||
var runReq manager.ComputationRunReq
|
||||
if err = proto.Unmarshal(data.Bytes(), &runReq); err != nil {
|
||||
t.Fatalf("Failed to create run request: %v", err)
|
||||
}
|
||||
|
||||
runRes := &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
Message: "test log",
|
||||
ComputationId: "comp1",
|
||||
Level: "DEBUG",
|
||||
},
|
||||
}
|
||||
if runReq.Id != "1" || runReq.Name != "sample computation" || runReq.Description != "sample description" {
|
||||
t.Fatalf("Unexpected run request message: %v", &runReq)
|
||||
}
|
||||
if err := stream.Send(&manager.ClientStreamMessage{Message: runRes}); err != nil {
|
||||
t.Fatalf("Failed to send ClientStreamMessage: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
data.Write(m.RunReqChunks.Data)
|
||||
default:
|
||||
t.Fatalf("Unexpected message type: %T", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
+91
-119
@@ -9,7 +9,6 @@ import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
manager "github.com/ultravioletrs/cocos/manager"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
@@ -25,6 +24,69 @@ func (_m *Service) EXPECT() *Service_Expecter {
|
||||
return &Service_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CreateVM provides a mock function with given fields: ctx
|
||||
func (_m *Service) CreateVM(ctx context.Context) (string, string, error) {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CreateVM")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r1 string
|
||||
var r2 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) (string, string, error)); ok {
|
||||
return rf(ctx)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context) string); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context) string); ok {
|
||||
r1 = rf(ctx)
|
||||
} else {
|
||||
r1 = ret.Get(1).(string)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(2).(func(context.Context) error); ok {
|
||||
r2 = rf(ctx)
|
||||
} else {
|
||||
r2 = ret.Error(2)
|
||||
}
|
||||
|
||||
return r0, r1, r2
|
||||
}
|
||||
|
||||
// Service_CreateVM_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateVM'
|
||||
type Service_CreateVM_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CreateVM is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) CreateVM(ctx interface{}) *Service_CreateVM_Call {
|
||||
return &Service_CreateVM_Call{Call: _e.mock.On("CreateVM", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_CreateVM_Call) Run(run func(ctx context.Context)) *Service_CreateVM_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_CreateVM_Call) Return(_a0 string, _a1 string, _a2 error) *Service_CreateVM_Call {
|
||||
_c.Call.Return(_a0, _a1, _a2)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_CreateVM_Call) RunAndReturn(run func(context.Context) (string, string, error)) *Service_CreateVM_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// FetchAttestationPolicy provides a mock function with given fields: ctx, computationID
|
||||
func (_m *Service) FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error) {
|
||||
ret := _m.Called(ctx, computationID)
|
||||
@@ -84,35 +146,49 @@ func (_c *Service_FetchAttestationPolicy_Call) RunAndReturn(run func(context.Con
|
||||
return _c
|
||||
}
|
||||
|
||||
// ReportBrokenConnection provides a mock function with given fields: addr
|
||||
func (_m *Service) ReportBrokenConnection(addr string) {
|
||||
_m.Called(addr)
|
||||
// RemoveVM provides a mock function with given fields: ctx, computationID
|
||||
func (_m *Service) RemoveVM(ctx context.Context, computationID string) error {
|
||||
ret := _m.Called(ctx, computationID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RemoveVM")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
|
||||
r0 = rf(ctx, computationID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_ReportBrokenConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportBrokenConnection'
|
||||
type Service_ReportBrokenConnection_Call struct {
|
||||
// Service_RemoveVM_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveVM'
|
||||
type Service_RemoveVM_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ReportBrokenConnection is a helper method to define mock.On call
|
||||
// - addr string
|
||||
func (_e *Service_Expecter) ReportBrokenConnection(addr interface{}) *Service_ReportBrokenConnection_Call {
|
||||
return &Service_ReportBrokenConnection_Call{Call: _e.mock.On("ReportBrokenConnection", addr)}
|
||||
// RemoveVM is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - computationID string
|
||||
func (_e *Service_Expecter) RemoveVM(ctx interface{}, computationID interface{}) *Service_RemoveVM_Call {
|
||||
return &Service_RemoveVM_Call{Call: _e.mock.On("RemoveVM", ctx, computationID)}
|
||||
}
|
||||
|
||||
func (_c *Service_ReportBrokenConnection_Call) Run(run func(addr string)) *Service_ReportBrokenConnection_Call {
|
||||
func (_c *Service_RemoveVM_Call) Run(run func(ctx context.Context, computationID string)) *Service_RemoveVM_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_ReportBrokenConnection_Call) Return() *Service_ReportBrokenConnection_Call {
|
||||
_c.Call.Return()
|
||||
func (_c *Service_RemoveVM_Call) Return(_a0 error) *Service_RemoveVM_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_ReportBrokenConnection_Call) RunAndReturn(run func(string)) *Service_ReportBrokenConnection_Call {
|
||||
func (_c *Service_RemoveVM_Call) RunAndReturn(run func(context.Context, string) error) *Service_RemoveVM_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -187,110 +263,6 @@ func (_c *Service_ReturnSVMInfo_Call) RunAndReturn(run func(context.Context) (st
|
||||
return _c
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: ctx, c
|
||||
func (_m *Service) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) {
|
||||
ret := _m.Called(ctx, c)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) (string, error)); ok {
|
||||
return rf(ctx, c)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) string); ok {
|
||||
r0 = rf(ctx, c)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *manager.ComputationRunReq) error); ok {
|
||||
r1 = rf(ctx, c)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
|
||||
type Service_Run_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Run is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - c *manager.ComputationRunReq
|
||||
func (_e *Service_Expecter) Run(ctx interface{}, c interface{}) *Service_Run_Call {
|
||||
return &Service_Run_Call{Call: _e.mock.On("Run", ctx, c)}
|
||||
}
|
||||
|
||||
func (_c *Service_Run_Call) Run(run func(ctx context.Context, c *manager.ComputationRunReq)) *Service_Run_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(*manager.ComputationRunReq))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Run_Call) Return(_a0 string, _a1 error) *Service_Run_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Run_Call) RunAndReturn(run func(context.Context, *manager.ComputationRunReq) (string, error)) *Service_Run_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields: ctx, computationID
|
||||
func (_m *Service) Stop(ctx context.Context, computationID string) error {
|
||||
ret := _m.Called(ctx, computationID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
|
||||
r0 = rf(ctx, computationID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type Service_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - computationID string
|
||||
func (_e *Service_Expecter) Stop(ctx interface{}, computationID interface{}) *Service_Stop_Call {
|
||||
return &Service_Stop_Call{Call: _e.mock.On("Stop", ctx, computationID)}
|
||||
}
|
||||
|
||||
func (_c *Service_Stop_Call) Run(run func(ctx context.Context, computationID string)) *Service_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Stop_Call) Return(_a0 error) *Service_Stop_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Stop_Call) RunAndReturn(run func(context.Context, string) error) *Service_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
|
||||
+10
-30
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,19 +30,17 @@ type VMInfo struct {
|
||||
}
|
||||
|
||||
type qemuVM struct {
|
||||
vmi VMInfo
|
||||
cmd *exec.Cmd
|
||||
eventsLogsSender vm.EventSender
|
||||
computationId string
|
||||
vmi VMInfo
|
||||
cmd *exec.Cmd
|
||||
computationId string
|
||||
vm.StateMachine
|
||||
}
|
||||
|
||||
func NewVM(config interface{}, eventsLogsSender vm.EventSender, computationId string) vm.VM {
|
||||
func NewVM(config interface{}, computationId string) vm.VM {
|
||||
return &qemuVM{
|
||||
vmi: config.(VMInfo),
|
||||
eventsLogsSender: eventsLogsSender,
|
||||
computationId: computationId,
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
vmi: config.(VMInfo),
|
||||
computationId: computationId,
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,8 +76,8 @@ func (v *qemuVM) Start() (err error) {
|
||||
}
|
||||
|
||||
v.cmd = exec.Command(exe, args...)
|
||||
v.cmd.Stdout = &vm.Stdout{ComputationId: v.computationId, EventSender: v.eventsLogsSender}
|
||||
v.cmd.Stderr = &vm.Stderr{EventSender: v.eventsLogsSender, ComputationId: v.computationId, StateMachine: v.StateMachine}
|
||||
v.cmd.Stdout = os.Stdout
|
||||
v.cmd.Stderr = os.Stderr
|
||||
|
||||
return v.cmd.Start()
|
||||
}
|
||||
@@ -89,15 +86,7 @@ func (v *qemuVM) Stop() error {
|
||||
defer func() {
|
||||
err := v.StateMachine.Transition(manager.StopComputationRun)
|
||||
if err != nil {
|
||||
if err := v.eventsLogsSender(&vm.Event{
|
||||
EventType: v.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: v.computationId,
|
||||
Originator: "manager",
|
||||
Status: manager.Warning.String(),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
err := v.cmd.Process.Signal(syscall.SIGTERM)
|
||||
@@ -163,15 +152,6 @@ func (v *qemuVM) executableAndArgs() (string, []string, error) {
|
||||
func (v *qemuVM) checkVMProcessPeriodically() {
|
||||
for {
|
||||
if !processExists(v.GetProcess()) {
|
||||
if err := v.eventsLogsSender(&vm.Event{
|
||||
EventType: v.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: v.computationId,
|
||||
Originator: "manager",
|
||||
Status: manager.Stopped.String(),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
time.Sleep(interval)
|
||||
|
||||
+3
-36
@@ -6,10 +6,8 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/manager/vm/mocks"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
@@ -19,7 +17,7 @@ const testComputationID = "test-computation"
|
||||
func TestNewVM(t *testing.T) {
|
||||
config := VMInfo{Config: Config{}}
|
||||
|
||||
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID)
|
||||
vm := NewVM(config, testComputationID)
|
||||
|
||||
assert.NotNil(t, vm)
|
||||
assert.IsType(t, &qemuVM{}, vm)
|
||||
@@ -38,7 +36,7 @@ func TestStart(t *testing.T) {
|
||||
QemuBinPath: "echo",
|
||||
}}
|
||||
|
||||
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
|
||||
vm := NewVM(config, testComputationID).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
@@ -61,7 +59,7 @@ func TestStartSudo(t *testing.T) {
|
||||
UseSudo: true,
|
||||
}}
|
||||
|
||||
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
|
||||
vm := NewVM(config, testComputationID).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
@@ -101,9 +99,6 @@ func TestStop(t *testing.T) {
|
||||
Process: cmd.Process,
|
||||
},
|
||||
StateMachine: sm,
|
||||
eventsLogsSender: func(event interface{}) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err = vm.Stop()
|
||||
@@ -165,31 +160,3 @@ func TestGetConfig(t *testing.T) {
|
||||
config := vm.GetConfig()
|
||||
assert.Equal(t, expectedConfig, config)
|
||||
}
|
||||
|
||||
func TestCheckVMProcessPeriodically(t *testing.T) {
|
||||
logsChan := make(chan interface{}, 1)
|
||||
vmi := &qemuVM{
|
||||
eventsLogsSender: func(event interface{}) error {
|
||||
logsChan <- event
|
||||
return nil
|
||||
},
|
||||
computationId: testComputationID,
|
||||
cmd: &exec.Cmd{
|
||||
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
|
||||
},
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
}
|
||||
|
||||
go vmi.checkVMProcessPeriodically()
|
||||
|
||||
select {
|
||||
case msg := <-logsChan:
|
||||
assert.NotNil(t, msg)
|
||||
msgE := msg.(*vm.Event)
|
||||
assert.Equal(t, testComputationID, msgE.ComputationId)
|
||||
assert.Equal(t, pkgmanager.VmProvision.String(), msgE.EventType)
|
||||
assert.Equal(t, pkgmanager.Stopped.String(), msgE.Status)
|
||||
case <-time.After(2 * interval):
|
||||
t.Fatal("Timeout waiting for VM stopped message")
|
||||
}
|
||||
}
|
||||
|
||||
+21
-127
@@ -5,7 +5,6 @@ package manager
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -17,15 +16,13 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/google/uuid"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,8 +45,6 @@ var (
|
||||
// ErrFailedToAllocatePort indicates no free port was found on host.
|
||||
ErrFailedToAllocatePort = errors.New("failed to allocate free port on host")
|
||||
|
||||
errInvalidHashLength = errors.New("hash must be of byte length 32")
|
||||
|
||||
// ErrFailedToCalculateHash indicates that agent computation returned an error while calculating the hash of the computation.
|
||||
ErrFailedToCalculateHash = errors.New("error while calculating the hash of the computation")
|
||||
|
||||
@@ -67,13 +62,11 @@ var (
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
type Service interface {
|
||||
// Run create a computation.
|
||||
Run(ctx context.Context, c *ComputationRunReq) (string, error)
|
||||
CreateVM(ctx context.Context) (string, string, error)
|
||||
// Stop stops a computation.
|
||||
Stop(ctx context.Context, computationID string) error
|
||||
RemoveVM(ctx context.Context, computationID string) error
|
||||
// FetchAttestationPolicy measures and fetches the attestation policy.
|
||||
FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error)
|
||||
// ReportBrokenConnection reports a broken connection.
|
||||
ReportBrokenConnection(addr string)
|
||||
// ReturnSVMInfo returns SVM information needed for attestation verification and validation.
|
||||
ReturnSVMInfo(ctx context.Context) (string, int, string, string)
|
||||
}
|
||||
@@ -84,7 +77,6 @@ type managerService struct {
|
||||
qemuCfg qemu.Config
|
||||
attestationPolicyBinaryPath string
|
||||
logger *slog.Logger
|
||||
eventsChan chan *ClientStreamMessage
|
||||
vms map[string]vm.VM
|
||||
vmFactory vm.Provider
|
||||
portRangeMin int
|
||||
@@ -96,7 +88,7 @@ type managerService struct {
|
||||
var _ Service = (*managerService)(nil)
|
||||
|
||||
// New instantiates the manager service implementation.
|
||||
func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger, eventsChan chan *ClientStreamMessage, vmFactory vm.Provider, eosVersion string) (Service, error) {
|
||||
func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string) (Service, error) {
|
||||
start, end, err := decodeRange(cfg.HostFwdRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -111,7 +103,6 @@ func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger,
|
||||
qemuCfg: cfg,
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
vmFactory: vmFactory,
|
||||
attestationPolicyBinaryPath: attestationPolicyBinPath,
|
||||
portRangeMin: start,
|
||||
@@ -127,7 +118,8 @@ func New(cfg qemu.Config, attestationPolicyBinPath string, logger *slog.Logger,
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string, error) {
|
||||
func (ms *managerService) CreateVM(ctx context.Context) (string, string, error) {
|
||||
id := uuid.New().String()
|
||||
ms.mu.Lock()
|
||||
cfg := qemu.VMInfo{
|
||||
Config: ms.qemuCfg,
|
||||
@@ -142,54 +134,29 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
|
||||
_, err := cmd.Output()
|
||||
ms.ap.Unlock()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(ErrFailedToCreateAttestationPolicy, err)
|
||||
return "", id, errors.Wrap(ErrFailedToCreateAttestationPolicy, err)
|
||||
}
|
||||
|
||||
ms.ap.Lock()
|
||||
f, err := os.ReadFile("./attestation_policy.json")
|
||||
ms.ap.Unlock()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(ErrFailedToReadPolicy, err)
|
||||
return "", id, errors.Wrap(ErrFailedToReadPolicy, err)
|
||||
}
|
||||
|
||||
var attestationPolicy check.Config
|
||||
|
||||
if err = protojson.Unmarshal(f, &attestationPolicy); err != nil {
|
||||
return "", errors.Wrap(ErrUnmarshalFailed, err)
|
||||
return "", id, errors.Wrap(ErrUnmarshalFailed, err)
|
||||
}
|
||||
|
||||
// Define the TCB that was present at launch of the VM.
|
||||
cfg.LaunchTCB = attestationPolicy.Policy.MinimumLaunchTcb
|
||||
}
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, manager.Starting.String(), json.RawMessage{})
|
||||
ac := agent.Computation{
|
||||
ID: c.Id,
|
||||
Name: c.Name,
|
||||
Description: c.Description,
|
||||
}
|
||||
if len(c.Algorithm.Hash) != hashLength {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
return "", errInvalidHashLength
|
||||
}
|
||||
|
||||
ac.Algorithm = agent.Algorithm{Hash: [hashLength]byte(c.Algorithm.Hash), UserKey: c.Algorithm.UserKey}
|
||||
|
||||
for _, data := range c.Datasets {
|
||||
if len(data.Hash) != hashLength {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
return "", errInvalidHashLength
|
||||
}
|
||||
ac.Datasets = append(ac.Datasets, agent.Dataset{Hash: [hashLength]byte(data.Hash), UserKey: data.UserKey, Filename: data.Filename})
|
||||
}
|
||||
|
||||
for _, rc := range c.ResultConsumers {
|
||||
ac.ResultConsumers = append(ac.ResultConsumers, agent.ResultConsumer{UserKey: rc.UserKey})
|
||||
}
|
||||
|
||||
agentPort, err := getFreePort(ms.portRangeMin, ms.portRangeMax)
|
||||
if err != nil {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
return "", errors.Wrap(ErrFailedToAllocatePort, err)
|
||||
return "", id, errors.Wrap(ErrFailedToAllocatePort, err)
|
||||
}
|
||||
cfg.Config.HostFwdAgent = agentPort
|
||||
|
||||
@@ -210,30 +177,23 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
|
||||
cfg.Config.VSockConfig.GuestCID = cid
|
||||
|
||||
if cfg.Config.EnableSEVSNP {
|
||||
ch, err := computationHash(ac)
|
||||
if err != nil {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
return "", errors.Wrap(ErrFailedToCalculateHash, err)
|
||||
}
|
||||
|
||||
todo := sha3.Sum256([]byte("TODO"))
|
||||
// Define host-data value of QEMU for SEV-SNP, with a base64 encoding of the computation hash.
|
||||
cfg.Config.SevConfig.HostData = base64.StdEncoding.EncodeToString(ch[:])
|
||||
cfg.Config.SevConfig.HostData = base64.StdEncoding.EncodeToString(todo[:])
|
||||
}
|
||||
|
||||
cvm := ms.vmFactory(cfg, ms.eventsLogsSender, c.Id)
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.InProgress.String(), json.RawMessage{})
|
||||
cvm := ms.vmFactory(cfg, id)
|
||||
if err = cvm.Start(); err != nil {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
return "", err
|
||||
return "", id, err
|
||||
}
|
||||
ms.mu.Lock()
|
||||
ms.vms[c.Id] = cvm
|
||||
ms.vms[id] = cvm
|
||||
ms.mu.Unlock()
|
||||
|
||||
pid := cvm.GetProcess()
|
||||
|
||||
state := qemu.VMState{
|
||||
ID: c.Id,
|
||||
ID: id,
|
||||
VMinfo: cfg,
|
||||
PID: pid,
|
||||
}
|
||||
@@ -241,34 +201,23 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
|
||||
ms.logger.Error("Failed to persist VM state", "error", err)
|
||||
}
|
||||
|
||||
err = backoff.Retry(func() error {
|
||||
return cvm.SendAgentConfig(ac)
|
||||
}, backoff.NewExponentialBackOff())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ms.mu.Lock()
|
||||
if err := ms.vms[c.Id].Transition(manager.VmRunning); err != nil {
|
||||
ms.logger.Warn("Failed to transition VM state", "computation", c.Id, "error", err)
|
||||
if err := ms.vms[id].Transition(manager.VmRunning); err != nil {
|
||||
ms.logger.Warn("Failed to transition VM state", "cvm", id, "error", err)
|
||||
}
|
||||
ms.mu.Unlock()
|
||||
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Completed.String(), json.RawMessage{})
|
||||
|
||||
return fmt.Sprint(agentPort), nil
|
||||
return fmt.Sprint(agentPort), id, nil
|
||||
}
|
||||
|
||||
func (ms *managerService) Stop(ctx context.Context, computationID string) error {
|
||||
func (ms *managerService) RemoveVM(ctx context.Context, computationID string) error {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
cvm, ok := ms.vms[computationID]
|
||||
if !ok {
|
||||
defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Failed.String(), json.RawMessage{})
|
||||
return ErrNotFound
|
||||
}
|
||||
if err := cvm.Stop(); err != nil {
|
||||
defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Failed.String(), json.RawMessage{})
|
||||
return err
|
||||
}
|
||||
delete(ms.vms, computationID)
|
||||
@@ -277,7 +226,6 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error
|
||||
ms.logger.Error("Failed to delete persisted VM state", "error", err)
|
||||
}
|
||||
|
||||
defer ms.publishEvent(manager.StopComputationRun.String(), computationID, agent.Completed.String(), json.RawMessage{})
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -329,30 +277,6 @@ func checkPortisFree(port int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (ms *managerService) publishEvent(event, cmpID, status string, details json.RawMessage) {
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &AgentEvent{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Status: status,
|
||||
Details: details,
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func computationHash(ac agent.Computation) ([32]byte, error) {
|
||||
jsonData, err := json.Marshal(ac)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
|
||||
return sha3.Sum256(jsonData), nil
|
||||
}
|
||||
|
||||
func decodeRange(input string) (int, int, error) {
|
||||
re := regexp.MustCompile(`(\d+)-(\d+)`)
|
||||
matches := re.FindStringSubmatch(input)
|
||||
@@ -392,7 +316,7 @@ func (ms *managerService) restoreVMs() error {
|
||||
continue
|
||||
}
|
||||
|
||||
cvm := ms.vmFactory(state.VMinfo, ms.eventsLogsSender, state.ID)
|
||||
cvm := ms.vmFactory(state.VMinfo, state.ID)
|
||||
|
||||
if err = cvm.SetProcess(state.PID); err != nil {
|
||||
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
|
||||
@@ -425,33 +349,3 @@ func (ms *managerService) processExists(pid int) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ms *managerService) eventsLogsSender(e interface{}) error {
|
||||
switch msg := e.(type) {
|
||||
case *vm.Event:
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &AgentEvent{
|
||||
EventType: msg.EventType,
|
||||
Timestamp: msg.Timestamp,
|
||||
ComputationId: msg.ComputationId,
|
||||
Originator: msg.Originator,
|
||||
Status: msg.Status,
|
||||
Details: msg.Details,
|
||||
},
|
||||
},
|
||||
}
|
||||
case *vm.Log:
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentLog{
|
||||
AgentLog: &AgentLog{
|
||||
ComputationId: msg.ComputationId,
|
||||
Level: msg.Level,
|
||||
Timestamp: msg.Timestamp,
|
||||
Message: msg.Message,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+6
-156
@@ -4,7 +4,6 @@ package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
@@ -29,10 +27,9 @@ func TestNew(t *testing.T) {
|
||||
HostFwdRange: "6000-6100",
|
||||
}
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *ClientStreamMessage)
|
||||
vmf := new(mocks.Provider)
|
||||
|
||||
service, err := New(cfg, "", logger, eventsChan, vmf.Execute, "")
|
||||
service, err := New(cfg, "", logger, vmf.Execute, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, service)
|
||||
@@ -46,82 +43,24 @@ func TestRun(t *testing.T) {
|
||||
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
|
||||
tests := []struct {
|
||||
name string
|
||||
req *ComputationRunReq
|
||||
binaryBehavior string
|
||||
vmStartError error
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Successful run",
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
name: "Successful run",
|
||||
binaryBehavior: "success",
|
||||
vmStartError: nil,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "VM start failure",
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
name: "VM start failure",
|
||||
binaryBehavior: "success",
|
||||
vmStartError: assert.AnError,
|
||||
expectedError: assert.AnError,
|
||||
},
|
||||
{
|
||||
name: "Invalid algorithm hash",
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength-1),
|
||||
},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
binaryBehavior: "success",
|
||||
vmStartError: nil,
|
||||
expectedError: errInvalidHashLength,
|
||||
},
|
||||
{
|
||||
name: "Invalid dataset hash",
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &AgentConfig{},
|
||||
Datasets: []*Dataset{
|
||||
{
|
||||
Hash: make([]byte, hashLength-1),
|
||||
},
|
||||
},
|
||||
},
|
||||
binaryBehavior: "success",
|
||||
vmStartError: nil,
|
||||
expectedError: errInvalidHashLength,
|
||||
},
|
||||
{
|
||||
name: "Invalid attestation policy",
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
name: "Invalid attestation policy",
|
||||
binaryBehavior: "fail",
|
||||
vmStartError: nil,
|
||||
expectedError: ErrFailedToCreateAttestationPolicy,
|
||||
@@ -149,7 +88,6 @@ func TestRun(t *testing.T) {
|
||||
},
|
||||
}
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *ClientStreamMessage, 10)
|
||||
|
||||
tempDir := CreateDummyAttestationPolicyBinary(t, tt.binaryBehavior)
|
||||
defer os.RemoveAll(tempDir)
|
||||
@@ -159,14 +97,13 @@ func TestRun(t *testing.T) {
|
||||
attestationPolicyBinaryPath: tempDir,
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
vmFactory: vmf.Execute,
|
||||
persistence: persistence,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
port, err := ms.Run(ctx, tt.req)
|
||||
port, _, err := ms.CreateVM(ctx)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
@@ -179,10 +116,6 @@ func TestRun(t *testing.T) {
|
||||
}
|
||||
|
||||
vmf.AssertExpectations(t)
|
||||
|
||||
for len(eventsChan) > 0 {
|
||||
<-eventsChan
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -226,11 +159,9 @@ func TestStop(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *ClientStreamMessage, 10)
|
||||
ms := &managerService{
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
persistence: persistence,
|
||||
}
|
||||
vmMock := new(mocks.VM)
|
||||
@@ -247,7 +178,7 @@ func TestStop(t *testing.T) {
|
||||
ms.vms[tt.computationID] = vmMock
|
||||
}
|
||||
|
||||
err := ms.Stop(context.Background(), tt.computationID)
|
||||
err := ms.RemoveVM(context.Background(), tt.computationID)
|
||||
|
||||
if tt.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
@@ -256,10 +187,6 @@ func TestStop(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms.vms, 0)
|
||||
}
|
||||
|
||||
for len(eventsChan) > 0 {
|
||||
<-eventsChan
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -278,82 +205,6 @@ func TestGetFreePort(t *testing.T) {
|
||||
assert.Greater(t, port, 6000)
|
||||
}
|
||||
|
||||
func TestPublishEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
event string
|
||||
computationID string
|
||||
status string
|
||||
details json.RawMessage
|
||||
}{
|
||||
{
|
||||
name: "Standard event",
|
||||
event: "test-event",
|
||||
computationID: "test-computation",
|
||||
status: "test-status",
|
||||
details: nil,
|
||||
},
|
||||
{
|
||||
name: "Event with details",
|
||||
event: "detailed-event",
|
||||
computationID: "detailed-computation",
|
||||
status: "detailed-status",
|
||||
details: json.RawMessage(`{"key": "value"}`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
eventsChan := make(chan *ClientStreamMessage, 1)
|
||||
ms := &managerService{
|
||||
eventsChan: eventsChan,
|
||||
}
|
||||
|
||||
ms.publishEvent(tt.event, tt.computationID, tt.status, tt.details)
|
||||
|
||||
assert.Len(t, eventsChan, 1)
|
||||
event := <-eventsChan
|
||||
assert.Equal(t, tt.event, event.GetAgentEvent().EventType)
|
||||
assert.Equal(t, tt.computationID, event.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, tt.status, event.GetAgentEvent().Status)
|
||||
assert.Equal(t, "manager", event.GetAgentEvent().Originator)
|
||||
assert.Equal(t, tt.details, json.RawMessage(event.GetAgentEvent().Details))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputationHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
computation agent.Computation
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid computation",
|
||||
computation: agent.Computation{
|
||||
ID: "test-id",
|
||||
Name: "test-name",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := computationHash(tt.computation)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, hash)
|
||||
|
||||
hash2, _ := computationHash(tt.computation)
|
||||
assert.Equal(t, hash, hash2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -393,7 +244,6 @@ func TestRestoreVMs(t *testing.T) {
|
||||
ms := &managerService{
|
||||
persistence: mockPersistence,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: make(chan *ClientStreamMessage, 10),
|
||||
vmFactory: vmf.Execute,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const (
|
||||
bufSize = 1024 * 1024
|
||||
keyBitSize = 4096
|
||||
)
|
||||
|
||||
var (
|
||||
lis *bufconn.Listener
|
||||
algoPath = "../test/manual/algo/lin_reg.py"
|
||||
dataPath = "../test/manual/data/iris.csv"
|
||||
attestedTLS = false
|
||||
)
|
||||
|
||||
type svc struct {
|
||||
logger *slog.Logger
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
logger := mglog.NewMock()
|
||||
|
||||
lis = bufconn.Listen(bufSize)
|
||||
s := grpc.NewServer()
|
||||
|
||||
manager.RegisterManagerServiceServer(s, managergrpc.NewServer(make(chan *manager.ClientStreamMessage, 1), &svc{logger: logger}))
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
code := m.Run()
|
||||
|
||||
s.Stop()
|
||||
lis.Close()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func bufDialer(context.Context, string) (net.Conn, error) {
|
||||
return lis.Dial()
|
||||
}
|
||||
|
||||
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc.SendFunc, authInfo credentials.AuthInfo) {
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, keyBitSize)
|
||||
if err != nil {
|
||||
s.t.Fatalf("Error generating public key: %v", err)
|
||||
}
|
||||
|
||||
pubKey, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
if err != nil {
|
||||
s.t.Fatalf("Error marshalling public key: %v", err)
|
||||
}
|
||||
|
||||
pubPemBytes := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKey,
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
if err := sendMessage(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{Message: "test terminate"},
|
||||
},
|
||||
}); err != nil {
|
||||
s.t.Fatalf("failed to send terminate request: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
algo, err := os.ReadFile(algoPath)
|
||||
if err != nil {
|
||||
s.t.Fatalf("failed to read algorithm file: %s", err)
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
s.t.Fatalf("failed to read data file: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
pubPem, _ := pem.Decode(pubPemBytes)
|
||||
algoHash := sha3.Sum256(algo)
|
||||
dataHash := sha3.Sum256(data)
|
||||
|
||||
if err := sendMessage(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: &manager.ComputationRunReq{
|
||||
Id: "1",
|
||||
Name: "sample computation",
|
||||
Description: "sample description",
|
||||
Datasets: []*manager.Dataset{{Hash: dataHash[:], UserKey: pubPem.Bytes}},
|
||||
Algorithm: &manager.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
|
||||
ResultConsumers: []*manager.ResultConsumer{{UserKey: pubPem.Bytes}},
|
||||
AgentConfig: &manager.AgentConfig{
|
||||
Port: "7002",
|
||||
LogLevel: "debug",
|
||||
AttestedTls: attestedTLS,
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
s.t.Fatalf("failed to send run request: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -21,18 +21,18 @@ func New(svc manager.Service, tracer trace.Tracer) manager.Service {
|
||||
return &tracingMiddleware{tracer, svc}
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
|
||||
func (tm *tracingMiddleware) CreateVM(ctx context.Context) (string, string, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "run")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Run(ctx, mc)
|
||||
return tm.svc.CreateVM(ctx)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Stop(ctx context.Context, computationID string) error {
|
||||
func (tm *tracingMiddleware) RemoveVM(ctx context.Context, id string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "stop")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Stop(ctx, computationID)
|
||||
return tm.svc.RemoveVM(ctx, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) FetchAttestationPolicy(ctx context.Context, computationId string) ([]byte, error) {
|
||||
@@ -42,10 +42,6 @@ func (tm *tracingMiddleware) FetchAttestationPolicy(ctx context.Context, computa
|
||||
return tm.svc.FetchAttestationPolicy(ctx, computationId)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ReportBrokenConnection(addr string) {
|
||||
tm.svc.ReportBrokenConnection(addr)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ReturnSVMInfo(ctx context.Context) (string, int, string, string) {
|
||||
_, span := tm.tracer.Start(ctx, "return_svm_info")
|
||||
defer span.End()
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
var (
|
||||
_ io.Writer = &Stdout{}
|
||||
_ io.Writer = &Stderr{}
|
||||
)
|
||||
|
||||
const bufSize = 1024
|
||||
|
||||
type Stdout struct {
|
||||
EventSender EventSender
|
||||
ComputationId string
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (s *Stdout) Write(p []byte) (n int, err error) {
|
||||
inBuf := bytes.NewBuffer(p)
|
||||
|
||||
buf := make([]byte, bufSize)
|
||||
|
||||
for {
|
||||
n, err := inBuf.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
|
||||
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil {
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
type Stderr struct {
|
||||
EventSender EventSender
|
||||
ComputationId string
|
||||
StateMachine StateMachine
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (s *Stderr) Write(p []byte) (n int, err error) {
|
||||
inBuf := bytes.NewBuffer(p)
|
||||
|
||||
buf := make([]byte, bufSize)
|
||||
|
||||
for {
|
||||
n, err := inBuf.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
|
||||
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), ""); err != nil {
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
}
|
||||
|
||||
eventMsg := &Event{
|
||||
ComputationId: s.ComputationId,
|
||||
EventType: s.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
Status: pkgmanager.Warning.String(),
|
||||
}
|
||||
|
||||
return len(p), s.EventSender(eventMsg)
|
||||
}
|
||||
|
||||
func sendLog(eventSender EventSender, computationID, message, level string) error {
|
||||
if len(message) < 3 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if level == "" {
|
||||
if strings.Contains(strings.ToLower(message), "warning") {
|
||||
level = slog.LevelWarn.String()
|
||||
} else {
|
||||
level = slog.LevelError.String()
|
||||
}
|
||||
}
|
||||
|
||||
msg := Log{
|
||||
Message: message,
|
||||
ComputationId: computationID,
|
||||
Level: level,
|
||||
Timestamp: timestamppb.Now(),
|
||||
}
|
||||
|
||||
return eventSender(&msg)
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package vm
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
func TestStdoutWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedWrites int
|
||||
}{
|
||||
{
|
||||
name: "Single write within buffer size",
|
||||
input: "Hello, World!",
|
||||
expectedWrites: 1,
|
||||
},
|
||||
{
|
||||
name: "Multiple writes within buffer size",
|
||||
input: "This is a longer message that will be split into multiple writes.",
|
||||
expectedWrites: 1,
|
||||
},
|
||||
{
|
||||
name: "Large write exceeding buffer size",
|
||||
input: string(make([]byte, bufSize*2+3)),
|
||||
expectedWrites: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stdout{
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return nil
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
n, err := s.Write([]byte(tt.input))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.input), n)
|
||||
|
||||
var receivedWrites int
|
||||
for i := 0; i < tt.expectedWrites; i++ {
|
||||
select {
|
||||
case msg := <-eventLogChan:
|
||||
receivedWrites++
|
||||
agentLog := msg.(*Log)
|
||||
assert.NotNil(t, agentLog)
|
||||
assert.Equal(t, "test-computation", agentLog.ComputationId)
|
||||
assert.Equal(t, slog.LevelDebug.String(), agentLog.Level)
|
||||
assert.NotEmpty(t, agentLog.Message)
|
||||
assert.NotNil(t, agentLog.Timestamp)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for log message")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedWrites, receivedWrites)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStderrWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedWrites int
|
||||
}{
|
||||
{
|
||||
name: "Single write within buffer size",
|
||||
input: "Error: Something went wrong",
|
||||
expectedWrites: 1,
|
||||
},
|
||||
{
|
||||
name: "Multiple writes within buffer size",
|
||||
input: "This is a longer error message that will be split into multiple writes.",
|
||||
expectedWrites: 1,
|
||||
},
|
||||
{
|
||||
name: "Large write exceeding buffer size",
|
||||
input: string(make([]byte, bufSize*2)),
|
||||
expectedWrites: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stderr{
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return nil
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
StateMachine: NewStateMachine(),
|
||||
}
|
||||
|
||||
err := s.StateMachine.Transition(pkgmanager.VmRunning)
|
||||
assert.NoError(t, err)
|
||||
|
||||
n, err := s.Write([]byte(tt.input))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.input), n)
|
||||
|
||||
var receivedWrites int
|
||||
for i := 0; i < tt.expectedWrites; i++ {
|
||||
select {
|
||||
case msg := <-eventLogChan:
|
||||
receivedWrites++
|
||||
switch logEv := msg.(type) {
|
||||
case *Log:
|
||||
assert.NotNil(t, logEv)
|
||||
assert.Equal(t, "test-computation", logEv.ComputationId)
|
||||
assert.Equal(t, slog.LevelError.String(), logEv.Level)
|
||||
assert.NotEmpty(t, logEv.Message)
|
||||
assert.NotNil(t, logEv.Timestamp)
|
||||
case *Event:
|
||||
assert.NotNil(t, logEv)
|
||||
assert.Equal(t, "test-computation", logEv.ComputationId)
|
||||
assert.Equal(t, pkgmanager.VmRunning.String(), logEv.EventType)
|
||||
assert.Equal(t, pkgmanager.Warning.String(), logEv.Status)
|
||||
assert.NotNil(t, logEv.Timestamp)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for log message")
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedWrites, receivedWrites)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStdoutWriteErrorHandling(t *testing.T) {
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stdout{
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return assert.AnError
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
message := []byte("This should fail")
|
||||
n, err := s.Write(message)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, len(message), n)
|
||||
assert.Equal(t, assert.AnError, err)
|
||||
}
|
||||
|
||||
func TestStderrWriteErrorHandling(t *testing.T) {
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stderr{
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return assert.AnError
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
message := []byte("This should fail")
|
||||
n, err := s.Write(message)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, len(message), n)
|
||||
assert.Equal(t, assert.AnError, err)
|
||||
}
|
||||
@@ -23,17 +23,17 @@ func (_m *Provider) EXPECT() *Provider_Expecter {
|
||||
return &Provider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Execute provides a mock function with given fields: config, eventSender, computationId
|
||||
func (_m *Provider) Execute(config interface{}, eventSender vm.EventSender, computationId string) vm.VM {
|
||||
ret := _m.Called(config, eventSender, computationId)
|
||||
// Execute provides a mock function with given fields: config, computationId
|
||||
func (_m *Provider) Execute(config interface{}, computationId string) vm.VM {
|
||||
ret := _m.Called(config, computationId)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Execute")
|
||||
}
|
||||
|
||||
var r0 vm.VM
|
||||
if rf, ok := ret.Get(0).(func(interface{}, vm.EventSender, string) vm.VM); ok {
|
||||
r0 = rf(config, eventSender, computationId)
|
||||
if rf, ok := ret.Get(0).(func(interface{}, string) vm.VM); ok {
|
||||
r0 = rf(config, computationId)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(vm.VM)
|
||||
@@ -50,15 +50,14 @@ type Provider_Execute_Call struct {
|
||||
|
||||
// Execute is a helper method to define mock.On call
|
||||
// - config interface{}
|
||||
// - eventSender vm.EventSender
|
||||
// - computationId string
|
||||
func (_e *Provider_Expecter) Execute(config interface{}, eventSender interface{}, computationId interface{}) *Provider_Execute_Call {
|
||||
return &Provider_Execute_Call{Call: _e.mock.On("Execute", config, eventSender, computationId)}
|
||||
func (_e *Provider_Expecter) Execute(config interface{}, computationId interface{}) *Provider_Execute_Call {
|
||||
return &Provider_Execute_Call{Call: _e.mock.On("Execute", config, computationId)}
|
||||
}
|
||||
|
||||
func (_c *Provider_Execute_Call) Run(run func(config interface{}, eventSender vm.EventSender, computationId string)) *Provider_Execute_Call {
|
||||
func (_c *Provider_Execute_Call) Run(run func(config interface{}, computationId string)) *Provider_Execute_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}), args[1].(vm.EventSender), args[2].(string))
|
||||
run(args[0].(interface{}), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@@ -68,7 +67,7 @@ func (_c *Provider_Execute_Call) Return(_a0 vm.VM) *Provider_Execute_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_Execute_Call) RunAndReturn(run func(interface{}, vm.EventSender, string) vm.VM) *Provider_Execute_Call {
|
||||
func (_c *Provider_Execute_Call) RunAndReturn(run func(interface{}, string) vm.VM) *Provider_Execute_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
+1
-3
@@ -21,7 +21,7 @@ type VM interface {
|
||||
GetConfig() interface{}
|
||||
}
|
||||
|
||||
type Provider func(config interface{}, eventSender EventSender, computationId string) VM
|
||||
type Provider func(config interface{}, computationId string) VM
|
||||
|
||||
type Event struct {
|
||||
EventType string
|
||||
@@ -38,5 +38,3 @@ type Log struct {
|
||||
Level string
|
||||
Timestamp *timestamppb.Timestamp
|
||||
}
|
||||
|
||||
type EventSender func(event interface{}) error
|
||||
|
||||
@@ -68,6 +68,10 @@ type AgentClientConfig struct {
|
||||
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
|
||||
}
|
||||
|
||||
type ManagerClientConfig struct {
|
||||
BaseConfig
|
||||
}
|
||||
|
||||
type CVMClientConfig struct {
|
||||
BaseConfig
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// NewManagerClient creates new manager gRPC client instance.
|
||||
func NewManagerClient(cfg grpc.CVMClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
|
||||
func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
|
||||
client, err := grpc.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func TestNewManagerClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg grpc.CVMClientConfig
|
||||
cfg grpc.ManagerClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid config",
|
||||
cfg: grpc.CVMClientConfig{
|
||||
cfg: grpc.ManagerClientConfig{
|
||||
BaseConfig: grpc.BaseConfig{
|
||||
URL: "localhost:7001",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user