mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-159 - Improve test coverage to 65% (#310)
* few atls tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * remove commented code Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add atls tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * new line Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add more test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * more test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add empty line and parallel test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * move const outside test case Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
d5941edb56
commit
5a22ac2eca
@@ -38,6 +38,13 @@ func TestCopyFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFile_NonExistentSource(t *testing.T) {
|
||||
err := CopyFile("nonexistent.txt", "destination.txt")
|
||||
if err == nil {
|
||||
t.Error("CopyFile did not return an error for a nonexistent source file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteFilesInDir(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "deletefiles_test")
|
||||
if err != nil {
|
||||
@@ -111,6 +118,13 @@ func TestChecksum(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksum_NonExistentFile(t *testing.T) {
|
||||
_, err := Checksum("nonexistent.txt")
|
||||
if err == nil {
|
||||
t.Error("Checksum did not return an error for a nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumHex(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "checksumhex_test")
|
||||
if err != nil {
|
||||
@@ -134,3 +148,10 @@ func TestChecksumHex(t *testing.T) {
|
||||
t.Errorf("ChecksumHex mismatch. Got %s, want %s", checksumHex, expectedChecksumHex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChecksumHex_NonExistentFile(t *testing.T) {
|
||||
_, err := ChecksumHex("nonexistent.txt")
|
||||
if err == nil {
|
||||
t.Error("ChecksumHex did not return an error for a nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
+101
-14
@@ -33,26 +33,100 @@ func (m *mockStream) Send(msg *manager.ClientStreamMessage) error {
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestManagerClient_Process(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
func TestManagerClient_Process1(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Stop computation",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &manager.StopComputation{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Terminate request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: errTerminationFromServer.Error(),
|
||||
},
|
||||
{
|
||||
name: "Backend info request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &manager.BackendInfoReq{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
mockSvc.On("FetchBackendInfo", mock.Anything, mock.Anything).Return(nil, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Run request chunks",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
mockSvc.On("Run", mock.Anything, mock.Anything).Return("", assert.AnError).Once()
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Receive error",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service) {
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{}, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{Message: &manager.ServerStreamMessage_StopComputation{StopComputation: &manager.StopComputation{}}}, nil).Maybe()
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Maybe()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
tc.setupMocks(mockStream, mockSvc)
|
||||
|
||||
err := client.Process(ctx, cancel)
|
||||
err := client.Process(ctx, cancel)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context deadline exceeded")
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
if tc.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
@@ -233,3 +307,16 @@ func TestManagerClient_handleSVMInfoReq(t *testing.T) {
|
||||
assert.Equal(t, "", infoRes.SvmInfo.EosVersion)
|
||||
assert.Equal(t, qemu.KernelCommandLine, infoRes.SvmInfo.KernelCmd)
|
||||
}
|
||||
|
||||
func TestManagerClient_timeoutRequest(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
rm.requests["test-id"] = &runRequest{
|
||||
timer: time.NewTimer(100 * time.Millisecond),
|
||||
buffer: []byte("test-data"),
|
||||
lastChunk: time.Now(),
|
||||
}
|
||||
|
||||
rm.timeoutRequest("test-id")
|
||||
|
||||
assert.Len(t, rm.requests, 0)
|
||||
}
|
||||
|
||||
+122
-64
@@ -54,34 +54,60 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGrpcServer_Process(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 1)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
tests := []struct {
|
||||
name string
|
||||
recvReturn *manager.ClientStreamMessage
|
||||
recvError error
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Process with context deadline exceeded",
|
||||
recvReturn: &manager.ClientStreamMessage{},
|
||||
recvError: nil,
|
||||
expectedError: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Process with Recv error",
|
||||
recvReturn: &manager.ClientStreamMessage{},
|
||||
recvError: errors.New("recv error"),
|
||||
expectedError: "recv error",
|
||||
},
|
||||
}
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 1)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
}))
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
}))
|
||||
|
||||
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil)
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
|
||||
if tt.recvError == nil {
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err := server.Process(mockStream)
|
||||
mockStream.On("Recv").Return(tt.recvReturn, tt.recvError)
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context deadline exceeded")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunks(t *testing.T) {
|
||||
@@ -138,58 +164,90 @@ type mockAuthInfo struct{}
|
||||
func (mockAuthInfo) AuthType() string { return "test auth" }
|
||||
|
||||
func TestGrpcServer_ProcessWithMockService(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 10)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMockFn func(*mockService, *mockServerStream)
|
||||
}{
|
||||
{
|
||||
name: "Run Request Test",
|
||||
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
runReq := &manager.ComputationRunReq{Id: "test-run-id"}
|
||||
err := sendFunc(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: runReq,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).
|
||||
Return()
|
||||
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
mockStream.On("Send", mock.MatchedBy(func(msg *manager.ServerStreamMessage) bool {
|
||||
chunks := msg.GetRunReqChunks()
|
||||
return chunks != nil && chunks.Id == "test-run-id"
|
||||
})).Return(nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Terminate Request Test",
|
||||
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
err := sendFunc(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).Return()
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
mockStream.On("Send", mock.AnythingOfType("*manager.ServerStreamMessage")).Return(nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
peerCtx := peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
})
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
incoming := make(chan *manager.ClientStreamMessage, 10)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream.On("Context").Return(peerCtx)
|
||||
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil).Maybe()
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
// Simulate sending a RunReq
|
||||
runReq := &manager.ComputationRunReq{Id: "test-run-id"}
|
||||
err := sendFunc(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: runReq,
|
||||
},
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
peerCtx := peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).
|
||||
Return()
|
||||
|
||||
mockStream.On("Send", mock.MatchedBy(func(msg *manager.ServerStreamMessage) bool {
|
||||
chunks := msg.GetRunReqChunks()
|
||||
return chunks != nil && chunks.Id == "test-run-id"
|
||||
})).Return(nil)
|
||||
mockStream.On("Context").Return(peerCtx)
|
||||
mockStream.On("Recv").Return(&manager.ClientStreamMessage{}, nil).Maybe()
|
||||
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
tt.setupMockFn(mockSvc, mockStream)
|
||||
|
||||
err := server.Process(mockStream)
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunksError(t *testing.T) {
|
||||
|
||||
+110
-67
@@ -87,19 +87,20 @@ func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
if !vsockDeviceExists() {
|
||||
t.Skip("Skipping test: vsock device not available")
|
||||
}
|
||||
|
||||
logger := &slog.Logger{}
|
||||
reportBrokenConnection := func(address string) {}
|
||||
eventsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
e, err := New(logger, reportBrokenConnection, eventsChan)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, e)
|
||||
assert.IsType(t, &events{}, e)
|
||||
if vsockDeviceExists() {
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, e)
|
||||
assert.IsType(t, &events{}, e)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
@@ -123,6 +124,25 @@ func TestListen(t *testing.T) {
|
||||
mockListener.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestListenContextDone(t *testing.T) {
|
||||
mockListener := new(MockVsockListener)
|
||||
mockConn := new(MockConn)
|
||||
|
||||
e := &events{
|
||||
lis: mockListener,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mockListener.On("Accept").Return(mockConn, nil)
|
||||
|
||||
e.Listen(ctx)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func vsockDeviceExists() bool {
|
||||
fs, err := os.Stat("/dev/vsock")
|
||||
if err != nil {
|
||||
@@ -180,73 +200,96 @@ func (m *MockConnWithBuffer) SetWriteDeadline(t time.Time) error {
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
mockConn := NewMockConnWithBuffer()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
|
||||
e := &events{
|
||||
logger: mglog.NewMock(),
|
||||
eventsChan: eventsChan,
|
||||
reportBrokenConnection: func(address string) {},
|
||||
}
|
||||
|
||||
message := &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: "test_event",
|
||||
ComputationId: "test_computation",
|
||||
Status: "test_status",
|
||||
Originator: "test_originator",
|
||||
Timestamp: timestamppb.Now(),
|
||||
Details: []byte("test_details"),
|
||||
tests := []struct {
|
||||
name string
|
||||
message *manager.ClientStreamMessage
|
||||
}{
|
||||
{
|
||||
name: "handle agent event",
|
||||
message: &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: "test_event",
|
||||
ComputationId: "test_computation",
|
||||
Status: "test_status",
|
||||
Originator: "test_originator",
|
||||
Timestamp: timestamppb.Now(),
|
||||
Details: []byte("test_details"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle agent log",
|
||||
message: &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
ComputationId: "test_computation",
|
||||
Timestamp: timestamppb.Now(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(message)
|
||||
assert.NoError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConn := NewMockConnWithBuffer()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
|
||||
messageID := uint32(1)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, messageID)
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(len(data)))
|
||||
assert.NoError(t, err)
|
||||
_, err = mockConn.readBuf.Write(data)
|
||||
assert.NoError(t, err)
|
||||
e := &events{
|
||||
logger: mglog.NewMock(),
|
||||
eventsChan: eventsChan,
|
||||
reportBrokenConnection: func(address string) {},
|
||||
}
|
||||
|
||||
// Add EOF to signal end of stream
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
data, err := proto.Marshal(tt.message)
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
e.handleConnection(mockConn)
|
||||
close(done)
|
||||
}()
|
||||
messageID := uint32(1)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, messageID)
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(len(data)))
|
||||
assert.NoError(t, err)
|
||||
_, err = mockConn.readBuf.Write(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var receivedMessage *manager.ClientStreamMessage
|
||||
select {
|
||||
case receivedMessage = <-eventsChan:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for message in eventsChan")
|
||||
// Add EOF to signal end of stream
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
e.handleConnection(mockConn)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
var receivedMessage *manager.ClientStreamMessage
|
||||
select {
|
||||
case receivedMessage = <-eventsChan:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for message in eventsChan")
|
||||
}
|
||||
|
||||
assert.NotNil(t, receivedMessage)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// handleConnection has exited
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for handleConnection to exit")
|
||||
}
|
||||
|
||||
// Check if ack was written
|
||||
var receivedAck uint32
|
||||
err = binary.Read(mockConn.writeBuf, binary.LittleEndian, &receivedAck)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, messageID, receivedAck)
|
||||
|
||||
// Ensure no unexpected calls were made on the mock
|
||||
mockConn.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
assert.NotNil(t, receivedMessage)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// handleConnection has exited
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for handleConnection to exit")
|
||||
}
|
||||
|
||||
// Check if ack was written
|
||||
var receivedAck uint32
|
||||
err = binary.Read(mockConn.writeBuf, binary.LittleEndian, &receivedAck)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, messageID, receivedAck)
|
||||
|
||||
// Ensure no unexpected calls were made on the mock
|
||||
mockConn.AssertExpectations(t)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package vm
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
const numGoroutines = 10
|
||||
|
||||
func TestNewStateMachine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expectedState manager.ManagerState
|
||||
}{
|
||||
{
|
||||
name: "New state machine initialization",
|
||||
expectedState: manager.VmProvision,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sm := NewStateMachine()
|
||||
assert.Equal(t, tc.expectedState.String(), sm.State())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachineTransitions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState manager.ManagerState
|
||||
newState manager.ManagerState
|
||||
expectedError bool
|
||||
expectedState manager.ManagerState
|
||||
transitionDesc string
|
||||
}{
|
||||
{
|
||||
name: "Valid transition from VmProvision to VmRunning",
|
||||
initialState: manager.VmProvision,
|
||||
newState: manager.VmRunning,
|
||||
expectedError: false,
|
||||
expectedState: manager.VmRunning,
|
||||
transitionDesc: "should succeed",
|
||||
},
|
||||
{
|
||||
name: "Valid transition from VmProvision to StopComputationRun",
|
||||
initialState: manager.VmProvision,
|
||||
newState: manager.StopComputationRun,
|
||||
expectedError: false,
|
||||
expectedState: manager.StopComputationRun,
|
||||
transitionDesc: "should succeed",
|
||||
},
|
||||
{
|
||||
name: "Valid transition from VmRunning to StopComputationRun",
|
||||
initialState: manager.VmRunning,
|
||||
newState: manager.StopComputationRun,
|
||||
expectedError: false,
|
||||
expectedState: manager.StopComputationRun,
|
||||
transitionDesc: "should succeed",
|
||||
},
|
||||
{
|
||||
name: "Valid transition from StopComputationRun to VmRunning",
|
||||
initialState: manager.StopComputationRun,
|
||||
newState: manager.VmRunning,
|
||||
expectedError: false,
|
||||
expectedState: manager.VmRunning,
|
||||
transitionDesc: "should succeed",
|
||||
},
|
||||
{
|
||||
name: "Invalid transition from VmRunning to VmProvision",
|
||||
initialState: manager.VmRunning,
|
||||
newState: manager.VmProvision,
|
||||
expectedError: true,
|
||||
expectedState: manager.VmRunning,
|
||||
transitionDesc: "should fail",
|
||||
},
|
||||
{
|
||||
name: "Invalid transition from StopComputationRun to VmProvision",
|
||||
initialState: manager.StopComputationRun,
|
||||
newState: manager.VmProvision,
|
||||
expectedError: true,
|
||||
expectedState: manager.StopComputationRun,
|
||||
transitionDesc: "should fail",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sm := &sm{state: tc.initialState}
|
||||
|
||||
err := sm.Transition(tc.newState)
|
||||
|
||||
if tc.expectedError {
|
||||
assert.Error(t, err, "Expected transition to fail")
|
||||
} else {
|
||||
assert.NoError(t, err, "Expected transition to succeed")
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedState.String(), sm.State(),
|
||||
"State should be %s after transition", tc.expectedState.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachineConcurrency(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialState manager.ManagerState
|
||||
transitionState manager.ManagerState
|
||||
expectedStates []string
|
||||
}{
|
||||
{
|
||||
name: "Transition from VmProvision to VmRunning",
|
||||
initialState: manager.VmProvision,
|
||||
transitionState: manager.VmRunning,
|
||||
expectedStates: []string{manager.VmProvision.String(), manager.VmRunning.String()},
|
||||
},
|
||||
{
|
||||
name: "Transition from VmRunning to StopComputationRun",
|
||||
initialState: manager.VmRunning,
|
||||
transitionState: manager.StopComputationRun,
|
||||
expectedStates: []string{manager.VmRunning.String(), manager.StopComputationRun.String()},
|
||||
},
|
||||
{
|
||||
name: "Transition from StopComputationRun back to VmRunning",
|
||||
initialState: manager.StopComputationRun,
|
||||
transitionState: manager.VmRunning,
|
||||
expectedStates: []string{manager.StopComputationRun.String(), manager.VmRunning.String()},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sm := NewStateMachine()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = sm.Transition(tc.transitionState)
|
||||
_ = sm.State()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
finalState := sm.State()
|
||||
assert.Contains(t, tc.expectedStates, finalState,
|
||||
"Final state should be one of the expected states")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateRetrieval(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
state manager.ManagerState
|
||||
expectedString string
|
||||
}{
|
||||
{
|
||||
name: "Get VmProvision state",
|
||||
state: manager.VmProvision,
|
||||
expectedString: manager.VmProvision.String(),
|
||||
},
|
||||
{
|
||||
name: "Get VmRunning state",
|
||||
state: manager.VmRunning,
|
||||
expectedString: manager.VmRunning.String(),
|
||||
},
|
||||
{
|
||||
name: "Get StopComputationRun state",
|
||||
state: manager.StopComputationRun,
|
||||
expectedString: manager.StopComputationRun.String(),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sm := &sm{state: tc.state}
|
||||
assert.Equal(t, tc.expectedString, sm.State())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -42,8 +42,4 @@ type Log struct {
|
||||
Timestamp *timestamppb.Timestamp
|
||||
}
|
||||
|
||||
func (l *Log) IsEventLog() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type EventSender func(event interface{}) error
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package atls
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
cert := []byte("dummy_cert")
|
||||
key := []byte("dummy_key")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
address string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid address",
|
||||
address: "127.0.0.1:8889",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid address format",
|
||||
address: "127.0.0.1",
|
||||
err: errListener,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
l, err := Listen(c.address, cert, key)
|
||||
assert.True(t, errors.Contains(err, c.err))
|
||||
if l != nil {
|
||||
t.Cleanup(func() {
|
||||
err := l.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestATLSServerListener_Accept(t *testing.T) {
|
||||
t.Run("Accepts connection", func(t *testing.T) {
|
||||
listener, err := Listen("127.0.0.1:8887", []byte("dummy_cert"), []byte("dummy_key"))
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := listener.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
conn, err := listener.Accept()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, conn)
|
||||
})
|
||||
}
|
||||
|
||||
func TestATLSConn_Read(t *testing.T) {
|
||||
buffer := make([]byte, 1024)
|
||||
|
||||
t.Run("Read with nil connection", func(t *testing.T) {
|
||||
conn := &ATLSConn{tlsConn: nil}
|
||||
_, err := conn.Read(buffer)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err, errConnFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestATLSConn_Write(t *testing.T) {
|
||||
data := []byte("test data")
|
||||
|
||||
t.Run("Write with nil connection", func(t *testing.T) {
|
||||
conn := &ATLSConn{tlsConn: nil}
|
||||
_, err := conn.Write(data)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, err, errConnFailed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestATLSConn_DeadlineFunctions(t *testing.T) {
|
||||
conn := &ATLSConn{}
|
||||
|
||||
t.Run("SetDeadline - valid time", func(t *testing.T) {
|
||||
err := conn.SetDeadline(time.Now().Add(1 * time.Minute))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetReadDeadline - past time", func(t *testing.T) {
|
||||
err := conn.SetReadDeadline(time.Now().Add(-1 * time.Minute))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetWriteDeadline - zero time", func(t *testing.T) {
|
||||
err := conn.SetWriteDeadline(time.Time{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -78,17 +78,35 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverRunning bool
|
||||
config pkggrpc.Config
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "successful connection",
|
||||
serverRunning: true,
|
||||
err: nil,
|
||||
config: pkggrpc.Config{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "server not healthy",
|
||||
serverRunning: false,
|
||||
err: ErrAgentServiceUnavailable,
|
||||
config: pkggrpc.Config{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
err: ErrAgentServiceUnavailable,
|
||||
},
|
||||
{
|
||||
name: "invalid config, missing BackendInfo with aTLS",
|
||||
config: pkggrpc.Config{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
AttestedTLS: true,
|
||||
},
|
||||
err: pkggrpc.ErrBackendInfoMissing,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -102,16 +120,7 @@ func TestAgentClientIntegration(t *testing.T) {
|
||||
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
|
||||
}
|
||||
|
||||
cfg := pkggrpc.Config{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
}
|
||||
|
||||
if !tt.serverRunning {
|
||||
cfg.URL = ""
|
||||
}
|
||||
|
||||
client, agentClient, err := NewAgentClient(ctx, cfg)
|
||||
client, agentClient, err := NewAgentClient(ctx, tt.config)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
if err != nil {
|
||||
assert.Nil(t, client)
|
||||
|
||||
@@ -134,6 +134,11 @@ func TestClientSecure(t *testing.T) {
|
||||
secure: withmTLS,
|
||||
expected: "with mTLS",
|
||||
},
|
||||
{
|
||||
name: "With aTLS",
|
||||
secure: withaTLS,
|
||||
expected: WithATLS,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -1,25 +1,28 @@
|
||||
{
|
||||
"rootOfTrust":{
|
||||
"product":"Milan",
|
||||
"checkCrl":true,
|
||||
"productLine":"Milan"
|
||||
},
|
||||
"policy":{
|
||||
"policy":"196608",
|
||||
"permit_provisional_firmware":true,
|
||||
"familyId":"AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
"imageId":"AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
"vmpl":0,
|
||||
"minimumTcb":"15352208179752599555",
|
||||
"minimumLaunchTcb":"15352208179752599555",
|
||||
"measurement":"TsWRmg8efWUW9XHZIomxBKrv4iCYeMO3ZlUPr+OhU5/QAPjCr96w0Dq9gJ7EaaP/",
|
||||
"hostData":"HE5X+yGlBfpKlg4z9TTdV6ATs7MUr4Y+EhN+reuG+zY=",
|
||||
"reportIdMa":"//////////////////////////////////////////8=",
|
||||
"chipId":"GrFqtQ+lrkLsjBslu9pcC6XqkrtFWY1ArIQ+I4gugQIsvCG0qekSvEtE4P/SLSJ6mHNpOkY0MHnGpvz1OkV+kw==",
|
||||
"minimumBuild":7,
|
||||
"minimumVersion":"1.55",
|
||||
"product":{
|
||||
"name":"SEV_PRODUCT_MILAN"
|
||||
}
|
||||
"policy": {
|
||||
"policy": 196608,
|
||||
"family_id": "AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
"image_id": "AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
"vmpl": 0,
|
||||
"minimum_tcb": 15352208179752599555,
|
||||
"minimum_launch_tcb": 15352208179752599555,
|
||||
"require_author_key": false,
|
||||
"measurement": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
|
||||
"host_data": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
|
||||
"report_id_ma": "//////////////////////////////////////////8=",
|
||||
"chip_id": "GrFqtQ+lrkLsjBslu9pcC6XqkrtFWY1ArIQ+I4gugQIsvCG0qekSvEtE4P/SLSJ6mHNpOkY0MHnGpvz1OkV+kw==",
|
||||
"minimum_build": 8,
|
||||
"minimum_version": "1.55",
|
||||
"permit_provisional_firmware": true,
|
||||
"require_id_block": false,
|
||||
"product": {
|
||||
"name": 1
|
||||
}
|
||||
}
|
||||
},
|
||||
"root_of_trust": {
|
||||
"product": "Milan",
|
||||
"check_crl": true,
|
||||
"disallow_network": false,
|
||||
"product_line": "Milan"
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user