mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
add manager tests (#273)
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f6b69d65df
commit
5e01ecdab7
@@ -0,0 +1,168 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAddr struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func TestComputationIDFromAddress(t *testing.T) {
|
||||
ms := &managerService{
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
|
||||
"comp2": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}, make(chan *manager.ClientStreamMessage), "comp2"),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
address string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid address", "vm(3)", "comp1", false},
|
||||
{"Invalid address", "invalid", "", true},
|
||||
{"Non-existent CID", "vm(10)", "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ms.computationIDFromAddress(tt.address)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
ms := &managerService{
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
|
||||
},
|
||||
eventsChan: make(chan *manager.ClientStreamMessage, 1),
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
mockConn := new(MockConn)
|
||||
mockAddr := new(MockAddr)
|
||||
mockConn.On("RemoteAddr").Return(mockAddr)
|
||||
mockConn.On("Close").Return(nil)
|
||||
mockAddr.On("String").Return("vm(3)")
|
||||
|
||||
msg := &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: manager.VmRunning.String(),
|
||||
ComputationId: "comp1",
|
||||
Status: manager.VmRunning.String(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "agent",
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBytes, _ := proto.Marshal(msg)
|
||||
|
||||
mockConn.On("Read", mock.Anything).Return(len(msgBytes), nil).Run(func(args mock.Arguments) {
|
||||
copy(args.Get(0).([]byte), msgBytes)
|
||||
}).Once()
|
||||
|
||||
mockConn.On("Read", mock.Anything).Return(0, net.ErrClosed)
|
||||
|
||||
go ms.handleConnection(mockConn)
|
||||
|
||||
receivedMsg := <-ms.eventsChan
|
||||
assert.Equal(t, msg.GetAgentEvent().EventType, receivedMsg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, msg.GetAgentEvent().ComputationId, receivedMsg.GetAgentEvent().ComputationId)
|
||||
|
||||
mockConn.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestReportBrokenConnection(t *testing.T) {
|
||||
ms := &managerService{
|
||||
eventsChan: make(chan *manager.ClientStreamMessage, 1),
|
||||
}
|
||||
|
||||
ms.reportBrokenConnection("comp1")
|
||||
|
||||
select {
|
||||
case msg := <-ms.eventsChan:
|
||||
assert.Equal(t, "comp1", msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, manager.Disconnected.String(), msg.GetAgentEvent().Status)
|
||||
assert.Equal(t, "manager", msg.GetAgentEvent().Originator)
|
||||
default:
|
||||
t.Error("Expected message in eventsChan, but none received")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type mockStream struct {
|
||||
mock.Mock
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (m *mockStream) Recv() (*pkgmanager.ServerStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*pkgmanager.ServerStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockStream) Send(msg *pkgmanager.ClientStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestManagerClient_Process(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStream.On("Recv").Return(&pkgmanager.ServerStreamMessage{Message: &pkgmanager.ServerStreamMessage_StopComputation{StopComputation: &pkgmanager.StopComputation{}}}, nil).Maybe()
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Maybe()
|
||||
|
||||
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
|
||||
err := client.Process(ctx, cancel)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context deadline exceeded")
|
||||
}
|
||||
|
||||
func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
runReq := &pkgmanager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk1 := &pkgmanager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &pkgmanager.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[:len(runReqBytes)/2],
|
||||
IsLast: false,
|
||||
},
|
||||
}
|
||||
chunk2 := &pkgmanager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &pkgmanager.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[len(runReqBytes)/2:],
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("Run", mock.Anything, mock.AnythingOfType("*manager.ComputationRunReq")).Return("8080", nil)
|
||||
|
||||
err := client.handleRunReqChunks(context.Background(), chunk1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
runRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_RunRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "8080", runRes.RunRes.AgentPort)
|
||||
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
|
||||
}
|
||||
|
||||
func TestManagerClient_handleTerminateReq(t *testing.T) {
|
||||
client := ManagerClient{}
|
||||
|
||||
terminateReq := &pkgmanager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &pkgmanager.Terminate{
|
||||
Message: "Test termination",
|
||||
},
|
||||
}
|
||||
|
||||
err := client.handleTerminateReq(terminateReq)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Test termination")
|
||||
assert.True(t, errors.Contains(err, errTerminationFromServer))
|
||||
}
|
||||
|
||||
func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
stopReq := &pkgmanager.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &pkgmanager.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("Stop", mock.Anything, "test-comp-id").Return(nil)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
stopRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_StopComputationRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
|
||||
assert.Empty(t, stopRes.StopComputationRes.Message)
|
||||
}
|
||||
|
||||
func TestManagerClient_handleBackendInfoReq(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &pkgmanager.BackendInfoReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("FetchBackendInfo").Return([]byte("test-backend-info"), nil)
|
||||
|
||||
client.handleBackendInfoReq(infoReq)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
|
||||
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
|
||||
}
|
||||
@@ -56,22 +56,31 @@ func (s *grpcServer) Process(stream manager.ManagerService_ProcessServer) error
|
||||
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.incoming <- req
|
||||
}
|
||||
|
||||
s.incoming <- req
|
||||
}
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
sendMessage := func(msg *manager.ServerStreamMessage) error {
|
||||
switch m := msg.Message.(type) {
|
||||
case *manager.ServerStreamMessage_RunReq:
|
||||
return s.sendRunReqInChunks(stream, m.RunReq)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return stream.Send(msg)
|
||||
switch m := msg.Message.(type) {
|
||||
case *manager.ServerStreamMessage_RunReq:
|
||||
return s.sendRunReqInChunks(stream, m.RunReq)
|
||||
default:
|
||||
return stream.Send(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type mockServerStream struct {
|
||||
mock.Mock
|
||||
manager.ManagerService_ProcessServer
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Send(msg *manager.ServerStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Recv() (*manager.ClientStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*manager.ClientStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Context() context.Context {
|
||||
args := m.Called()
|
||||
return args.Get(0).(context.Context)
|
||||
}
|
||||
|
||||
type mockService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockService) Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) {
|
||||
m.Called(ctx, ipAddress, sendMessage, authInfo)
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
|
||||
server := NewServer(incoming, mockSvc)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
assert.IsType(t, &grpcServer{}, server)
|
||||
}
|
||||
|
||||
func TestGrpcServer_Process(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 1)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
}))
|
||||
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
|
||||
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil)
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context deadline exceeded")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunks(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
|
||||
runReq := &manager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
|
||||
largePayload := make([]byte, bufferSize*2)
|
||||
for i := range largePayload {
|
||||
largePayload[i] = byte(i % 256)
|
||||
}
|
||||
runReq.Algorithm = &manager.Algorithm{}
|
||||
runReq.Algorithm.UserKey = largePayload
|
||||
|
||||
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil).Times(4)
|
||||
|
||||
err := server.sendRunReqInChunks(mockStream, runReq)
|
||||
|
||||
assert.NoError(t, err)
|
||||
mockStream.AssertExpectations(t)
|
||||
|
||||
calls := mockStream.Calls
|
||||
assert.Equal(t, 4, len(calls))
|
||||
|
||||
for i, call := range calls {
|
||||
msg := call.Arguments[0].(*manager.ServerStreamMessage)
|
||||
chunk := msg.GetRunReqChunks()
|
||||
|
||||
assert.NotNil(t, chunk)
|
||||
assert.Equal(t, "test-id", chunk.Id)
|
||||
|
||||
if i < 3 {
|
||||
assert.False(t, chunk.IsLast)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(chunk.Data))
|
||||
assert.True(t, chunk.IsLast)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockAddr struct{}
|
||||
|
||||
func (mockAddr) Network() string { return "test network" }
|
||||
func (mockAddr) String() string { return "test" }
|
||||
|
||||
type mockAuthInfo struct{}
|
||||
|
||||
func (mockAuthInfo) AuthType() string { return "test auth" }
|
||||
|
||||
func TestGrpcServer_ProcessWithMockService(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 10)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
peerCtx := peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
})
|
||||
|
||||
mockStream.On("Context").Return(peerCtx)
|
||||
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil).Maybe()
|
||||
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
// Simulate sending a RunReq
|
||||
runReq := &manager.ComputationRunReq{Id: "test-run-id"}
|
||||
err := sendFunc(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: runReq,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).
|
||||
Return()
|
||||
|
||||
mockStream.On("Send", mock.MatchedBy(func(msg *manager.ServerStreamMessage) bool {
|
||||
chunks := msg.GetRunReqChunks()
|
||||
return chunks != nil && chunks.Id == "test-run-id"
|
||||
})).Return(nil)
|
||||
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunksError(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
|
||||
runReq := &manager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
|
||||
// Simulate an error when sending
|
||||
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(errors.New("send error")).Once()
|
||||
|
||||
err := server.sendRunReqInChunks(mockStream, runReq)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "send error")
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGrpcServer_ProcessMissingPeerInfo(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx := context.Background()
|
||||
|
||||
// Return a context without peer info
|
||||
mockStream.On("Context").Return(ctx)
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to get peer info")
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// FetchBackendInfo provides a mock function with given fields:
|
||||
func (_m *Service) FetchBackendInfo() ([]byte, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for FetchBackendInfo")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() ([]byte, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() []byte); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveAgentEventsLogs provides a mock function with given fields:
|
||||
func (_m *Service) RetrieveAgentEventsLogs() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: ctx, c
|
||||
func (_m *Service) Run(ctx context.Context, c *pkgmanager.ComputationRunReq) (string, error) {
|
||||
ret := _m.Called(ctx, c)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) (string, error)); ok {
|
||||
return rf(ctx, c)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) string); ok {
|
||||
r0 = rf(ctx, c)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *pkgmanager.ComputationRunReq) error); ok {
|
||||
r1 = rf(ctx, c)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields: ctx, computationID
|
||||
func (_m *Service) Stop(ctx context.Context, computationID string) error {
|
||||
ret := _m.Called(ctx, computationID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
|
||||
r0 = rf(ctx, computationID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package qemu
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConstructQemuArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Default configuration",
|
||||
config: Config{
|
||||
QemuBinPath: "qemu-system-x86_64",
|
||||
EnableKVM: true,
|
||||
Machine: "q35",
|
||||
CPU: "EPYC",
|
||||
SMPCount: 4,
|
||||
MaxCPUs: 64,
|
||||
MemID: "ram1",
|
||||
MemoryConfig: MemoryConfig{
|
||||
Size: "2048M",
|
||||
Slots: 5,
|
||||
Max: "30G",
|
||||
},
|
||||
OVMFCodeConfig: OVMFCodeConfig{
|
||||
If: "pflash",
|
||||
Format: "raw",
|
||||
Unit: 0,
|
||||
File: "/usr/share/OVMF/OVMF_CODE.fd",
|
||||
ReadOnly: "on",
|
||||
},
|
||||
OVMFVarsConfig: OVMFVarsConfig{
|
||||
If: "pflash",
|
||||
Format: "raw",
|
||||
Unit: 1,
|
||||
File: "/usr/share/OVMF/OVMF_VARS.fd",
|
||||
},
|
||||
NetDevConfig: NetDevConfig{
|
||||
ID: "vmnic",
|
||||
HostFwdAgent: 7020,
|
||||
GuestFwdAgent: 7002,
|
||||
},
|
||||
VirtioNetPciConfig: VirtioNetPciConfig{
|
||||
DisableLegacy: "on",
|
||||
IOMMUPlatform: true,
|
||||
Addr: "0x2",
|
||||
},
|
||||
VSockConfig: VSockConfig{
|
||||
ID: "vhost-vsock-pci0",
|
||||
GuestCID: 3,
|
||||
},
|
||||
DiskImgConfig: DiskImgConfig{
|
||||
KernelFile: "img/bzImage",
|
||||
RootFsFile: "img/rootfs.cpio.gz",
|
||||
},
|
||||
NoGraphic: true,
|
||||
Monitor: "pty",
|
||||
},
|
||||
expected: []string{
|
||||
"-enable-kvm",
|
||||
"-machine", "q35",
|
||||
"-cpu", "EPYC",
|
||||
"-smp", "4,maxcpus=64",
|
||||
"-m", "2048M,slots=5,maxmem=30G",
|
||||
"-drive", "if=pflash,format=raw,unit=0,file=/usr/share/OVMF/OVMF_CODE.fd,readonly=on",
|
||||
"-drive", "if=pflash,format=raw,unit=1,file=/usr/share/OVMF/OVMF_VARS.fd",
|
||||
"-netdev", "user,id=vmnic,hostfwd=tcp::7020-:7002",
|
||||
"-device", "virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,addr=0x2,romfile=",
|
||||
"-device", "vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3",
|
||||
"-kernel", "img/bzImage",
|
||||
"-append", "\"quiet console=null rootfstype=ramfs\"",
|
||||
"-initrd", "img/rootfs.cpio.gz",
|
||||
"-nographic",
|
||||
"-monitor", "pty",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SEV-SNP enabled configuration",
|
||||
config: Config{
|
||||
QemuBinPath: "qemu-system-x86_64",
|
||||
EnableKVM: true,
|
||||
EnableSEVSNP: true,
|
||||
Machine: "q35",
|
||||
CPU: "EPYC",
|
||||
SMPCount: 4,
|
||||
MaxCPUs: 64,
|
||||
MemID: "ram1",
|
||||
MemoryConfig: MemoryConfig{
|
||||
Size: "2048M",
|
||||
Slots: 5,
|
||||
Max: "30G",
|
||||
},
|
||||
OVMFCodeConfig: OVMFCodeConfig{
|
||||
If: "pflash",
|
||||
Format: "raw",
|
||||
Unit: 0,
|
||||
File: "/usr/share/OVMF/OVMF_CODE.fd",
|
||||
ReadOnly: "on",
|
||||
},
|
||||
OVMFVarsConfig: OVMFVarsConfig{
|
||||
If: "pflash",
|
||||
Format: "raw",
|
||||
Unit: 1,
|
||||
File: "/usr/share/OVMF/OVMF_VARS.fd",
|
||||
},
|
||||
NetDevConfig: NetDevConfig{
|
||||
ID: "vmnic",
|
||||
HostFwdAgent: 7020,
|
||||
GuestFwdAgent: 7002,
|
||||
},
|
||||
VirtioNetPciConfig: VirtioNetPciConfig{
|
||||
DisableLegacy: "on",
|
||||
IOMMUPlatform: true,
|
||||
Addr: "0x2",
|
||||
},
|
||||
VSockConfig: VSockConfig{
|
||||
ID: "vhost-vsock-pci0",
|
||||
GuestCID: 3,
|
||||
},
|
||||
DiskImgConfig: DiskImgConfig{
|
||||
KernelFile: "img/bzImage",
|
||||
RootFsFile: "img/rootfs.cpio.gz",
|
||||
},
|
||||
SevConfig: SevConfig{
|
||||
ID: "sev0",
|
||||
CBitPos: 51,
|
||||
ReducedPhysBits: 1,
|
||||
},
|
||||
NoGraphic: true,
|
||||
Monitor: "pty",
|
||||
},
|
||||
expected: []string{
|
||||
"-enable-kvm",
|
||||
"-machine", "q35",
|
||||
"-cpu", "EPYC",
|
||||
"-smp", "4,maxcpus=64",
|
||||
"-m", "2048M,slots=5,maxmem=30G",
|
||||
"-drive", "if=pflash,format=raw,unit=0,file=/usr/share/OVMF/OVMF_CODE.fd,readonly=on",
|
||||
"-drive", "if=pflash,format=raw,unit=1,file=/usr/share/OVMF/OVMF_VARS.fd",
|
||||
"-netdev", "user,id=vmnic,hostfwd=tcp::7020-:7002",
|
||||
"-device", "virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,addr=0x2,romfile=",
|
||||
"-device", "vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3",
|
||||
"-object", "memory-backend-memfd-private,id=ram1,size=2048M,share=true",
|
||||
"-machine", "memory-backend=ram1,kvm-type=protected",
|
||||
"-kernel", "img/bzImage",
|
||||
"-append", "\"quiet console=null rootfstype=ramfs\"",
|
||||
"-initrd", "img/rootfs.cpio.gz",
|
||||
"-object", "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1",
|
||||
"-machine", "memory-encryption=sev0",
|
||||
"-nographic",
|
||||
"-monitor", "pty",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.config.ConstructQemuArgs()
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("ConstructQemuArgs() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructQemuArgs_KernelHash(t *testing.T) {
|
||||
config := Config{
|
||||
EnableSEVSNP: true,
|
||||
KernelHash: true,
|
||||
SevConfig: SevConfig{
|
||||
ID: "sev0",
|
||||
CBitPos: 51,
|
||||
ReducedPhysBits: 1,
|
||||
},
|
||||
}
|
||||
|
||||
result := config.ConstructQemuArgs()
|
||||
|
||||
expected := "-object"
|
||||
expectedValue := "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1,discard=none,kernel-hashes=on"
|
||||
|
||||
found := false
|
||||
for i, arg := range result {
|
||||
if arg == expected && i+1 < len(result) {
|
||||
if result[i+1] == expectedValue {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("ConstructQemuArgs() did not contain expected SEV-SNP configuration with kernel hashes enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstructQemuArgs_HostData(t *testing.T) {
|
||||
config := Config{
|
||||
EnableSEVSNP: true,
|
||||
SevConfig: SevConfig{
|
||||
ID: "sev0",
|
||||
CBitPos: 51,
|
||||
ReducedPhysBits: 1,
|
||||
HostData: "test-host-data",
|
||||
},
|
||||
}
|
||||
|
||||
result := config.ConstructQemuArgs()
|
||||
|
||||
expected := "-object"
|
||||
expectedValue := "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1,host-data=test-host-data"
|
||||
|
||||
found := false
|
||||
for i, arg := range result {
|
||||
if arg == expected && i+1 < len(result) {
|
||||
if result[i+1] == expectedValue {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("ConstructQemuArgs() did not contain expected SEV-SNP configuration with host data")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package qemu
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFilePersistence(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
fp, err := NewFilePersistence(tempDir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewFilePersistence failed: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := fp.(*FilePersistence); !ok {
|
||||
t.Fatalf("NewFilePersistence didn't return a FilePersistence")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveVM(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
fp, _ := NewFilePersistence(tempDir)
|
||||
|
||||
state := VMState{
|
||||
ID: "test-vm",
|
||||
Config: Config{},
|
||||
PID: 1234,
|
||||
}
|
||||
|
||||
err := fp.SaveVM(state)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveVM failed: %v", err)
|
||||
}
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(filepath.Join(tempDir, "test-vm.json")); os.IsNotExist(err) {
|
||||
t.Fatalf("SaveVM didn't create a file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadVMs(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
fp, _ := NewFilePersistence(tempDir)
|
||||
|
||||
// Save two VMs
|
||||
states := []VMState{
|
||||
{ID: "vm1", Config: Config{}, PID: 1234},
|
||||
{ID: "vm2", Config: Config{}, PID: 5678},
|
||||
}
|
||||
|
||||
for _, state := range states {
|
||||
if err := fp.SaveVM(state); err != nil {
|
||||
t.Fatalf("SaveVM failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load VMs
|
||||
loadedStates, err := fp.LoadVMs()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadVMs failed: %v", err)
|
||||
}
|
||||
|
||||
if len(loadedStates) != len(states) {
|
||||
t.Fatalf("LoadVMs returned %d states, expected %d", len(loadedStates), len(states))
|
||||
}
|
||||
|
||||
// Check if loaded states match saved states
|
||||
for i, state := range states {
|
||||
if state.ID != loadedStates[i].ID || state.PID != loadedStates[i].PID {
|
||||
t.Fatalf("Loaded state %v doesn't match saved state %v", loadedStates[i], state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteVM(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
fp, _ := NewFilePersistence(tempDir)
|
||||
|
||||
state := VMState{ID: "test-vm", Config: Config{}, PID: 1234}
|
||||
|
||||
// Save VM
|
||||
if err := fp.SaveVM(state); err != nil {
|
||||
t.Fatalf("SaveVM failed: %v", err)
|
||||
}
|
||||
|
||||
// Delete VM
|
||||
if err := fp.DeleteVM(state.ID); err != nil {
|
||||
t.Fatalf("DeleteVM failed: %v", err)
|
||||
}
|
||||
|
||||
// Check if file is deleted
|
||||
if _, err := os.Stat(filepath.Join(tempDir, "test-vm.json")); !os.IsNotExist(err) {
|
||||
t.Fatalf("DeleteVM didn't remove the file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadVMsWithInvalidFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
fp, _ := NewFilePersistence(tempDir)
|
||||
|
||||
invalidData := []byte("{invalid json")
|
||||
if err := os.WriteFile(filepath.Join(tempDir, "invalid.json"), invalidData, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create invalid JSON file: %v", err)
|
||||
}
|
||||
|
||||
_, err := fp.LoadVMs()
|
||||
if err == nil {
|
||||
t.Fatalf("LoadVMs should have failed with invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
fp, _ := NewFilePersistence(tempDir)
|
||||
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines * 2)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
state := VMState{ID: fmt.Sprintf("vm-%d", id), Config: Config{}, PID: id}
|
||||
if err := fp.SaveVM(state); err != nil {
|
||||
t.Errorf("Concurrent SaveVM failed: %v", err)
|
||||
}
|
||||
}(i)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := fp.LoadVMs(); err != nil {
|
||||
t.Errorf("Concurrent LoadVMs failed: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
+102
-15
@@ -3,8 +3,10 @@
|
||||
package qemu
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
@@ -15,25 +17,110 @@ func TestNewVM(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
computationId := "test-computation"
|
||||
|
||||
nvm := NewVM(config, logsChan, computationId)
|
||||
vm := NewVM(config, logsChan, computationId)
|
||||
|
||||
assert.NotNil(t, nvm)
|
||||
assert.IsType(t, &qemuVM{}, nvm)
|
||||
assert.NotNil(t, vm)
|
||||
assert.IsType(t, &qemuVM{}, vm)
|
||||
}
|
||||
|
||||
func TestVM_Stop(t *testing.T) {
|
||||
// Setup
|
||||
v := &qemuVM{
|
||||
cmd: exec.Command("sleep", "1"),
|
||||
func TestStart(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpFile, err := os.CreateTemp("", "test-ovmf-vars")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
config := Config{
|
||||
OVMFVarsConfig: OVMFVarsConfig{
|
||||
File: tmpFile.Name(),
|
||||
},
|
||||
QemuBinPath: "echo", // Use 'echo' as a dummy QEMU binary
|
||||
}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
computationId := "test-computation"
|
||||
|
||||
vm := NewVM(config, logsChan, computationId).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, vm.cmd)
|
||||
|
||||
// Clean up
|
||||
_ = vm.Stop()
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
cmd := exec.Command("echo", "test")
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
Process: cmd.Process,
|
||||
},
|
||||
}
|
||||
|
||||
err := v.cmd.Start()
|
||||
err = vm.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test
|
||||
err = v.Stop()
|
||||
|
||||
// Assert
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, v.cmd.Wait()) // Process should have been killed
|
||||
}
|
||||
|
||||
func TestSetProcess(t *testing.T) {
|
||||
vm := &qemuVM{
|
||||
config: Config{
|
||||
QemuBinPath: "echo", // Use 'echo' as a dummy QEMU binary
|
||||
},
|
||||
}
|
||||
|
||||
err := vm.SetProcess(os.Getpid()) // Use current process as a dummy
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, vm.cmd)
|
||||
assert.NotNil(t, vm.cmd.Process)
|
||||
}
|
||||
|
||||
func TestGetProcess(t *testing.T) {
|
||||
expectedPid := 12345
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
Process: &os.Process{Pid: expectedPid},
|
||||
},
|
||||
}
|
||||
|
||||
pid := vm.GetProcess()
|
||||
assert.Equal(t, expectedPid, pid)
|
||||
}
|
||||
|
||||
func TestGetCID(t *testing.T) {
|
||||
expectedCID := 42
|
||||
vm := &qemuVM{
|
||||
config: Config{
|
||||
VSockConfig: VSockConfig{
|
||||
GuestCID: expectedCID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cid := vm.GetCID()
|
||||
assert.Equal(t, expectedCID, cid)
|
||||
}
|
||||
|
||||
func TestCheckVMProcessPeriodically(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
vm := &qemuVM{
|
||||
logsChan: logsChan,
|
||||
computationId: "test-computation",
|
||||
cmd: &exec.Cmd{
|
||||
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
|
||||
},
|
||||
}
|
||||
|
||||
go vm.checkVMProcessPeriodically()
|
||||
|
||||
select {
|
||||
case msg := <-logsChan:
|
||||
assert.NotNil(t, msg.GetAgentEvent())
|
||||
assert.Equal(t, "test-computation", msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, manager.VmRunning.String(), msg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status)
|
||||
case <-time.After(2 * interval):
|
||||
t.Fatal("Timeout waiting for VM stopped message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,8 @@ var (
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
//
|
||||
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Service interface {
|
||||
// Run create a computation.
|
||||
Run(ctx context.Context, c *manager.ComputationRunReq) (string, error)
|
||||
|
||||
@@ -6,11 +6,15 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
@@ -252,3 +256,110 @@ func TestPublishEvent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputationHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
computation agent.Computation
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid computation",
|
||||
computation: agent.Computation{
|
||||
ID: "test-id",
|
||||
Name: "test-name",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hash, err := computationHash(tt.computation)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, hash)
|
||||
|
||||
hash2, _ := computationHash(tt.computation)
|
||||
assert.Equal(t, hash, hash2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantStart int
|
||||
wantEnd int
|
||||
wantErr bool
|
||||
}{
|
||||
{"Valid range", "1-5", 1, 5, false},
|
||||
{"Invalid format", "1:5", 0, 0, true},
|
||||
{"Start greater than end", "5-1", 0, 0, true},
|
||||
{"Non-numeric input", "a-b", 0, 0, true},
|
||||
{"Single number", "5", 0, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
start, end, err := decodeRange(tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantStart, start)
|
||||
assert.Equal(t, tt.wantEnd, end)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreVMs(t *testing.T) {
|
||||
mockPersistence := new(persistenceMocks.Persistence)
|
||||
vmf := new(mocks.Provider)
|
||||
vmMock := new(mocks.VM)
|
||||
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
|
||||
vmMock.On("SetProcess", mock.Anything).Return(nil)
|
||||
ms := &managerService{
|
||||
persistence: mockPersistence,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: make(chan *manager.ClientStreamMessage, 10),
|
||||
vmFactory: vmf.Execute,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
cmd := exec.Command("echo", "test")
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockPersistence.On("LoadVMs").Return([]qemu.VMState{
|
||||
{ID: "vm1", PID: cmd.Process.Pid},
|
||||
{ID: "vm2", PID: 1000},
|
||||
}, nil)
|
||||
|
||||
mockPersistence.On("DeleteVM", mock.Anything).Return(nil)
|
||||
|
||||
err = ms.restoreVMs()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, ms.vms, 1)
|
||||
assert.Contains(t, ms.vms, "vm1")
|
||||
|
||||
mockPersistence.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestProcessExists(t *testing.T) {
|
||||
ms := &managerService{}
|
||||
|
||||
assert.True(t, ms.processExists(os.Getpid()))
|
||||
|
||||
assert.False(t, ms.processExists(99999))
|
||||
|
||||
if os.Getuid() != 0 { // Skip this test if running as root.
|
||||
assert.False(t, ms.processExists(1)) // PID 1 is usually the init process.
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user