mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
* Implement IMAMeasurements method in agentSDK and add corresponding unit tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for NewIMAMeasurements command in CLI Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add error assertion for command execution in NewIMAMeasurements test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix nil pointer dereference in Close method and update NewCreateVMCmd logic for manager client initialization Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor file permission settings to use octal notation and improve cleanup handling in NewCreateVMCmd test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive unit tests for state machine functionality Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add mock implementation for Algorithm interface and corresponding test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor file permission settings to use octal notation in TestStopComputationIntegration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove redundant reset test cases from TestStateMachine_Reset Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix race condition in action call verification in TestStateMachine_HandleEvent Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance state machine with reset functionality and improve thread safety in event handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in state machine start function during tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove concurrent reset and send event test from state machine tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove error logging for Start function in transition tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add mock implementations for AgentService_IMAMeasurementsClient and Service Shutdown method; enhance progress tests for IMA measurements handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for FileStorage functionality including loading, saving, and concurrent access Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance tests by adding dataset and algorithm hashes in handleRunReqChunks; improve error handling in TestFileStorage_ErrorHandling cleanup Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TestManagerClient_Process by adding new test cases for Agent state and Disconnect requests; update setupMocks to include grpcClient Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix graceful shutdown in gRPC server by adding nil checks for health and server instances Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TestAttestation by adding mock expectations for VTpmAttestation and Attestation methods; update service call to include platform parameter Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance gRPC Server by adding synchronization for start/stop methods; prevent multiple starts and ensure graceful shutdown Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for gRPC server methods including VM creation, removal, and info retrieval Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for SEVSNP and TDX host capabilities; remove unused vsock code Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add a newline for better readability in vm_test.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add integration tests for gRPC client in cvm_test.go Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove unused vsock dependencies and add comprehensive unit tests for GCP attestation functions Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Skip GCP tests if credentials are not set Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add tests for error handling in attestation configuration and GCP commands Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Improve error handling in Azure VM test response writing Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Skip tests in GCP functions if credentials are not set Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive unit tests for Azure attestation provider and verifier Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for TPM functionality and improve error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add comprehensive tests for attestation functionality and improve error handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add validation for teeNonce in TeeAttestation and implement comprehensive tests for provider methods Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor error messages in TDX attestation tests for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix error message in TeeAttestation test for valid nonce case Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add MeasurementProvider mock and update mockery configuration Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add logging for product in parseUints and rename test functions for clarity Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor TestSevsnpverify to reset configuration and improve error logging Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
85a2b7a6c8
commit
4e8057f481
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// Algorithm is an autogenerated mock type for the Algorithm type
|
||||
type Algorithm struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Algorithm_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Algorithm) EXPECT() *Algorithm_Expecter {
|
||||
return &Algorithm_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Run provides a mock function with no fields
|
||||
func (_m *Algorithm) Run() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Algorithm_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
|
||||
type Algorithm_Run_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Run is a helper method to define mock.On call
|
||||
func (_e *Algorithm_Expecter) Run() *Algorithm_Run_Call {
|
||||
return &Algorithm_Run_Call{Call: _e.mock.On("Run")}
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Run_Call) Run(run func()) *Algorithm_Run_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Run_Call) Return(_a0 error) *Algorithm_Run_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Run_Call) RunAndReturn(run func() error) *Algorithm_Run_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with no fields
|
||||
func (_m *Algorithm) Stop() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Algorithm_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type Algorithm_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *Algorithm_Expecter) Stop() *Algorithm_Stop_Call {
|
||||
return &Algorithm_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Stop_Call) Run(run func()) *Algorithm_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Stop_Call) Return(_a0 error) *Algorithm_Stop_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Algorithm_Stop_Call) RunAndReturn(run func() error) *Algorithm_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAlgorithm creates a new instance of Algorithm. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAlgorithm(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Algorithm {
|
||||
mock := &Algorithm{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
clientmocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -37,13 +38,13 @@ func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error {
|
||||
func TestManagerClient_Process(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer)
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Stop computation",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer) {
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{},
|
||||
@@ -58,7 +59,7 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Run request chunks",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer) {
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{},
|
||||
@@ -69,9 +70,37 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Agent state request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_AgentStateReq{
|
||||
AgentStateReq: &cvms.AgentStateReq{
|
||||
Id: "test-agent",
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
mockSvc.On("State").Return("test-state")
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Disconnect request",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_DisconnectReq{},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
grpcClient.On("Close").Return(nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Receive error",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer) {
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServer, grpcClient *clientmocks.Client) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{}, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
@@ -98,7 +127,7 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
tc.setupMocks(mockStream, mockSvc, mockServerSvc)
|
||||
tc.setupMocks(mockStream, mockSvc, mockServerSvc, grpcClient)
|
||||
|
||||
err = client.Process(ctx, cancel)
|
||||
|
||||
@@ -127,6 +156,19 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
Datasets: []*cvms.Dataset{
|
||||
{
|
||||
Hash: sha3.New256().Sum([]byte("test-dataset")),
|
||||
},
|
||||
},
|
||||
Algorithm: &cvms.Algorithm{
|
||||
Hash: sha3.New256().Sum([]byte("test-algorithm")),
|
||||
},
|
||||
ResultConsumers: []*cvms.ResultConsumer{
|
||||
{
|
||||
UserKey: []byte("test-consumer"),
|
||||
},
|
||||
},
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
|
||||
@@ -0,0 +1,450 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package storage
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
)
|
||||
|
||||
func createTempDir(t *testing.T) string {
|
||||
tmpDir, err := os.MkdirTemp("", "storage_test_*")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
os.RemoveAll(tmpDir)
|
||||
})
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
func createTestMessage(content string) *cvms.ClientStreamMessage {
|
||||
return &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
Error: "",
|
||||
ComputationId: content,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileStorage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storageDir string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid directory",
|
||||
storageDir: createTempDir(t),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent directory gets created",
|
||||
storageDir: filepath.Join(createTempDir(t), "subdir"),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid directory path",
|
||||
storageDir: "/invalid/path/that/cannot/be/created",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
storage, err := NewFileStorage(tt.storageDir)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, storage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, storage)
|
||||
assert.Equal(t, filepath.Join(tt.storageDir, "pending_messages.json"), storage.path)
|
||||
assert.Empty(t, storage.msgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Load(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFile func(string) error
|
||||
expectedMsgs int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "load from non-existent file",
|
||||
setupFile: func(path string) error {
|
||||
// Don't create file
|
||||
return nil
|
||||
},
|
||||
expectedMsgs: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "load from empty file",
|
||||
setupFile: func(path string) error {
|
||||
return os.WriteFile(path, []byte("[]"), 0o644)
|
||||
},
|
||||
expectedMsgs: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "load from corrupted file",
|
||||
setupFile: func(path string) error {
|
||||
return os.WriteFile(path, []byte("invalid json"), 0o644)
|
||||
},
|
||||
expectedMsgs: 0,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tt.setupFile(storage.path)
|
||||
require.NoError(t, err)
|
||||
|
||||
msgs, err := storage.Load()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, tt.expectedMsgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Save(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []Message
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "save empty messages",
|
||||
messages: []Message{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "save single message",
|
||||
messages: []Message{
|
||||
{
|
||||
Message: createTestMessage("test"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "save multiple messages",
|
||||
messages: []Message{
|
||||
{
|
||||
Message: createTestMessage("test1"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
{
|
||||
Message: createTestMessage("test2"),
|
||||
Time: time.Now().Add(time.Second),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = storage.Save(tt.messages)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify file was written correctly
|
||||
_, err := os.ReadFile(storage.path)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify internal state was updated
|
||||
assert.Equal(t, tt.messages, storage.msgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Add(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialMsgs []Message
|
||||
newMessage *cvms.ClientStreamMessage
|
||||
expectError bool
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "add to empty storage",
|
||||
initialMsgs: []Message{},
|
||||
newMessage: createTestMessage("new"),
|
||||
expectError: false,
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "add to existing messages",
|
||||
initialMsgs: []Message{
|
||||
{
|
||||
Message: createTestMessage("existing"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
},
|
||||
newMessage: createTestMessage("new"),
|
||||
expectError: false,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "add nil message",
|
||||
initialMsgs: []Message{},
|
||||
newMessage: nil,
|
||||
expectError: false,
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup initial messages
|
||||
if len(tt.initialMsgs) > 0 {
|
||||
err = storage.Save(tt.initialMsgs)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
beforeTime := time.Now()
|
||||
err = storage.Add(tt.newMessage)
|
||||
afterTime := time.Now()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify message was added to internal state
|
||||
assert.Len(t, storage.msgs, tt.expectedCount)
|
||||
|
||||
// Verify timestamp is reasonable
|
||||
if tt.expectedCount > 0 {
|
||||
lastMsg := storage.msgs[len(storage.msgs)-1]
|
||||
assert.True(t, lastMsg.Time.After(beforeTime) || lastMsg.Time.Equal(beforeTime))
|
||||
assert.True(t, lastMsg.Time.Before(afterTime) || lastMsg.Time.Equal(afterTime))
|
||||
assert.Equal(t, tt.newMessage, lastMsg.Message)
|
||||
}
|
||||
|
||||
_, err := os.ReadFile(storage.path)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_Clear(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialMsgs []Message
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "clear empty storage",
|
||||
initialMsgs: []Message{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "clear storage with messages",
|
||||
initialMsgs: []Message{
|
||||
{
|
||||
Message: createTestMessage("test1"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
{
|
||||
Message: createTestMessage("test2"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Setup initial messages
|
||||
if len(tt.initialMsgs) > 0 {
|
||||
err = storage.Save(tt.initialMsgs)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err = storage.Clear()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify internal state is cleared
|
||||
assert.Empty(t, storage.msgs)
|
||||
|
||||
// Verify file contains empty array
|
||||
data, err := os.ReadFile(storage.path)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "[]", string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileStorage_ConcurrentAccess(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test concurrent Add operations
|
||||
numGoroutines := 10
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
msg := createTestMessage(string(rune('A' + id)))
|
||||
err := storage.Add(msg)
|
||||
assert.NoError(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all messages were added
|
||||
msgs, err := storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, numGoroutines)
|
||||
}
|
||||
|
||||
func TestFileStorage_IntegrationFlow(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test full workflow
|
||||
|
||||
// 1. Load from empty storage
|
||||
msgs, err := storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, msgs)
|
||||
|
||||
// 2. Add some messages
|
||||
msg1 := createTestMessage("message1")
|
||||
err = storage.Add(msg1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
msg2 := createTestMessage("message2")
|
||||
err = storage.Add(msg2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 3. Load and verify
|
||||
msgs, err = storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, 2)
|
||||
|
||||
// 4. Save new set of messages
|
||||
newMsgs := []Message{
|
||||
{
|
||||
Message: createTestMessage("new1"),
|
||||
Time: time.Now(),
|
||||
},
|
||||
}
|
||||
err = storage.Save(newMsgs)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 5. Load and verify replacement
|
||||
msgs, err = storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msgs, 1)
|
||||
|
||||
// 6. Clear storage
|
||||
err = storage.Clear()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 7. Verify empty
|
||||
msgs, err = storage.Load()
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, msgs)
|
||||
}
|
||||
|
||||
func TestFileStorage_FilePermissions(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add a message to create the file
|
||||
msg := createTestMessage("test")
|
||||
err = storage.Add(msg)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check file permissions
|
||||
info, err := os.Stat(storage.path)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, os.FileMode(0o644), info.Mode().Perm())
|
||||
}
|
||||
|
||||
func TestFileStorage_ErrorHandling(t *testing.T) {
|
||||
tmpDir := createTempDir(t)
|
||||
storage, err := NewFileStorage(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make directory read-only to trigger write errors
|
||||
err = os.Chmod(tmpDir, 0o555)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Restore permissions for cleanup
|
||||
t.Cleanup(func() {
|
||||
if err := os.Chmod(tmpDir, 0o755); err != nil {
|
||||
t.Errorf("Failed to restore permissions: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Try to add a message - should fail due to write permissions
|
||||
msg := createTestMessage("test")
|
||||
err = storage.Add(msg)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Try to save - should fail due to write permissions
|
||||
err = storage.Save([]Message{})
|
||||
assert.Error(t, err)
|
||||
|
||||
// Try to clear - should fail due to write permissions
|
||||
err = storage.Clear()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,544 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
)
|
||||
|
||||
func setupTest(t *testing.T) (*slog.Logger, *mocks.Service, string, string, string, []byte) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))
|
||||
mockSvc := new(mocks.Service)
|
||||
host := "localhost"
|
||||
caUrl := "https://ca.example.com"
|
||||
cvmId := "test-cvm-id"
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
assert.NoError(t, err, "Failed to generate ECDSA key")
|
||||
|
||||
pubkey, err := x509.MarshalPKIXPublicKey(privateKey.Public())
|
||||
assert.NoError(t, err, "Failed to marshal public key")
|
||||
|
||||
return logger, mockSvc, host, caUrl, cvmId, pubkey
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, _ := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
host string
|
||||
caUrl string
|
||||
cvmId string
|
||||
expected AgentServer
|
||||
}{
|
||||
{
|
||||
name: "valid server creation",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
},
|
||||
{
|
||||
name: "server with empty host",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: "",
|
||||
caUrl: caUrl,
|
||||
cvmId: cvmId,
|
||||
},
|
||||
{
|
||||
name: "server with empty caUrl",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: "",
|
||||
cvmId: cvmId,
|
||||
},
|
||||
{
|
||||
name: "server with empty cvmId",
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
host: host,
|
||||
caUrl: caUrl,
|
||||
cvmId: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(tt.logger, tt.svc, tt.host, tt.caUrl, tt.cvmId)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
|
||||
agentSrv, ok := server.(*agentServer)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.logger, agentSrv.logger)
|
||||
assert.Equal(t, tt.svc, agentSrv.svc)
|
||||
assert.Equal(t, tt.host, agentSrv.host)
|
||||
assert.Equal(t, tt.caUrl, agentSrv.caUrl)
|
||||
assert.Equal(t, tt.cvmId, agentSrv.cvmId)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_Start(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg agent.AgentConfig
|
||||
cmp agent.Computation
|
||||
setupMocks func(*mocks.Service)
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "successful start with default port",
|
||||
cfg: agent.AgentConfig{
|
||||
Port: "",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: true,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-1",
|
||||
Name: "Test Computation",
|
||||
Description: "A test computation",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x01, 0x02, 0x03},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x04, 0x05, 0x06},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupMocks: func(m *mocks.Service) {
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "successful start with custom port",
|
||||
cfg: agent.AgentConfig{
|
||||
Port: "8080",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: false,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-2",
|
||||
Name: "Test Computation 2",
|
||||
Description: "Another test computation",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x07, 0x08, 0x09},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x0a, 0x0b, 0x0c},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupMocks: func(m *mocks.Service) {
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "start with minimal config",
|
||||
cfg: agent.AgentConfig{
|
||||
Port: "9090",
|
||||
AttestedTls: false,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "test-computation-3",
|
||||
Name: "Minimal Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x0d, 0x0e, 0x0f},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x10, 0x11, 0x12},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
setupMocks: func(m *mocks.Service) {
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupMocks(svc)
|
||||
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
|
||||
err := server.Start(tt.cfg, tt.cmp)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify the port was set correctly
|
||||
agentSrv := server.(*agentServer)
|
||||
assert.NotNil(t, agentSrv.gs)
|
||||
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Fatalf("Failed to stop server after start: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_Stop(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupServer func(AgentServer) error
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "stop unstarted server",
|
||||
setupServer: func(server AgentServer) error {
|
||||
// Don't start the server
|
||||
return nil
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "stop started server",
|
||||
setupServer: func(server AgentServer) error {
|
||||
cfg := agent.AgentConfig{
|
||||
Port: "7004",
|
||||
}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-stop-computation",
|
||||
Name: "Stop Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x19, 0x1a, 0x1b},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x1c, 0x1d, 0x1e},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
return server.Start(cfg, cmp)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
|
||||
err := tt.setupServer(server)
|
||||
if err != nil {
|
||||
t.Fatalf("Setup failed: %v", err)
|
||||
}
|
||||
|
||||
// Give the server a moment to start if it was started
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = server.Stop()
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentServer_StopMultipleTimes(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
|
||||
// Start the server
|
||||
cfg := agent.AgentConfig{Port: "7005"}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-multiple-stop",
|
||||
Name: "Multiple Stop Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x1f, 0x20, 0x21},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x22, 0x23, 0x24},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := server.Start(cfg, cmp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Give the server a moment to start
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Stop the server multiple times
|
||||
err1 := server.Stop()
|
||||
err2 := server.Stop()
|
||||
err3 := server.Stop()
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NoError(t, err3)
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAgentServer_StartAfterStop(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
|
||||
cfg := agent.AgentConfig{Port: "7006"}
|
||||
cmp := agent.Computation{
|
||||
ID: "test-restart",
|
||||
Name: "Restart Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x25, 0x26, 0x27},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x28, 0x29, 0x2a},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start, stop, then start again
|
||||
err := server.Start(cfg, cmp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start again with different config
|
||||
cfg2 := agent.AgentConfig{Port: "7007"}
|
||||
cmp2 := agent.Computation{
|
||||
ID: "test-restart-2",
|
||||
Name: "Restart Test 2",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x2b, 0x2c, 0x2d},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x2e, 0x2f, 0x30},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = server.Start(cfg2, cmp2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAgentServer_ConfigValidation(t *testing.T) {
|
||||
logger, svc, host, caUrl, cvmId, pubKey := setupTest(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config agent.AgentConfig
|
||||
cmp agent.Computation
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "valid config with all fields",
|
||||
config: agent.AgentConfig{
|
||||
Port: "8080",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server-ca.pem",
|
||||
ClientCAFile: "client-ca.pem",
|
||||
AttestedTls: true,
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "valid-config-test",
|
||||
Name: "Valid Config Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x31, 0x32, 0x33},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x34, 0x35, 0x36},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid config with minimal fields",
|
||||
config: agent.AgentConfig{
|
||||
Port: "9090",
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "minimal-config-test",
|
||||
Name: "Minimal Config Test",
|
||||
Algorithm: agent.Algorithm{
|
||||
Hash: [32]byte{0x37, 0x38, 0x39},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
Datasets: []agent.Dataset{
|
||||
{
|
||||
Hash: [32]byte{0x3a, 0x3b, 0x3c},
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{
|
||||
UserKey: pubKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "config with empty port uses default",
|
||||
config: agent.AgentConfig{
|
||||
Port: "",
|
||||
},
|
||||
cmp: agent.Computation{
|
||||
ID: "default-port-test",
|
||||
Name: "Default Port Test",
|
||||
Algorithm: agent.Algorithm{Hash: [32]byte{0x3d, 0x3e, 0x3f}, UserKey: pubKey},
|
||||
Datasets: []agent.Dataset{
|
||||
{Hash: [32]byte{0x40, 0x41, 0x42}, UserKey: pubKey},
|
||||
},
|
||||
ResultConsumers: []agent.ResultConsumer{
|
||||
{UserKey: pubKey},
|
||||
},
|
||||
},
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := NewServer(logger, svc, host, caUrl, cvmId)
|
||||
|
||||
err := server.Start(tt.config, tt.cmp)
|
||||
|
||||
if tt.valid {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify default port is used when empty
|
||||
if tt.config.Port == "" {
|
||||
agentSrv := server.(*agentServer)
|
||||
assert.NotNil(t, agentSrv.gs)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Fatalf("Failed to stop server after start: %v", err)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
svc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
assert.Equal(t, "agent", svcName)
|
||||
assert.Equal(t, "7002", defSvcGRPCPort)
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
|
||||
metadata "google.golang.org/grpc/metadata"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// AgentService_IMAMeasurementsClient is an autogenerated mock type for the AgentService_IMAMeasurementsClient type
|
||||
type AgentService_IMAMeasurementsClient[Res interface{}] struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentService_IMAMeasurementsClient_Expecter[Res interface{}] struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) EXPECT() *AgentService_IMAMeasurementsClient_Expecter[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_Expecter[Res]{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// CloseSend provides a mock function with no fields
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) CloseSend() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CloseSend")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_CloseSend_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSend'
|
||||
type AgentService_IMAMeasurementsClient_CloseSend_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// CloseSend is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) CloseSend() *AgentService_IMAMeasurementsClient_CloseSend_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_CloseSend_Call[Res]{Call: _e.mock.On("CloseSend")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_CloseSend_Call[Res]) Run(run func()) *AgentService_IMAMeasurementsClient_CloseSend_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_CloseSend_Call[Res]) Return(_a0 error) *AgentService_IMAMeasurementsClient_CloseSend_Call[Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_CloseSend_Call[Res]) RunAndReturn(run func() error) *AgentService_IMAMeasurementsClient_CloseSend_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Context provides a mock function with no fields
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) Context() context.Context {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Context")
|
||||
}
|
||||
|
||||
var r0 context.Context
|
||||
if rf, ok := ret.Get(0).(func() context.Context); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(context.Context)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context'
|
||||
type AgentService_IMAMeasurementsClient_Context_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Context is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) Context() *AgentService_IMAMeasurementsClient_Context_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_Context_Call[Res]{Call: _e.mock.On("Context")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Context_Call[Res]) Run(run func()) *AgentService_IMAMeasurementsClient_Context_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Context_Call[Res]) Return(_a0 context.Context) *AgentService_IMAMeasurementsClient_Context_Call[Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Context_Call[Res]) RunAndReturn(run func() context.Context) *AgentService_IMAMeasurementsClient_Context_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Header provides a mock function with no fields
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) Header() (metadata.MD, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Header")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() (metadata.MD, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Header_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Header'
|
||||
type AgentService_IMAMeasurementsClient_Header_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Header is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) Header() *AgentService_IMAMeasurementsClient_Header_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_Header_Call[Res]{Call: _e.mock.On("Header")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Header_Call[Res]) Run(run func()) *AgentService_IMAMeasurementsClient_Header_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Header_Call[Res]) Return(_a0 metadata.MD, _a1 error) *AgentService_IMAMeasurementsClient_Header_Call[Res] {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Header_Call[Res]) RunAndReturn(run func() (metadata.MD, error)) *AgentService_IMAMeasurementsClient_Header_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Recv provides a mock function with no fields
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) Recv() (*agent.IMAMeasurementsResponse, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Recv")
|
||||
}
|
||||
|
||||
var r0 *agent.IMAMeasurementsResponse
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() (*agent.IMAMeasurementsResponse, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() *agent.IMAMeasurementsResponse); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*agent.IMAMeasurementsResponse)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Recv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recv'
|
||||
type AgentService_IMAMeasurementsClient_Recv_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Recv is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) Recv() *AgentService_IMAMeasurementsClient_Recv_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_Recv_Call[Res]{Call: _e.mock.On("Recv")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Recv_Call[Res]) Run(run func()) *AgentService_IMAMeasurementsClient_Recv_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Recv_Call[Res]) Return(_a0 *agent.IMAMeasurementsResponse, _a1 error) *AgentService_IMAMeasurementsClient_Recv_Call[Res] {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Recv_Call[Res]) RunAndReturn(run func() (*agent.IMAMeasurementsResponse, error)) *AgentService_IMAMeasurementsClient_Recv_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RecvMsg provides a mock function with given fields: m
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) RecvMsg(m interface{}) error {
|
||||
ret := _m.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RecvMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg'
|
||||
type AgentService_IMAMeasurementsClient_RecvMsg_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RecvMsg is a helper method to define mock.On call
|
||||
// - m interface{}
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) RecvMsg(m interface{}) *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_RecvMsg_Call[Res]{Call: _e.mock.On("RecvMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res]) Run(run func(m interface{})) *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res]) Return(_a0 error) *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res]) RunAndReturn(run func(interface{}) error) *AgentService_IMAMeasurementsClient_RecvMsg_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendMsg provides a mock function with given fields: m
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) SendMsg(m interface{}) error {
|
||||
ret := _m.Called(m)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendMsg")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(interface{}) error); ok {
|
||||
r0 = rf(m)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg'
|
||||
type AgentService_IMAMeasurementsClient_SendMsg_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendMsg is a helper method to define mock.On call
|
||||
// - m interface{}
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) SendMsg(m interface{}) *AgentService_IMAMeasurementsClient_SendMsg_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_SendMsg_Call[Res]{Call: _e.mock.On("SendMsg", m)}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_SendMsg_Call[Res]) Run(run func(m interface{})) *AgentService_IMAMeasurementsClient_SendMsg_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(interface{}))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_SendMsg_Call[Res]) Return(_a0 error) *AgentService_IMAMeasurementsClient_SendMsg_Call[Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_SendMsg_Call[Res]) RunAndReturn(run func(interface{}) error) *AgentService_IMAMeasurementsClient_SendMsg_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Trailer provides a mock function with no fields
|
||||
func (_m *AgentService_IMAMeasurementsClient[Res]) Trailer() metadata.MD {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Trailer")
|
||||
}
|
||||
|
||||
var r0 metadata.MD
|
||||
if rf, ok := ret.Get(0).(func() metadata.MD); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(metadata.MD)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentService_IMAMeasurementsClient_Trailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Trailer'
|
||||
type AgentService_IMAMeasurementsClient_Trailer_Call[Res interface{}] struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Trailer is a helper method to define mock.On call
|
||||
func (_e *AgentService_IMAMeasurementsClient_Expecter[Res]) Trailer() *AgentService_IMAMeasurementsClient_Trailer_Call[Res] {
|
||||
return &AgentService_IMAMeasurementsClient_Trailer_Call[Res]{Call: _e.mock.On("Trailer")}
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Trailer_Call[Res]) Run(run func()) *AgentService_IMAMeasurementsClient_Trailer_Call[Res] {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Trailer_Call[Res]) Return(_a0 metadata.MD) *AgentService_IMAMeasurementsClient_Trailer_Call[Res] {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentService_IMAMeasurementsClient_Trailer_Call[Res]) RunAndReturn(run func() metadata.MD) *AgentService_IMAMeasurementsClient_Trailer_Call[Res] {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAgentService_IMAMeasurementsClient creates a new instance of AgentService_IMAMeasurementsClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentService_IMAMeasurementsClient[Res interface{}](t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentService_IMAMeasurementsClient[Res] {
|
||||
mock := &AgentService_IMAMeasurementsClient[Res]{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"slices"
|
||||
sync "sync"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
@@ -183,8 +184,12 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, prov
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
sm.SendEvent(Start)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -257,8 +262,12 @@ func (as *agentService) StopComputation(ctx context.Context) error {
|
||||
as.logger.Error(err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
as.sm.SendEvent(Start)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+194
-1
@@ -5,8 +5,10 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,6 +18,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
algomocks "github.com/ultravioletrs/cocos/agent/algorithm/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
@@ -389,14 +392,20 @@ func TestAttestation(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
getQuote := provider.On("TeeAttestation", mock.Anything).Return(tc.rawQuote, tc.err)
|
||||
vtpmQuote := provider.On("VTpmAttestation", mock.Anything).Return(tc.rawQuote, tc.err)
|
||||
snpVtpm := provider.On("Attestation", mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err)
|
||||
if tc.err != ErrAttestationFailed && tc.err != ErrAttestationVTpmFailed {
|
||||
getQuote = provider.On("TeeAttestation", mock.Anything).Return(tc.nonce, nil)
|
||||
vtpmQuote = provider.On("VTpmAttestation", mock.Anything).Return(tc.nonce[:], nil)
|
||||
snpVtpm = provider.On("Attestation", mock.Anything, mock.Anything).Return(tc.nonce[:], nil)
|
||||
}
|
||||
defer getQuote.Unset()
|
||||
defer vtpmQuote.Unset()
|
||||
defer snpVtpm.Unset()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, provider, 0)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, 0)
|
||||
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
@@ -483,3 +492,187 @@ func testComputation(t *testing.T) Computation {
|
||||
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopComputation(t *testing.T) {
|
||||
testDataDir := "test_datasets"
|
||||
testResultsDir := "test_results"
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
setupDirs bool
|
||||
setupAlgo bool
|
||||
algoStopErr error
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Stop computation successfully",
|
||||
setupDirs: true,
|
||||
setupAlgo: true,
|
||||
algoStopErr: nil,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Stop computation with algorithm stop error",
|
||||
setupDirs: true,
|
||||
setupAlgo: true,
|
||||
algoStopErr: fmt.Errorf("algorithm stop failed"),
|
||||
expectedErr: fmt.Errorf("error stopping computation: algorithm stop failed"),
|
||||
},
|
||||
{
|
||||
name: "Stop computation without algorithm",
|
||||
setupDirs: true,
|
||||
setupAlgo: false,
|
||||
algoStopErr: nil,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Stop computation with missing directories",
|
||||
setupDirs: false,
|
||||
setupAlgo: false,
|
||||
algoStopErr: nil,
|
||||
expectedErr: nil, // os.RemoveAll doesn't error on non-existing directories
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
events.On("SendEvent", mock.Anything, "Stopped", "Stopped", mock.Anything).Return()
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0).(*agentService)
|
||||
|
||||
svc.computation = Computation{
|
||||
ID: "test-computation",
|
||||
Name: "test",
|
||||
}
|
||||
|
||||
if tc.setupDirs {
|
||||
err := os.MkdirAll(testDataDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
err = os.MkdirAll(testResultsDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
if tc.setupAlgo {
|
||||
mockAlgo := new(algomocks.Algorithm)
|
||||
mockAlgo.On("Stop").Return(tc.algoStopErr)
|
||||
svc.algorithm = mockAlgo
|
||||
}
|
||||
|
||||
err := svc.StopComputation(ctx)
|
||||
|
||||
if tc.expectedErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedErr.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, ReceivingManifest, svc.sm.GetState())
|
||||
assert.Nil(t, svc.result)
|
||||
assert.Nil(t, svc.runError)
|
||||
assert.False(t, svc.resultsConsumed)
|
||||
|
||||
events.AssertExpectations(t)
|
||||
|
||||
_ = os.RemoveAll(testDataDir)
|
||||
_ = os.RemoveAll(testResultsDir)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopComputationIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
algo := []byte("#!/bin/bash\necho 'test algorithm'")
|
||||
algoHash := sha3.Sum256(algo)
|
||||
|
||||
testDir := "test_integration"
|
||||
err := os.MkdirAll(testDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(testDir)
|
||||
|
||||
algoFile := filepath.Join(testDir, "test_algo")
|
||||
err = os.WriteFile(algoFile, algo, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
events := new(mocks.Service)
|
||||
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||||
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, "bin"),
|
||||
)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0)
|
||||
|
||||
computation := Computation{
|
||||
ID: "integration-test",
|
||||
Name: "Integration Test",
|
||||
Algorithm: Algorithm{
|
||||
Hash: algoHash,
|
||||
Algorithm: algo,
|
||||
},
|
||||
}
|
||||
|
||||
err = svc.InitComputation(ctx, computation)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = svc.Algo(ctx, Algorithm{
|
||||
Hash: algoHash,
|
||||
Algorithm: algo,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err = svc.StopComputation(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "ReceivingManifest", svc.State())
|
||||
}
|
||||
|
||||
func TestStopComputationConcurrent(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0)
|
||||
|
||||
svc.(*agentService).computation = Computation{
|
||||
ID: "concurrent-test",
|
||||
Name: "Concurrent Test",
|
||||
}
|
||||
|
||||
const numGoroutines = 10
|
||||
errChan := make(chan error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
err := svc.StopComputation(ctx)
|
||||
errChan <- err
|
||||
}()
|
||||
}
|
||||
|
||||
var errors []error
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errChan
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, len(errors) < numGoroutines, "All StopComputation calls failed")
|
||||
}
|
||||
|
||||
+9
-3
@@ -54,10 +54,11 @@ func TestAddTransition(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sm.SendEvent(Event1)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
@@ -79,7 +80,7 @@ func TestSetAction(t *testing.T) {
|
||||
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -88,8 +89,12 @@ func TestSetAction(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sm.SendEvent(Event1)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
@@ -132,10 +137,11 @@ func TestMultipleTransitions(t *testing.T) {
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
transitions := []struct {
|
||||
event MockEvent
|
||||
want MockState
|
||||
|
||||
@@ -39,6 +39,7 @@ type stateMachine struct {
|
||||
transitions map[State]map[Event]State
|
||||
actions map[State]Action
|
||||
eventChan chan Event
|
||||
resetChan chan struct{}
|
||||
}
|
||||
|
||||
func NewStateMachine(initialState State) StateMachine {
|
||||
@@ -47,6 +48,7 @@ func NewStateMachine(initialState State) StateMachine {
|
||||
transitions: make(map[State]map[Event]State),
|
||||
actions: make(map[State]Action),
|
||||
eventChan: make(chan Event),
|
||||
resetChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,16 +76,31 @@ func (sm *stateMachine) GetState() State {
|
||||
}
|
||||
|
||||
func (sm *stateMachine) SendEvent(event Event) {
|
||||
sm.eventChan <- event
|
||||
sm.mu.Lock()
|
||||
eventChan := sm.eventChan
|
||||
sm.mu.Unlock()
|
||||
|
||||
select {
|
||||
case eventChan <- event:
|
||||
default:
|
||||
// Channel might be closed or full, ignore the event
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *stateMachine) Start(ctx context.Context) error {
|
||||
for {
|
||||
sm.mu.Lock()
|
||||
eventChan := sm.eventChan
|
||||
resetChan := sm.resetChan
|
||||
sm.mu.Unlock()
|
||||
|
||||
select {
|
||||
case event := <-sm.eventChan:
|
||||
case event := <-eventChan:
|
||||
if err := sm.handleEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-resetChan:
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -100,8 +117,11 @@ func (sm *stateMachine) Reset(initialState State) {
|
||||
// Close the existing event channel to stop processing events
|
||||
close(sm.eventChan)
|
||||
|
||||
// Create a new event channel
|
||||
// Close the reset channel to signal Start() to restart
|
||||
close(sm.resetChan)
|
||||
|
||||
sm.eventChan = make(chan Event)
|
||||
sm.resetChan = make(chan struct{})
|
||||
}
|
||||
|
||||
func (sm *stateMachine) handleEvent(event Event) error {
|
||||
|
||||
@@ -0,0 +1,607 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package statemachine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testState string
|
||||
|
||||
func (s testState) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
type testEvent string
|
||||
|
||||
func (e testEvent) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
const (
|
||||
StateIdle testState = "idle"
|
||||
StateRunning testState = "running"
|
||||
StatePaused testState = "paused"
|
||||
StateStopped testState = "stopped"
|
||||
StateError testState = "error"
|
||||
)
|
||||
|
||||
const (
|
||||
EventStart testEvent = "start"
|
||||
EventPause testEvent = "pause"
|
||||
EventStop testEvent = "stop"
|
||||
EventReset testEvent = "reset"
|
||||
EventError testEvent = "error"
|
||||
)
|
||||
|
||||
func TestNewStateMachine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
want State
|
||||
}{
|
||||
{
|
||||
name: "create with idle state",
|
||||
initialState: StateIdle,
|
||||
want: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "create with running state",
|
||||
initialState: StateRunning,
|
||||
want: StateRunning,
|
||||
},
|
||||
{
|
||||
name: "create with custom state",
|
||||
initialState: testState("custom"),
|
||||
want: testState("custom"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
if got := sm.GetState(); got != tt.want {
|
||||
t.Errorf("NewStateMachine() initial state = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_AddTransition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
transitions []Transition
|
||||
from State
|
||||
event Event
|
||||
expectTo State
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "single transition",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
from: StateIdle,
|
||||
event: EventStart,
|
||||
expectTo: StateRunning,
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "multiple transitions from same state",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateIdle, Event: EventError, To: StateError},
|
||||
},
|
||||
from: StateIdle,
|
||||
event: EventError,
|
||||
expectTo: StateError,
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "overwrite existing transition",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateIdle, Event: EventStart, To: StatePaused}, // Overwrite
|
||||
},
|
||||
from: StateIdle,
|
||||
event: EventStart,
|
||||
expectTo: StatePaused,
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "transition not found",
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
from: StateRunning,
|
||||
event: EventPause,
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle).(*stateMachine)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
nextState, valid := sm.transitions[tt.from][tt.event]
|
||||
sm.mu.Unlock()
|
||||
|
||||
if valid != tt.expectValid {
|
||||
t.Errorf("Transition validity = %v, want %v", valid, tt.expectValid)
|
||||
}
|
||||
|
||||
if tt.expectValid && nextState != tt.expectTo {
|
||||
t.Errorf("Transition destination = %v, want %v", nextState, tt.expectTo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_SetAction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
state State
|
||||
action Action
|
||||
expectAction bool
|
||||
}{
|
||||
{
|
||||
name: "set action for state",
|
||||
state: StateRunning,
|
||||
action: func(s State) {
|
||||
},
|
||||
expectAction: true,
|
||||
},
|
||||
{
|
||||
name: "set nil action",
|
||||
state: StatePaused,
|
||||
action: nil,
|
||||
expectAction: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle).(*stateMachine)
|
||||
sm.SetAction(tt.state, tt.action)
|
||||
|
||||
sm.mu.Lock()
|
||||
action := sm.actions[tt.state]
|
||||
sm.mu.Unlock()
|
||||
|
||||
if tt.expectAction && action == nil {
|
||||
t.Error("Expected action to be set, but it was nil")
|
||||
}
|
||||
if !tt.expectAction && action != nil {
|
||||
t.Error("Expected action to be nil, but it was set")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_GetState(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
transitions []Transition
|
||||
events []Event
|
||||
finalState State
|
||||
}{
|
||||
{
|
||||
name: "get initial state",
|
||||
initialState: StateIdle,
|
||||
finalState: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "get state after transition",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
events: []Event{EventStart},
|
||||
finalState: StateRunning,
|
||||
},
|
||||
{
|
||||
name: "get state after multiple transitions",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateRunning, Event: EventPause, To: StatePaused},
|
||||
{From: StatePaused, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
events: []Event{EventStart, EventPause, EventStart},
|
||||
finalState: StateRunning,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
smImpl := sm.(*stateMachine)
|
||||
for _, event := range tt.events {
|
||||
if err := smImpl.handleEvent(event); err != nil {
|
||||
t.Fatalf("Failed to handle event %v: %v", event, err)
|
||||
}
|
||||
}
|
||||
|
||||
if got := sm.GetState(); got != tt.finalState {
|
||||
t.Errorf("GetState() = %v, want %v", got, tt.finalState)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_Start(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
transitions []Transition
|
||||
events []Event
|
||||
cancelAfter time.Duration
|
||||
expectError bool
|
||||
expectedStates []State
|
||||
}{
|
||||
{
|
||||
name: "start and cancel immediately",
|
||||
initialState: StateIdle,
|
||||
cancelAfter: 10 * time.Millisecond,
|
||||
expectError: true, // context.Canceled
|
||||
},
|
||||
{
|
||||
name: "process events then cancel",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateRunning, Event: EventStop, To: StateStopped},
|
||||
},
|
||||
events: []Event{EventStart, EventStop},
|
||||
cancelAfter: 100 * time.Millisecond,
|
||||
expectError: true, // context.Canceled
|
||||
expectedStates: []State{StateRunning, StateStopped},
|
||||
},
|
||||
{
|
||||
name: "invalid transition error",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
events: []Event{EventPause}, // Invalid from StateIdle
|
||||
cancelAfter: 50 * time.Millisecond,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
var states []State
|
||||
var mu sync.Mutex
|
||||
|
||||
for _, state := range tt.expectedStates {
|
||||
sm.SetAction(state, func(s State) {
|
||||
mu.Lock()
|
||||
states = append(states, s)
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- sm.Start(ctx)
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
for _, event := range tt.events {
|
||||
sm.SendEvent(event)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
time.Sleep(tt.cancelAfter)
|
||||
cancel()
|
||||
|
||||
err := <-errChan
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
if len(states) != len(tt.expectedStates) {
|
||||
t.Errorf("Expected %d state changes, got %d", len(tt.expectedStates), len(states))
|
||||
}
|
||||
for i, expectedState := range tt.expectedStates {
|
||||
if i < len(states) && states[i] != expectedState {
|
||||
t.Errorf("State change %d = %v, want %v", i, states[i], expectedState)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_Reset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
resetState State
|
||||
setupTransitions []Transition
|
||||
eventsBeforeReset []Event
|
||||
eventsAfterReset []Event
|
||||
expectedState State
|
||||
}{
|
||||
{
|
||||
name: "reset to same state",
|
||||
initialState: StateIdle,
|
||||
resetState: StateIdle,
|
||||
expectedState: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "reset to different state",
|
||||
initialState: StateIdle,
|
||||
resetState: StateRunning,
|
||||
expectedState: StateRunning,
|
||||
},
|
||||
{
|
||||
name: "reset after state changes",
|
||||
initialState: StateIdle,
|
||||
resetState: StateIdle,
|
||||
setupTransitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
eventsBeforeReset: []Event{EventStart},
|
||||
expectedState: StateIdle,
|
||||
},
|
||||
{
|
||||
name: "reset and send new events",
|
||||
initialState: StateIdle,
|
||||
resetState: StateIdle,
|
||||
setupTransitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
{From: StateRunning, Event: EventStop, To: StateStopped},
|
||||
},
|
||||
eventsBeforeReset: []Event{EventStart},
|
||||
eventsAfterReset: []Event{EventStart},
|
||||
expectedState: StateIdle,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState)
|
||||
smImpl := sm.(*stateMachine)
|
||||
|
||||
for _, transition := range tt.setupTransitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
for _, event := range tt.eventsBeforeReset {
|
||||
if err := smImpl.handleEvent(event); err != nil {
|
||||
// Ignore errors for this test
|
||||
}
|
||||
}
|
||||
|
||||
sm.Reset(tt.resetState)
|
||||
|
||||
if got := sm.GetState(); got != tt.expectedState {
|
||||
t.Errorf("State after reset = %v, want %v", got, tt.expectedState)
|
||||
}
|
||||
|
||||
for _, event := range tt.eventsAfterReset {
|
||||
sm.SendEvent(event)
|
||||
}
|
||||
|
||||
// For events after reset, we can't easily check the channel length
|
||||
// due to the synchronization changes, so we just verify the reset worked
|
||||
if len(tt.eventsAfterReset) > 0 {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_Reset_WithRunningStateMachine(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle)
|
||||
sm.AddTransition(Transition{From: StateIdle, Event: EventStart, To: StateRunning})
|
||||
sm.AddTransition(Transition{From: StateRunning, Event: EventStop, To: StateStopped})
|
||||
|
||||
var stateChanges []State
|
||||
var mu sync.Mutex
|
||||
|
||||
sm.SetAction(StateRunning, func(s State) {
|
||||
mu.Lock()
|
||||
stateChanges = append(stateChanges, s)
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
sm.SetAction(StateStopped, func(s State) {
|
||||
mu.Lock()
|
||||
stateChanges = append(stateChanges, s)
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != nil {
|
||||
}
|
||||
}()
|
||||
|
||||
// Give it time to start
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
// Send an event
|
||||
sm.SendEvent(EventStart)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Reset while running
|
||||
sm.Reset(StateIdle)
|
||||
|
||||
// Verify state was reset
|
||||
if got := sm.GetState(); got != StateIdle {
|
||||
t.Errorf("State after reset = %v, want %v", got, StateIdle)
|
||||
}
|
||||
|
||||
// Send another event after reset
|
||||
sm.SendEvent(EventStart)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
changes := len(stateChanges)
|
||||
mu.Unlock()
|
||||
|
||||
// Should have at least processed the first event
|
||||
if changes < 1 {
|
||||
t.Errorf("Expected at least 1 state change, got %d", changes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_HandleEvent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState State
|
||||
transitions []Transition
|
||||
event Event
|
||||
expectedState State
|
||||
expectError bool
|
||||
expectActionCall bool
|
||||
}{
|
||||
{
|
||||
name: "valid transition",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
event: EventStart,
|
||||
expectedState: StateRunning,
|
||||
expectError: false,
|
||||
expectActionCall: true,
|
||||
},
|
||||
{
|
||||
name: "invalid transition",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateRunning, Event: EventPause, To: StatePaused},
|
||||
},
|
||||
event: EventStart,
|
||||
expectedState: StateIdle,
|
||||
expectError: true,
|
||||
expectActionCall: false,
|
||||
},
|
||||
{
|
||||
name: "transition with no action",
|
||||
initialState: StateIdle,
|
||||
transitions: []Transition{
|
||||
{From: StateIdle, Event: EventStart, To: StateRunning},
|
||||
},
|
||||
event: EventStart,
|
||||
expectedState: StateRunning,
|
||||
expectError: false,
|
||||
expectActionCall: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sm := NewStateMachine(tt.initialState).(*stateMachine)
|
||||
|
||||
for _, transition := range tt.transitions {
|
||||
sm.AddTransition(transition)
|
||||
}
|
||||
|
||||
var actionCalled bool
|
||||
var mu sync.Mutex
|
||||
|
||||
if tt.expectActionCall {
|
||||
sm.SetAction(tt.expectedState, func(s State) {
|
||||
mu.Lock()
|
||||
actionCalled = true
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
err := sm.handleEvent(tt.event)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if sm.GetState() != tt.expectedState {
|
||||
t.Errorf("State after handleEvent = %v, want %v", sm.GetState(), tt.expectedState)
|
||||
}
|
||||
|
||||
if tt.expectActionCall {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
mu.Lock()
|
||||
called := actionCalled
|
||||
mu.Unlock()
|
||||
if !called {
|
||||
t.Error("Expected action to be called but it wasn't")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachine_SendEvent_ThreadSafety(t *testing.T) {
|
||||
sm := NewStateMachine(StateIdle)
|
||||
sm.AddTransition(Transition{From: StateIdle, Event: EventStart, To: StateRunning})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != nil {
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
eventsPerGoroutine := 100
|
||||
|
||||
// Send events concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < eventsPerGoroutine; j++ {
|
||||
sm.SendEvent(EventStart)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// If we reach here without panicking, the test passes
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"os"
|
||||
@@ -131,3 +132,247 @@ func TestNewAddHostDataCmd(t *testing.T) {
|
||||
assert.Equal(t, "hostdata <host-data> <attestation_policy.json>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
}
|
||||
|
||||
func TestChangeAttestationConfigurationFileErrors(t *testing.T) {
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
err := changeAttestationConfiguration("nonexistent.json", base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), measurementLength, measurementField)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "error while reading the attestation policy file")
|
||||
})
|
||||
|
||||
t.Run("Invalid JSON Content", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "invalid.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = changeAttestationConfiguration(tmpfile.Name(), base64.StdEncoding.EncodeToString(make([]byte, measurementLength)), measurementLength, measurementField)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal json")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewGCPAttestationPolicy(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewGCPAttestationPolicy()
|
||||
|
||||
assert.Equal(t, "gcp", cmd.Use)
|
||||
assert.Equal(t, "Get attestation policy for GCP CVM", cmd.Short)
|
||||
assert.Equal(t, "gcp <bin_vtmp_attestation_report_file> <vcpu_count>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{"nonexistent.bin", "4"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error reading attestation report file")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Invalid vCPU Count", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation.bin")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("dummy content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name(), "invalid"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error converting vCPU count to integer")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Invalid Attestation Data", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation.bin")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("invalid protobuf data"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name(), "4"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error unmarshaling attestation report")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewDownloadGCPOvmfFile(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewDownloadGCPOvmfFile()
|
||||
|
||||
assert.Equal(t, "download", cmd.Use)
|
||||
assert.Equal(t, "Download GCP OVMF file", cmd.Short)
|
||||
assert.Equal(t, "download <bin_vtmp_attestation_report_file>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{"nonexistent.bin"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error reading attestation report file")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Invalid Attestation Data", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation.bin")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("invalid protobuf data"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name()})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error unmarshaling attestation report")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewAzureAttestationPolicy(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewAzureAttestationPolicy()
|
||||
|
||||
assert.Equal(t, "azure", cmd.Use)
|
||||
assert.Equal(t, "Get attestation policy for Azure CVM", cmd.Short)
|
||||
assert.Equal(t, "azure <azure_maa_token_file> <product_name>", cmd.Example)
|
||||
assert.NotNil(t, cmd.Run)
|
||||
|
||||
flag := cmd.Flags().Lookup("policy")
|
||||
assert.NotNil(t, flag)
|
||||
assert.Equal(t, "Policy of the guest CVM", flag.Usage)
|
||||
|
||||
t.Run("File Not Found", func(t *testing.T) {
|
||||
cmd.SetArgs([]string{"nonexistent.token", "test-product"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error reading attestation report file")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Valid Token File", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "token.maa")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("dummy.token.content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer os.Remove("attestation_policy.json")
|
||||
|
||||
cmd.SetArgs([]string{tmpfile.Name(), "test-product"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Custom Policy Flag", func(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "token.maa")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
err = os.WriteFile(tmpfile.Name(), []byte("dummy.token.content"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd.SetArgs([]string{"--policy", "123456", tmpfile.Name(), "test-product"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
flag := cmd.Flags().Lookup("policy")
|
||||
assert.NotNil(t, flag)
|
||||
assert.Equal(t, "123456", flag.Value.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestCommandErrorHandling(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
|
||||
t.Run("Measurement Command Error", func(t *testing.T) {
|
||||
cmd := cli.NewAddMeasurementCmd()
|
||||
cmd.SetArgs([]string{"invalid-base64", "nonexistent.json"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error could not change measurement data")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
|
||||
t.Run("Host Data Command Error", func(t *testing.T) {
|
||||
cmd := cli.NewAddHostDataCmd()
|
||||
cmd.SetArgs([]string{"invalid-base64", "nonexistent.json"})
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, "Error could not change host data")
|
||||
assert.Contains(t, output, "❌")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -331,6 +331,7 @@ func parseUints() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.Policy.Product.MachineStepping = wrapperspb.UInt32(uint32(num))
|
||||
} else {
|
||||
num, err := strconv.ParseUint(stepping[2:], base, 8)
|
||||
|
||||
@@ -0,0 +1,870 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
tpmAttest "github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/google/go-tpm-tools/proto/tpm"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
|
||||
"google.golang.org/protobuf/encoding/prototext"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestAddSEVSNPVerificationOptions(t *testing.T) {
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
}
|
||||
|
||||
result := addSEVSNPVerificationOptions(cmd)
|
||||
|
||||
assert.Equal(t, cmd, result)
|
||||
|
||||
// Check that important flags are added
|
||||
flags := []string{
|
||||
"host_data",
|
||||
"family_id",
|
||||
"image_id",
|
||||
"report_id",
|
||||
"report_id_ma",
|
||||
"measurement",
|
||||
"chip_id",
|
||||
"minimum_tcb",
|
||||
"minimum_lauch_tcb",
|
||||
"guest_policy",
|
||||
"minimum_guest_svn",
|
||||
"minimum_build",
|
||||
"check_crl",
|
||||
"timeout",
|
||||
"max_retry_delay",
|
||||
"require_author_key",
|
||||
"require_id_block",
|
||||
"platform_info",
|
||||
"minimum_version",
|
||||
"trusted_author_keys",
|
||||
"trusted_author_key_hashes",
|
||||
"trusted_id_keys",
|
||||
"trusted_id_key_hashes",
|
||||
"product",
|
||||
"stepping",
|
||||
"CA_bundles_paths",
|
||||
"CA_bundles",
|
||||
}
|
||||
|
||||
for _, flagName := range flags {
|
||||
flag := cmd.Flags().Lookup(flagName)
|
||||
assert.NotNil(t, flag, "Flag %s should exist", flagName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupCfg func()
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid empty config",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "CA bundles without product name",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{},
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
CabundlePaths: []string{"test.pem"},
|
||||
ProductLine: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "product name must be set if CA bundles are provided",
|
||||
},
|
||||
{
|
||||
name: "invalid report_data length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
ReportData: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "report_data",
|
||||
},
|
||||
{
|
||||
name: "invalid host_data length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
HostData: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "host_data",
|
||||
},
|
||||
{
|
||||
name: "invalid family_id length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
FamilyId: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "family_id",
|
||||
},
|
||||
{
|
||||
name: "invalid image_id length",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
ImageId: []byte("invalid"),
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "image_id",
|
||||
},
|
||||
{
|
||||
name: "invalid trusted author key hash",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
TrustedAuthorKeyHashes: [][]byte{[]byte("invalid")},
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "trusted_author_key_hash",
|
||||
},
|
||||
{
|
||||
name: "invalid trusted id key hash",
|
||||
setupCfg: func() {
|
||||
cfg = check.Config{
|
||||
Policy: &check.Policy{
|
||||
TrustedIdKeyHashes: [][]byte{[]byte("invalid")},
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
}
|
||||
},
|
||||
expectErr: true,
|
||||
errMsg: "trusted_id_key_hash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setupCfg()
|
||||
err := validateInput()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTrustedKeys(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authorKeyFile := filepath.Join(tempDir, "author.pem")
|
||||
idKeyFile := filepath.Join(tempDir, "id.pem")
|
||||
nonExistentFile := filepath.Join(tempDir, "nonexistent.pem")
|
||||
|
||||
authorKeyContent := "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAOI..."
|
||||
idKeyContent := "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIJAOI..."
|
||||
|
||||
require.NoError(t, os.WriteFile(authorKeyFile, []byte(authorKeyContent), 0o644))
|
||||
require.NoError(t, os.WriteFile(idKeyFile, []byte(idKeyContent), 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
trustedAuthorKeys []string
|
||||
trustedIdKeys []string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid files",
|
||||
trustedAuthorKeys: []string{authorKeyFile},
|
||||
trustedIdKeys: []string{idKeyFile},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent author key file",
|
||||
trustedAuthorKeys: []string{nonExistentFile},
|
||||
trustedIdKeys: []string{},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "nonexistent id key file",
|
||||
trustedAuthorKeys: []string{},
|
||||
trustedIdKeys: []string{nonExistentFile},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty file lists",
|
||||
trustedAuthorKeys: []string{},
|
||||
trustedIdKeys: []string{},
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
trustedAuthorKeys = tt.trustedAuthorKeys
|
||||
trustedIdKeys = tt.trustedIdKeys
|
||||
|
||||
err := parseTrustedKeys()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if len(tt.trustedAuthorKeys) > 0 {
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeys, len(tt.trustedAuthorKeys))
|
||||
assert.Equal(t, []byte(authorKeyContent), cfg.Policy.TrustedAuthorKeys[0])
|
||||
}
|
||||
if len(tt.trustedIdKeys) > 0 {
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeys, len(tt.trustedIdKeys))
|
||||
assert.Equal(t, []byte(idKeyContent), cfg.Policy.TrustedIdKeys[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUints(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stepping string
|
||||
platformInfo string
|
||||
expectErr bool
|
||||
expectedStep *uint32
|
||||
expectedPlatform *uint64
|
||||
}{
|
||||
{
|
||||
name: "empty values",
|
||||
stepping: "",
|
||||
platformInfo: "",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "decimal values",
|
||||
stepping: "5",
|
||||
platformInfo: "10",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(5),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "hex values",
|
||||
stepping: "0x5",
|
||||
platformInfo: "0xa",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(5),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "octal values",
|
||||
stepping: "0o7",
|
||||
platformInfo: "0o12",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(7),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "binary values",
|
||||
stepping: "0b101",
|
||||
platformInfo: "0b1010",
|
||||
expectErr: false,
|
||||
expectedStep: uint32Ptr(5),
|
||||
expectedPlatform: uint64Ptr(10),
|
||||
},
|
||||
{
|
||||
name: "invalid stepping",
|
||||
stepping: "invalid",
|
||||
platformInfo: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid platform info",
|
||||
stepping: "",
|
||||
platformInfo: "invalid",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{Product: &sevsnp.SevProduct{}}, RootOfTrust: &check.RootOfTrust{}}
|
||||
stepping = tt.stepping
|
||||
platformInfo = tt.platformInfo
|
||||
|
||||
err := parseUints()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedStep != nil {
|
||||
assert.Equal(t, *tt.expectedStep, cfg.Policy.Product.MachineStepping.Value)
|
||||
}
|
||||
if tt.expectedPlatform != nil {
|
||||
assert.Equal(t, *tt.expectedPlatform, cfg.Policy.PlatformInfo.Value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected int
|
||||
}{
|
||||
{"0x10", 16},
|
||||
{"0o10", 8},
|
||||
{"0b10", 2},
|
||||
{"10", 10},
|
||||
{"", 10},
|
||||
{"abc", 10},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := getBase(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
validConfig := map[string]interface{}{
|
||||
"rootOfTrust": map[string]interface{}{
|
||||
"product": "test_product",
|
||||
"cabundlePaths": []string{"test_path"},
|
||||
"cabundles": []string{"test_bundle"},
|
||||
"checkCrl": true,
|
||||
"disallowNetwork": true,
|
||||
},
|
||||
"policy": map[string]interface{}{
|
||||
"minimumGuestSvn": 1,
|
||||
"policy": "1",
|
||||
"minimumBuild": 1,
|
||||
"minimumVersion": "0.90",
|
||||
"requireAuthorKey": true,
|
||||
"requireIdBlock": true,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupConfig func() string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty config string",
|
||||
setupConfig: func() string {
|
||||
return ""
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config file",
|
||||
setupConfig: func() string {
|
||||
configFile := filepath.Join(tempDir, "valid_config.json")
|
||||
configBytes, err := json.Marshal(validConfig)
|
||||
assert.NoError(t, err)
|
||||
if err := os.WriteFile(configFile, configBytes, 0o644); err != nil {
|
||||
t.Errorf("failed to write config file: %v", err)
|
||||
}
|
||||
return configFile
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent config file",
|
||||
setupConfig: func() string {
|
||||
return filepath.Join(tempDir, "nonexistent.json")
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON config",
|
||||
setupConfig: func() string {
|
||||
configFile := filepath.Join(tempDir, "invalid_config.json")
|
||||
if err := os.WriteFile(configFile, []byte("invalid json"), 0o644); err != nil {
|
||||
t.Errorf("failed to write invalid config file: %v", err)
|
||||
}
|
||||
return configFile
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
cfgString = tt.setupConfig()
|
||||
|
||||
err := parseConfig()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg.Policy)
|
||||
assert.NotNil(t, cfg.RootOfTrust)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHashes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
trustedAuthorHashes []string
|
||||
trustedIdKeyHashes []string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid hashes",
|
||||
trustedAuthorHashes: []string{"deadbeef", "cafebabe"},
|
||||
trustedIdKeyHashes: []string{"12345678", "87654321"},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty hashes",
|
||||
trustedAuthorHashes: []string{},
|
||||
trustedIdKeyHashes: []string{},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid author hash",
|
||||
trustedAuthorHashes: []string{"invalid_hex"},
|
||||
trustedIdKeyHashes: []string{},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid id key hash",
|
||||
trustedAuthorHashes: []string{},
|
||||
trustedIdKeyHashes: []string{"invalid_hex"},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
trustedAuthorHashes = tt.trustedAuthorHashes
|
||||
trustedIdKeyHashes = tt.trustedIdKeyHashes
|
||||
|
||||
err := parseHashes()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, cfg.Policy.TrustedAuthorKeyHashes, len(tt.trustedAuthorHashes))
|
||||
assert.Len(t, cfg.Policy.TrustedIdKeyHashes, len(tt.trustedIdKeyHashes))
|
||||
|
||||
for i, hash := range tt.trustedAuthorHashes {
|
||||
expected, _ := hex.DecodeString(hash)
|
||||
assert.Equal(t, expected, cfg.Policy.TrustedAuthorKeyHashes[i])
|
||||
}
|
||||
|
||||
for i, hash := range tt.trustedIdKeyHashes {
|
||||
expected, _ := hex.DecodeString(hash)
|
||||
assert.Equal(t, expected, cfg.Policy.TrustedIdKeyHashes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAttestationFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
binaryFile := filepath.Join(tempDir, "attestation.bin")
|
||||
jsonFile := filepath.Join(tempDir, "attestation.json")
|
||||
|
||||
binaryData := make([]byte, 1024)
|
||||
for i := range binaryData {
|
||||
binaryData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
jsonData := &sevsnp.Attestation{
|
||||
Report: &sevsnp.Report{
|
||||
FamilyId: make([]byte, 16),
|
||||
ImageId: make([]byte, 16),
|
||||
ReportData: make([]byte, 64),
|
||||
Measurement: make([]byte, 48),
|
||||
HostData: make([]byte, 32),
|
||||
IdKeyDigest: make([]byte, 48),
|
||||
AuthorKeyDigest: make([]byte, 48),
|
||||
ReportId: make([]byte, 32),
|
||||
ReportIdMa: make([]byte, 32),
|
||||
ChipId: make([]byte, 64),
|
||||
Signature: make([]byte, 512),
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, os.WriteFile(binaryFile, binaryData, 0o644))
|
||||
require.NoError(t, os.WriteFile(jsonFile, jsonBytes, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestationFile string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid binary file",
|
||||
attestationFile: binaryFile,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid JSON file",
|
||||
attestationFile: jsonFile,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file",
|
||||
attestationFile: filepath.Join(tempDir, "nonexistent.bin"),
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
attestationFile = tt.attestationFile
|
||||
|
||||
err := parseAttestationFile()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, attestationRaw)
|
||||
assert.NotEmpty(t, attestationRaw)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSevsnpverify(t *testing.T) {
|
||||
trustedAuthorHashes = []string{}
|
||||
trustedIdKeyHashes = []string{}
|
||||
stepping = ""
|
||||
platformInfo = ""
|
||||
tempDir := t.TempDir()
|
||||
cfg = check.Config{Policy: &check.Policy{Product: &sevsnp.SevProduct{}}, RootOfTrust: &check.RootOfTrust{}}
|
||||
|
||||
attestationFile := filepath.Join(tempDir, "attestation.bin")
|
||||
attestationData := make([]byte, abi.ReportSize+100)
|
||||
for i := range attestationData {
|
||||
attestationData[i] = byte(i % 256)
|
||||
}
|
||||
require.NoError(t, os.WriteFile(attestationFile, attestationData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.Verifier)
|
||||
expectErr bool
|
||||
expectedMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifTeeAttestation", mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectErr: false,
|
||||
expectedMsg: "Attestation validation and verification is successful!",
|
||||
},
|
||||
{
|
||||
name: "verification failure",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifTeeAttestation", mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
|
||||
},
|
||||
expectErr: true,
|
||||
expectedMsg: "attestation validation and verification failed",
|
||||
},
|
||||
{
|
||||
name: "nonexistent file",
|
||||
args: []string{filepath.Join(tempDir, "nonexistent.bin")},
|
||||
setupMock: func(m *mocks.Verifier) {},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfgString = ""
|
||||
|
||||
mockVerifier := new(mocks.Verifier)
|
||||
tt.setupMock(mockVerifier)
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetOut(&output)
|
||||
|
||||
err := sevsnpverify(cmd, mockVerifier, tt.args)
|
||||
fmt.Println("error1", err)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedMsg != "" {
|
||||
assert.Contains(t, output.String(), tt.expectedMsg)
|
||||
}
|
||||
}
|
||||
|
||||
mockVerifier.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnvTPMAttestation(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
attestation := &tpmAttest.Attestation{
|
||||
Quotes: []*tpm.Quote{
|
||||
{
|
||||
Quote: []byte("test quote"),
|
||||
RawSig: []byte("test signature"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binaryData, err := proto.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
binaryFile := filepath.Join(tempDir, "attestation.pb")
|
||||
require.NoError(t, os.WriteFile(binaryFile, binaryData, 0o644))
|
||||
|
||||
textData, err := prototext.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
textFile := filepath.Join(tempDir, "attestation.txtpb")
|
||||
require.NoError(t, os.WriteFile(textFile, textData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
format string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "binary protobuf format",
|
||||
args: []string{binaryFile},
|
||||
format: FormatBinaryPB,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "text protobuf format",
|
||||
args: []string{textFile},
|
||||
format: FormatTextProto,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
args: []string{binaryFile},
|
||||
format: "invalid",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file",
|
||||
args: []string{filepath.Join(tempDir, "nonexistent.pb")},
|
||||
format: FormatBinaryPB,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
format = tt.format
|
||||
|
||||
result, err := returnvTPMAttestation(tt.args)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVtpmSevSnpverify(t *testing.T) {
|
||||
stepping = ""
|
||||
platformInfo = ""
|
||||
trustedAuthorHashes = []string{}
|
||||
trustedIdKeyHashes = []string{}
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
tempDir := t.TempDir()
|
||||
|
||||
attestation := &tpmAttest.Attestation{
|
||||
Quotes: []*tpm.Quote{
|
||||
{
|
||||
Quote: []byte("test quote"),
|
||||
RawSig: []byte("test signature"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binaryData, err := proto.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationFile := filepath.Join(tempDir, "vtpm_attestation.pb")
|
||||
require.NoError(t, os.WriteFile(attestationFile, binaryData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.Verifier)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifyAttestation", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "verification failure",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifyAttestation", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}
|
||||
cfgString = ""
|
||||
format = FormatBinaryPB
|
||||
|
||||
mockVerifier := new(mocks.Verifier)
|
||||
tt.setupMock(mockVerifier)
|
||||
|
||||
err := vtpmSevSnpverify(tt.args, mockVerifier)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
mockVerifier.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVtpmverify(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
attestation := &tpmAttest.Attestation{
|
||||
Quotes: []*tpm.Quote{
|
||||
{
|
||||
Quote: []byte("test quote"),
|
||||
RawSig: []byte("test signature"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
binaryData, err := proto.Marshal(attestation)
|
||||
require.NoError(t, err)
|
||||
|
||||
attestationFile := filepath.Join(tempDir, "vtpm_attestation.pb")
|
||||
require.NoError(t, os.WriteFile(attestationFile, binaryData, 0o644))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.Verifier)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifVTpmAttestation", mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "verification failure",
|
||||
args: []string{attestationFile},
|
||||
setupMock: func(m *mocks.Verifier) {
|
||||
m.On("VerifVTpmAttestation", mock.Anything, mock.Anything).Return(fmt.Errorf("verification failed"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
format = FormatBinaryPB
|
||||
|
||||
mockVerifier := new(mocks.Verifier)
|
||||
tt.setupMock(mockVerifier)
|
||||
|
||||
err := vtpmverify(tt.args, mockVerifier)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
mockVerifier.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func uint64Ptr(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
+694
-23
@@ -4,10 +4,12 @@ package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
@@ -18,6 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
mmocks "github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
@@ -311,26 +314,12 @@ func TestNewValidateAttestationValidationCmd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type MockMeasurement struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockMeasurement) Run(igvmBinaryPath string) ([]byte, error) {
|
||||
args := m.Called(igvmBinaryPath)
|
||||
return nil, args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockMeasurement) Stop() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNewMeasureCmd_RunSuccess(t *testing.T) {
|
||||
cliInstance := &CLI{}
|
||||
mockMeasurement := new(MockMeasurement)
|
||||
mockMeasurement := new(mmocks.MeasurementProvider)
|
||||
cliInstance.measurement = mockMeasurement
|
||||
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return(nil)
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, nil)
|
||||
|
||||
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
||||
buf := new(bytes.Buffer)
|
||||
@@ -346,11 +335,11 @@ func TestNewMeasureCmd_RunSuccess(t *testing.T) {
|
||||
|
||||
func TestNewMeasureCmd_RunError(t *testing.T) {
|
||||
cliInstance := &CLI{}
|
||||
mockMeasurement := new(MockMeasurement)
|
||||
mockMeasurement := new(mmocks.MeasurementProvider)
|
||||
cliInstance.measurement = mockMeasurement
|
||||
expectedError := errors.New("mocked measurement error")
|
||||
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return(expectedError)
|
||||
mockMeasurement.On("Run", "testfile.igvm").Return([]byte{}, expectedError)
|
||||
|
||||
cmd := cliInstance.NewMeasureCmd("fake_binary_path")
|
||||
|
||||
@@ -366,7 +355,7 @@ func TestNewMeasureCmd_RunError(t *testing.T) {
|
||||
mockMeasurement.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
func TestParseConfig1(t *testing.T) {
|
||||
tmpfile, err := os.CreateTemp("", "attestation_policy.json")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tmpfile.Name())
|
||||
@@ -393,7 +382,7 @@ func TestParseConfig(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseHashes(t *testing.T) {
|
||||
func TestParseHashes1(t *testing.T) {
|
||||
trustedAuthorHashes = []string{"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}
|
||||
trustedIdKeyHashes = []string{"fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210"}
|
||||
|
||||
@@ -444,7 +433,7 @@ func TestParseFiles(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestParseUints(t *testing.T) {
|
||||
func TestParseUints1(t *testing.T) {
|
||||
stepping = "10"
|
||||
platformInfo = "0xFF"
|
||||
|
||||
@@ -469,7 +458,7 @@ func TestParseUints(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateInput(t *testing.T) {
|
||||
func TestValidateInput1(t *testing.T) {
|
||||
cfg = check.Config{}
|
||||
if cfg.Policy == nil {
|
||||
cfg.Policy = &check.Policy{}
|
||||
@@ -494,7 +483,7 @@ func TestValidateInput(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetBase(t *testing.T) {
|
||||
func TestGetBase1(t *testing.T) {
|
||||
assert.Equal(t, 16, getBase("0xFF"))
|
||||
assert.Equal(t, 8, getBase("0o77"))
|
||||
assert.Equal(t, 2, getBase("0b1010"))
|
||||
@@ -716,3 +705,685 @@ func TestDecodeJWTToJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestEnvironment() func() {
|
||||
originalMode := mode
|
||||
originalCfgString := cfgString
|
||||
originalTimeout := timeout
|
||||
originalMaxRetryDelay := maxRetryDelay
|
||||
originalPlatformInfo := platformInfo
|
||||
originalStepping := stepping
|
||||
originalTrustedAuthorKeys := trustedAuthorKeys
|
||||
originalTrustedAuthorHashes := trustedAuthorHashes
|
||||
originalTrustedIdKeys := trustedIdKeys
|
||||
originalTrustedIdKeyHashes := trustedIdKeyHashes
|
||||
originalAttestationFile := attestationFile
|
||||
originalAttestationRaw := attestationRaw
|
||||
originalOutput := output
|
||||
originalNonce := nonce
|
||||
originalFormat := format
|
||||
originalTeeNonce := teeNonce
|
||||
originalTokenNonce := tokenNonce
|
||||
originalGetTextProtoAttestationReport := getTextProtoAttestationReport
|
||||
originalGetAzureTokenJWT := getAzureTokenJWT
|
||||
originalCloud := cloud
|
||||
originalReportData := reportData
|
||||
originalCheckCrl := checkCrl
|
||||
|
||||
mode = ""
|
||||
cfgString = ""
|
||||
timeout = 0
|
||||
maxRetryDelay = 0
|
||||
platformInfo = ""
|
||||
stepping = ""
|
||||
trustedAuthorKeys = []string{}
|
||||
trustedAuthorHashes = []string{}
|
||||
trustedIdKeys = []string{}
|
||||
trustedIdKeyHashes = []string{}
|
||||
attestationFile = ""
|
||||
attestationRaw = []byte{}
|
||||
output = ""
|
||||
nonce = []byte{}
|
||||
format = ""
|
||||
teeNonce = []byte{}
|
||||
tokenNonce = []byte{}
|
||||
getTextProtoAttestationReport = false
|
||||
getAzureTokenJWT = false
|
||||
cloud = ""
|
||||
reportData = []byte{}
|
||||
checkCrl = false
|
||||
|
||||
return func() {
|
||||
mode = originalMode
|
||||
cfgString = originalCfgString
|
||||
timeout = originalTimeout
|
||||
maxRetryDelay = originalMaxRetryDelay
|
||||
platformInfo = originalPlatformInfo
|
||||
stepping = originalStepping
|
||||
trustedAuthorKeys = originalTrustedAuthorKeys
|
||||
trustedAuthorHashes = originalTrustedAuthorHashes
|
||||
trustedIdKeys = originalTrustedIdKeys
|
||||
trustedIdKeyHashes = originalTrustedIdKeyHashes
|
||||
attestationFile = originalAttestationFile
|
||||
attestationRaw = originalAttestationRaw
|
||||
output = originalOutput
|
||||
nonce = originalNonce
|
||||
format = originalFormat
|
||||
teeNonce = originalTeeNonce
|
||||
tokenNonce = originalTokenNonce
|
||||
getTextProtoAttestationReport = originalGetTextProtoAttestationReport
|
||||
getAzureTokenJWT = originalGetAzureTokenJWT
|
||||
cloud = originalCloud
|
||||
reportData = originalReportData
|
||||
checkCrl = originalCheckCrl
|
||||
}
|
||||
}
|
||||
|
||||
func createTempFile(t *testing.T, content []byte) string {
|
||||
tmpfile, err := os.CreateTemp("", "test_*.bin")
|
||||
require.NoError(t, err)
|
||||
defer tmpfile.Close()
|
||||
|
||||
_, err = tmpfile.Write(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tmpfile.Name()
|
||||
}
|
||||
|
||||
func TestNewAttestationCmdEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedOutput string
|
||||
hasSubcommands bool
|
||||
}{
|
||||
{
|
||||
name: "no arguments shows help",
|
||||
args: []string{},
|
||||
expectedOutput: "Get and validate attestations",
|
||||
hasSubcommands: true,
|
||||
},
|
||||
{
|
||||
name: "help flag shows usage",
|
||||
args: []string{"--help"},
|
||||
expectedOutput: "Get and validate attestations",
|
||||
hasSubcommands: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
cmd := cli.NewAttestationCmd()
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
|
||||
output := buf.String()
|
||||
assert.Contains(t, output, tt.expectedOutput)
|
||||
|
||||
if tt.hasSubcommands {
|
||||
assert.Contains(t, output, "Get and validate attestations")
|
||||
assert.Contains(t, output, "Usage:")
|
||||
assert.Contains(t, output, "Flags:")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttestationCmdEdgeCases(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
setupMock func(*mocks.SDK)
|
||||
expectedErr string
|
||||
expectedOut string
|
||||
}{
|
||||
{
|
||||
name: "no arguments provided",
|
||||
args: []string{},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "accepts 1 arg(s), received 0",
|
||||
},
|
||||
{
|
||||
name: "too many arguments",
|
||||
args: []string{"snp", "extra"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "accepts 1 arg(s), received 2",
|
||||
},
|
||||
{
|
||||
name: "invalid attestation type",
|
||||
args: []string{"invalid-type"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Bad attestation type",
|
||||
},
|
||||
{
|
||||
name: "SNP with missing TEE nonce",
|
||||
args: []string{"snp"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "TEE nonce must be defined for SEV-SNP attestation",
|
||||
},
|
||||
{
|
||||
name: "vTPM with missing nonce",
|
||||
args: []string{"vtpm"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "vTPM nonce must be defined for vTPM attestation",
|
||||
},
|
||||
{
|
||||
name: "Azure token with missing token nonce",
|
||||
args: []string{"azure-token"},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Token nonce must be defined for Azure attestation",
|
||||
},
|
||||
{
|
||||
name: "TEE nonce too large",
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes",
|
||||
},
|
||||
{
|
||||
name: "vTPM nonce too large",
|
||||
args: []string{"vtpm", "--vtpm", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "Token nonce too large",
|
||||
args: []string{"azure-token", "--token", hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce+1))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "successful TDX attestation",
|
||||
args: []string{"tdx", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
sdk.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(nil).Run(func(args mock.Arguments) {
|
||||
if _, err := args.Get(4).(*os.File).Write([]byte("mock tdx attestation")); err != nil {
|
||||
t.Fatalf("Failed to write to attestation file: %v", err)
|
||||
}
|
||||
})
|
||||
},
|
||||
expectedOut: "Fetching TDX attestation report",
|
||||
},
|
||||
{
|
||||
name: "file creation error",
|
||||
args: []string{"snp", "--tee", hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))},
|
||||
setupMock: func(sdk *mocks.SDK) {
|
||||
},
|
||||
expectedErr: "Error creating attestation file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
os.Remove(attestationFilePath)
|
||||
os.Remove(azureAttestResultFilePath)
|
||||
os.Remove(azureAttestTokenFilePath)
|
||||
defer func() {
|
||||
os.Remove(attestationFilePath)
|
||||
os.Remove(azureAttestResultFilePath)
|
||||
os.Remove(azureAttestTokenFilePath)
|
||||
}()
|
||||
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
tc.setupMock(mockSDK)
|
||||
|
||||
if tc.name == "file creation error" {
|
||||
err := os.Mkdir(attestationFilePath, 0o755)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(attestationFilePath)
|
||||
}
|
||||
|
||||
cmd := cli.NewGetAttestationCmd()
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
cmd.SetArgs(tc.args)
|
||||
|
||||
err := cmd.Execute()
|
||||
output := buf.String()
|
||||
|
||||
if tc.expectedErr != "" {
|
||||
assert.Contains(t, output, tc.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tc.expectedOut != "" {
|
||||
assert.Contains(t, output, tc.expectedOut)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileOperations(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
t.Run("openInputFile", func(t *testing.T) {
|
||||
attestationFile = ""
|
||||
reader, err := openInputFile()
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, errEmptyFile, err)
|
||||
assert.Nil(t, reader)
|
||||
|
||||
tempFile := createTempFile(t, []byte("test content"))
|
||||
defer os.Remove(tempFile)
|
||||
attestationFile = tempFile
|
||||
reader, err = openInputFile()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, reader)
|
||||
if file, ok := reader.(*os.File); ok {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
attestationFile = "non-existent-file.bin"
|
||||
reader, err = openInputFile()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, reader)
|
||||
})
|
||||
|
||||
t.Run("createOutputFile", func(t *testing.T) {
|
||||
output = ""
|
||||
writer, err := createOutputFile()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, os.Stdout, writer)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
output = filepath.Join(tempDir, "test_output.txt")
|
||||
writer, err = createOutputFile()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, writer)
|
||||
if file, ok := writer.(*os.File); ok {
|
||||
file.Close()
|
||||
}
|
||||
|
||||
output = "/invalid/path/file.txt"
|
||||
writer, err = createOutputFile()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, writer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidationFunctions(t *testing.T) {
|
||||
t.Run("validateFieldLength", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldName string
|
||||
field []byte
|
||||
expectedLength int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "nil field",
|
||||
fieldName: "test",
|
||||
field: nil,
|
||||
expectedLength: 32,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "correct length",
|
||||
fieldName: "test",
|
||||
field: make([]byte, 32),
|
||||
expectedLength: 32,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "incorrect length",
|
||||
fieldName: "test",
|
||||
field: make([]byte, 16),
|
||||
expectedLength: 32,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateFieldLength(tt.fieldName, tt.field, tt.expectedLength)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.fieldName)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDecodeJWTToJSONEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte(""),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "single part",
|
||||
input: []byte("onlyonepart"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid base64 in header",
|
||||
input: []byte("invalid@base64.validpart"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid base64 in payload",
|
||||
input: []byte("eyJhbGciOiJIUzI1NiJ9.invalid@base64"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON in header",
|
||||
input: []byte("bm90anNvbg.eyJzdWIiOiJ0ZXN0In0"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON in payload",
|
||||
input: []byte("eyJhbGciOiJIUzI1NiJ9.bm90anNvbg"),
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "valid JWT with padding",
|
||||
input: []byte("eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0In0.signature"),
|
||||
expected: `{
|
||||
"header": {
|
||||
"alg": "HS256"
|
||||
},
|
||||
"payload": {
|
||||
"sub": "test"
|
||||
}
|
||||
}`,
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := decodeJWTToJSON(tt.input)
|
||||
if tt.hasError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
if tt.expected != "" {
|
||||
assert.JSONEq(t, tt.expected, string(result))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeasureCmdEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
mockSetup func(*mmocks.MeasurementProvider)
|
||||
expectedError string
|
||||
expectedOut string
|
||||
}{
|
||||
{
|
||||
name: "no arguments",
|
||||
args: []string{},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
},
|
||||
expectedError: "requires at least 1 arg(s), only received 0",
|
||||
},
|
||||
{
|
||||
name: "single line output success",
|
||||
args: []string{"test.igvm"},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
m.On("Run", "test.igvm").Return([]byte("ABCDEF123456"), nil)
|
||||
},
|
||||
expectedOut: "",
|
||||
},
|
||||
{
|
||||
name: "multi-line output error",
|
||||
args: []string{"test.igvm"},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
m.On("Run", "test.igvm").Return([]byte("line1\nline2\nERROR: something went wrong"), nil)
|
||||
},
|
||||
expectedError: "ERROR: something went wrong",
|
||||
},
|
||||
{
|
||||
name: "measurement run error",
|
||||
args: []string{"test.igvm"},
|
||||
mockSetup: func(m *mmocks.MeasurementProvider) {
|
||||
m.On("Run", "test.igvm").Return(nil, errors.New("measurement failed"))
|
||||
},
|
||||
expectedError: "measurement failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockMeasurement := new(mmocks.MeasurementProvider)
|
||||
tt.mockSetup(mockMeasurement)
|
||||
|
||||
cli := &CLI{measurement: mockMeasurement}
|
||||
cmd := cli.NewMeasureCmd("fake_binary_path")
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
if tt.expectedError != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOut != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOut)
|
||||
}
|
||||
}
|
||||
|
||||
mockMeasurement.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAttestationValidationCmdPreRunE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
flags map[string]string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
name: "no file path provided",
|
||||
args: []string{},
|
||||
flags: map[string]string{"mode": "snp"},
|
||||
expectedErr: "please pass the attestation report file path",
|
||||
},
|
||||
{
|
||||
name: "multiple file paths",
|
||||
args: []string{"file1.bin", "file2.bin"},
|
||||
flags: map[string]string{"mode": "snp"},
|
||||
expectedErr: "please pass the attestation report file path",
|
||||
},
|
||||
{
|
||||
name: "unknown mode",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "unknown"},
|
||||
expectedErr: "unknown mode: unknown",
|
||||
},
|
||||
{
|
||||
name: "SNP mode missing report_data",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "snp"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "SNP mode missing product",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "snp", "report_data": "123"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "vTPM mode missing nonce",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "vtpm"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "SNP-vTPM mode missing required flags",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "snp-vtpm"},
|
||||
expectedErr: "",
|
||||
},
|
||||
{
|
||||
name: "TDX mode missing report_data",
|
||||
args: []string{"test.bin"},
|
||||
flags: map[string]string{"mode": "tdx"},
|
||||
expectedErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
for key, value := range tt.flags {
|
||||
if err := cmd.Flags().Set(key, value); err != nil {
|
||||
}
|
||||
}
|
||||
|
||||
err := cmd.PreRunE(cmd, tt.args)
|
||||
if tt.expectedErr != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedErr)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudProviderConfigurations(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cloud string
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "none cloud provider",
|
||||
cloud: CCNone,
|
||||
expectedType: "vtpm",
|
||||
},
|
||||
{
|
||||
name: "azure cloud provider",
|
||||
cloud: CCAzure,
|
||||
expectedType: "azure",
|
||||
},
|
||||
{
|
||||
name: "gcp cloud provider",
|
||||
cloud: CCGCP,
|
||||
expectedType: "vtpm",
|
||||
},
|
||||
{
|
||||
name: "default cloud provider",
|
||||
cloud: "",
|
||||
expectedType: "vtpm",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
cmd := cli.NewValidateAttestationValidationCmd()
|
||||
|
||||
if err := cmd.Flags().Set("cloud", tt.cloud); err != nil {
|
||||
t.Fatalf("Failed to set cloud flag: %v", err)
|
||||
}
|
||||
cloud, _ := cmd.Flags().GetString("cloud")
|
||||
assert.Equal(t, tt.cloud, cloud)
|
||||
|
||||
assert.Contains(t, cmd.Short, tt.cloud)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileOperationErrors(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
t.Run("file close error handling", func(t *testing.T) {
|
||||
tempFile := createTempFile(t, []byte("test content"))
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
assert.True(t, true)
|
||||
})
|
||||
|
||||
t.Run("file write error handling", func(t *testing.T) {
|
||||
tempFile := createTempFile(t, []byte("test content"))
|
||||
defer os.Remove(tempFile)
|
||||
|
||||
err := os.Chmod(tempFile, 0o444)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(tempFile, []byte("new content"), 0o644)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("file read error handling", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
_, err := os.ReadFile(tempDir)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
defer setupTestEnvironment()()
|
||||
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mockSDK.On("Attestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(context.Canceled)
|
||||
|
||||
cmd := cli.NewGetAttestationCmd()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
teeNonceHex := hex.EncodeToString(bytes.Repeat([]byte{0x00}, quoteprovider.Nonce))
|
||||
cmd.SetArgs([]string{"snp", "--tee", teeNonceHex})
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, buf.String(), "Failed to get attestation due to error")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk/mocks"
|
||||
)
|
||||
|
||||
func TestCLI_NewIMAMeasurementsCmd(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
connectErr error
|
||||
mockIMAData string
|
||||
mockError error
|
||||
expectedFilename string
|
||||
expectedOutput []string
|
||||
expectedError []string
|
||||
shouldCreateFile bool
|
||||
fileCreationError bool
|
||||
invalidDigestData bool
|
||||
setupCustomFile func(filename string) error
|
||||
}{
|
||||
{
|
||||
name: "successful_retrieval_default_filename",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedFilename: imaMeasurementsFilename,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "PCR10 = 0000000000000000000000000000000000000000", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "successful_retrieval_custom_filename",
|
||||
args: []string{"custom_ima_file.txt"},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedFilename: "custom_ima_file.txt",
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "custom_ima_file.txt", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "connection_error",
|
||||
args: []string{},
|
||||
connectErr: fmt.Errorf("connection failed"),
|
||||
expectedError: []string{"Failed to connect to agent: connection failed ❌"},
|
||||
},
|
||||
{
|
||||
name: "file_creation_error",
|
||||
args: []string{"/invalid/path/file.txt"},
|
||||
connectErr: nil,
|
||||
fileCreationError: true,
|
||||
expectedError: []string{"Error creating imaMeasurements file:"},
|
||||
},
|
||||
{
|
||||
name: "sdk_error",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockError: fmt.Errorf("SDK communication failed"),
|
||||
expectedError: []string{"Error retrieving Linux IMA measurements file: SDK communication failed ❌"},
|
||||
},
|
||||
{
|
||||
name: "verification_failure_wrong_pcr",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "10 9999999999999999999999999999999999999999 ima-ng sha1:0000000000000000000000000000000000000000 /usr/bin/test",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully"},
|
||||
expectedError: []string{"Measurements file not verified ❌"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "empty_measurements_file",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "measurements_with_non_pcr10_entries",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
{
|
||||
name: "measurements_with_zero_digest_replacement",
|
||||
args: []string{},
|
||||
connectErr: nil,
|
||||
mockIMAData: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
mockError: nil,
|
||||
expectedOutput: []string{"⏳ Retrieving computation Linux IMA measurements file", "Linux IMA measurements file retrieved and saved successfully", "Measurements file verified!"},
|
||||
shouldCreateFile: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
|
||||
cli := &CLI{
|
||||
agentSDK: mockSDK,
|
||||
connectErr: tc.connectErr,
|
||||
}
|
||||
|
||||
if tc.connectErr == nil && !tc.fileCreationError {
|
||||
mockSDK.On("IMAMeasurements", mock.Anything, mock.Anything).Return([]byte(tc.mockIMAData), tc.mockError)
|
||||
}
|
||||
|
||||
cmd := cli.NewIMAMeasurementsCmd()
|
||||
|
||||
var output bytes.Buffer
|
||||
cmd.SetOut(&output)
|
||||
cmd.SetErr(&output)
|
||||
|
||||
expectedFilename := tc.expectedFilename
|
||||
if expectedFilename == "" {
|
||||
if len(tc.args) > 0 {
|
||||
expectedFilename = tc.args[0]
|
||||
} else {
|
||||
expectedFilename = imaMeasurementsFilename
|
||||
}
|
||||
}
|
||||
|
||||
if tc.setupCustomFile != nil {
|
||||
err := tc.setupCustomFile(expectedFilename)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
cmd.SetArgs(tc.args)
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err, "Command execution failed")
|
||||
|
||||
outputStr := output.String()
|
||||
|
||||
for _, expectedMsg := range tc.expectedOutput {
|
||||
assert.Contains(t, outputStr, expectedMsg, "Expected output message not found")
|
||||
}
|
||||
|
||||
for _, expectedErr := range tc.expectedError {
|
||||
assert.Contains(t, outputStr, expectedErr, "Expected error message not found")
|
||||
}
|
||||
|
||||
if tc.shouldCreateFile && tc.connectErr == nil && !tc.fileCreationError && tc.mockError == nil {
|
||||
if _, err := os.Stat(expectedFilename); err == nil {
|
||||
os.Remove(expectedFilename)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.connectErr == nil && !tc.fileCreationError {
|
||||
mockSDK.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+10
-6
@@ -38,9 +38,11 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
Example: `create-vm`,
|
||||
Args: cobra.ExactArgs(0),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
if c.managerClient == nil || c.connectErr != nil {
|
||||
if err := c.InitializeManagerClient(cmd); err != nil {
|
||||
printError(cmd, "Failed to connect to manager: %v ❌ ", c.connectErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
@@ -74,7 +76,7 @@ func (c *CLI) NewCreateVMCmd() *cobra.Command {
|
||||
cmd.Flags().StringVar(&agentCVMServerCA, serverCA, "", "CVM server CA")
|
||||
cmd.Flags().StringVar(&agentCVMClientKey, clientKey, "", "CVM client key")
|
||||
cmd.Flags().StringVar(&agentCVMClientCrt, clientCrt, "", "CVM client crt")
|
||||
cmd.Flags().StringVar(&agentCVMCaUrl, agentCVMCaUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&agentCVMCaUrl, caUrl, "", "CVM CA service URL")
|
||||
cmd.Flags().StringVar(&agentLogLevel, logLevel, "", "Agent Log level")
|
||||
cmd.Flags().DurationVar(&ttl, ttlFlag, 0, "TTL for the VM")
|
||||
if err := cmd.MarkFlagRequired(serverURL); err != nil {
|
||||
@@ -92,8 +94,10 @@ func (c *CLI) NewRemoveVMCmd() *cobra.Command {
|
||||
Example: `remove-vm <cvm_id>`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
if c.managerClient == nil || c.connectErr != nil {
|
||||
if err := c.InitializeManagerClient(cmd); err == nil {
|
||||
defer c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if c.connectErr != nil {
|
||||
|
||||
@@ -0,0 +1,600 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func TestCLI_NewCreateVMCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.ManagerServiceClient)
|
||||
setupCLI func(*CLI)
|
||||
setupFiles func(string) error
|
||||
flags map[string]string
|
||||
expectedOutput string
|
||||
expectedError string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful VM creation with all flags",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
|
||||
return req.AgentCvmServerUrl == "https://server.com" &&
|
||||
req.AgentLogLevel == "debug" &&
|
||||
req.AgentCvmCaUrl == "https://ca.com" &&
|
||||
req.Ttl == "1h0m0s" &&
|
||||
string(req.AgentCvmServerCaCert) == "ca-cert-content" &&
|
||||
string(req.AgentCvmClientKey) == "client-key-content" &&
|
||||
string(req.AgentCvmClientCert) == "client-cert-content"
|
||||
})).Return(&manager.CreateRes{
|
||||
CvmId: "vm-123",
|
||||
ForwardedPort: "8080",
|
||||
}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"server-ca.pem": "ca-cert-content",
|
||||
"client-key.pem": "client-key-content",
|
||||
"client-crt.pem": "client-cert-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
"server-ca": "server-ca.pem",
|
||||
"client-key": "client-key.pem",
|
||||
"client-crt": "client-crt.pem",
|
||||
"ca-url": "https://ca.com",
|
||||
"log-level": "debug",
|
||||
"ttl": "1h",
|
||||
},
|
||||
expectedOutput: "✅ Virtual machine created successfully with id vm-123 and port 8080",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "successful VM creation with minimal flags",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.MatchedBy(func(req *manager.CreateReq) bool {
|
||||
return req.AgentCvmServerUrl == "https://server.com" &&
|
||||
req.AgentLogLevel == "" &&
|
||||
req.AgentCvmCaUrl == "" &&
|
||||
req.Ttl == "" &&
|
||||
len(req.AgentCvmServerCaCert) == 0 &&
|
||||
len(req.AgentCvmClientKey) == 0 &&
|
||||
len(req.AgentCvmClientCert) == 0
|
||||
})).Return(&manager.CreateRes{
|
||||
CvmId: "vm-456",
|
||||
ForwardedPort: "9090",
|
||||
}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // No files needed for minimal test
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedOutput: "✅ Virtual machine created successfully with id vm-456 and port 9090",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "manager client initialization failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as initialization fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "certificate loading failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as cert loading fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create the cert file
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
"server-ca": "nonexistent-ca.pem",
|
||||
},
|
||||
expectedError: "Error loading certs:",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "CreateVm API call failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("CreateVm", mock.Anything, mock.Anything).Return(nil, errors.New("API error"))
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{
|
||||
"server-url": "https://server.com",
|
||||
},
|
||||
expectedError: "Error creating virtual machine: API error ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing required server-url flag",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
flags: map[string]string{}, // No server-url flag
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "cli-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
err = os.Chdir(tmpDir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := os.Chdir(oldDir)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockClient := new(mocks.ManagerServiceClient)
|
||||
tt.setupMock(mockClient)
|
||||
|
||||
mockCLI := &CLI{
|
||||
managerClient: mockClient,
|
||||
}
|
||||
|
||||
tt.setupCLI(mockCLI)
|
||||
|
||||
cmd := mockCLI.NewCreateVMCmd()
|
||||
|
||||
for flag, value := range tt.flags {
|
||||
err := cmd.Flags().Set(flag, value)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err = cmd.Execute()
|
||||
|
||||
if tt.expectError {
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCLI_NewRemoveVMCmd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMock func(*mocks.ManagerServiceClient)
|
||||
setupCLI func(*CLI)
|
||||
args []string
|
||||
expectedOutput string
|
||||
expectedError string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful VM removal",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
|
||||
CvmId: "vm-123",
|
||||
}).Return(&emptypb.Empty{}, nil)
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedOutput: "✅ Virtual machine removed successfully",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "manager client initialization failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as initialization fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
cli.connectErr = errors.New("connection failed")
|
||||
},
|
||||
args: []string{"vm-123"},
|
||||
expectedError: "Failed to connect to manager: failed to connect to grpc server : failed to exit idle mode: passthrough: received empty target in Build() ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "RemoveVm API call failure",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
m.On("RemoveVm", mock.Anything, &manager.RemoveReq{
|
||||
CvmId: "vm-456",
|
||||
}).Return(nil, errors.New("removal failed"))
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-456"},
|
||||
expectedError: "Error removing virtual machine: removal failed ❌",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "missing VM ID argument",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{}, // No VM ID provided
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "too many arguments",
|
||||
setupMock: func(m *mocks.ManagerServiceClient) {
|
||||
// No expectations set as command validation fails
|
||||
},
|
||||
setupCLI: func(cli *CLI) {
|
||||
},
|
||||
args: []string{"vm-123", "extra-arg"},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockClient := new(mocks.ManagerServiceClient)
|
||||
tt.setupMock(mockClient)
|
||||
|
||||
mockCLI := &CLI{
|
||||
managerClient: mockClient,
|
||||
}
|
||||
tt.setupCLI(mockCLI)
|
||||
|
||||
cmd := mockCLI.NewRemoveVMCmd()
|
||||
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
cmd.SetErr(&buf)
|
||||
|
||||
err := cmd.Execute()
|
||||
|
||||
if tt.expectError {
|
||||
if tt.expectedError != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedError)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedOutput != "" {
|
||||
assert.Contains(t, buf.String(), tt.expectedOutput)
|
||||
}
|
||||
}
|
||||
|
||||
mockClient.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFile func(string) (string, error)
|
||||
path string
|
||||
expectedResult []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful file read",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
filePath := filepath.Join(tmpDir, "test.txt")
|
||||
err := os.WriteFile(filePath, []byte("test content"), 0o644)
|
||||
return filePath, err
|
||||
},
|
||||
expectedResult: []byte("test content"),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty path returns nil",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
path: "",
|
||||
expectedResult: nil,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nonexistent file returns error",
|
||||
setupFile: func(tmpDir string) (string, error) {
|
||||
return filepath.Join(tmpDir, "nonexistent.txt"), nil
|
||||
},
|
||||
expectedResult: nil,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "fileReader-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
filePath, err := tt.setupFile(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.path != "" {
|
||||
filePath = tt.path
|
||||
}
|
||||
|
||||
result, err := fileReader(filePath)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCerts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFiles func(string) error
|
||||
setupGlobal func(string)
|
||||
expectError bool
|
||||
validate func(*testing.T, *manager.CreateReq)
|
||||
}{
|
||||
{
|
||||
name: "successful cert loading with all files",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"client.key": "client-key-content",
|
||||
"client.crt": "client-cert-content",
|
||||
"server.ca": "server-ca-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
agentCVMServerCA = filepath.Join(tmpDir, "server.ca")
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
assert.Equal(t, []byte("client-key-content"), req.AgentCvmClientKey)
|
||||
assert.Equal(t, []byte("client-cert-content"), req.AgentCvmClientCert)
|
||||
assert.Equal(t, []byte("server-ca-content"), req.AgentCvmServerCaCert)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful cert loading with empty paths",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = ""
|
||||
agentCVMClientCrt = ""
|
||||
agentCVMServerCA = ""
|
||||
},
|
||||
expectError: false,
|
||||
validate: func(t *testing.T, req *manager.CreateReq) {
|
||||
assert.Nil(t, req.AgentCvmClientKey)
|
||||
assert.Nil(t, req.AgentCvmClientCert)
|
||||
assert.Nil(t, req.AgentCvmServerCaCert)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "client key file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
return nil // Don't create client key file
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "nonexistent.key")
|
||||
agentCVMClientCrt = ""
|
||||
agentCVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "client cert file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
// Create client key but not cert
|
||||
return os.WriteFile(filepath.Join(tmpDir, "client.key"), []byte("key-content"), 0o644)
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "nonexistent.crt")
|
||||
agentCVMServerCA = ""
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "server CA file read error",
|
||||
setupFiles: func(tmpDir string) error {
|
||||
files := map[string]string{
|
||||
"client.key": "client-key-content",
|
||||
"client.crt": "client-cert-content",
|
||||
}
|
||||
for filename, content := range files {
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, filename), []byte(content), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
setupGlobal: func(tmpDir string) {
|
||||
agentCVMClientKey = filepath.Join(tmpDir, "client.key")
|
||||
agentCVMClientCrt = filepath.Join(tmpDir, "client.crt")
|
||||
agentCVMServerCA = filepath.Join(tmpDir, "nonexistent.ca")
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "loadCerts-test-")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
err = tt.setupFiles(tmpDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store original global variables
|
||||
origClientKey := agentCVMClientKey
|
||||
origClientCrt := agentCVMClientCrt
|
||||
origServerCA := agentCVMServerCA
|
||||
|
||||
// Setup global variables for test
|
||||
tt.setupGlobal(tmpDir)
|
||||
|
||||
// Restore original values after test
|
||||
defer func() {
|
||||
agentCVMClientKey = origClientKey
|
||||
agentCVMClientCrt = origClientCrt
|
||||
agentCVMServerCA = origServerCA
|
||||
}()
|
||||
|
||||
result, err := loadCerts()
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandCreation(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
|
||||
t.Run("create-vm command creation", func(t *testing.T) {
|
||||
cmd := cli.NewCreateVMCmd()
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "create-vm", cmd.Use)
|
||||
assert.Equal(t, "Create a new virtual machine", cmd.Short)
|
||||
|
||||
// Check that required flags are set
|
||||
flag := cmd.Flags().Lookup("server-url")
|
||||
assert.NotNil(t, flag)
|
||||
// Note: We can't easily test MarkFlagRequired in unit tests
|
||||
})
|
||||
|
||||
t.Run("remove-vm command creation", func(t *testing.T) {
|
||||
cmd := cli.NewRemoveVMCmd()
|
||||
assert.NotNil(t, cmd)
|
||||
assert.Equal(t, "remove-vm", cmd.Use)
|
||||
assert.Equal(t, "Remove a virtual machine", cmd.Short)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTTLHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ttlInput string
|
||||
expectedTTL time.Duration
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid duration",
|
||||
ttlInput: "1h30m",
|
||||
expectedTTL: time.Hour + 30*time.Minute,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero duration",
|
||||
ttlInput: "0",
|
||||
expectedTTL: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
ttlInput: "",
|
||||
expectedTTL: 0,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockCLI := &CLI{
|
||||
managerClient: new(mocks.ManagerServiceClient),
|
||||
}
|
||||
|
||||
cmd := mockCLI.NewCreateVMCmd()
|
||||
|
||||
if tt.ttlInput != "" {
|
||||
err := cmd.Flags().Set("ttl", tt.ttlInput)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedTTL, ttl)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+3
-1
@@ -62,5 +62,7 @@ func (c *CLI) InitializeManagerClient(cmd *cobra.Command) error {
|
||||
}
|
||||
|
||||
func (c *CLI) Close() {
|
||||
c.client.Close()
|
||||
if c.client != nil {
|
||||
c.client.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ require (
|
||||
github.com/gofrs/uuid v4.4.0+incompatible
|
||||
github.com/google/go-sev-guest v0.13.0
|
||||
github.com/google/go-tdx-guest v0.3.2-0.20241009005452-097ee70d0843
|
||||
github.com/mdlayher/vsock v1.2.1
|
||||
github.com/spf13/cobra v1.9.1
|
||||
github.com/spf13/pflag v1.0.6
|
||||
github.com/stretchr/testify v1.10.0
|
||||
@@ -115,7 +114,6 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/prometheus/client_golang v1.22.0 // indirect
|
||||
|
||||
@@ -166,10 +166,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/mdlayher/vsock v1.2.1 h1:pC1mTJTvjo1r9n9fbm7S1j04rCgCzhCOS5DY0zqHlnQ=
|
||||
github.com/mdlayher/vsock v1.2.1/go.mod h1:NRfCibel++DgeMD8z/hP+PPTjlNJsdPOmxcnENvE+SE=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw=
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
@@ -42,12 +43,15 @@ const (
|
||||
|
||||
type Server struct {
|
||||
server.BaseServer
|
||||
mu sync.RWMutex
|
||||
server *grpc.Server
|
||||
health *health.Server
|
||||
registerService serviceRegister
|
||||
authSvc auth.Authenticator
|
||||
health *health.Server
|
||||
caUrl string
|
||||
cvmId string
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
type serviceRegister func(srv *grpc.Server)
|
||||
@@ -74,6 +78,18 @@ func New(ctx context.Context, cancel context.CancelFunc, name string, config ser
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
if s.started {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server already started")
|
||||
}
|
||||
if s.stopped {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server already stopped")
|
||||
}
|
||||
s.started = true
|
||||
s.mu.Unlock()
|
||||
|
||||
errCh := make(chan error)
|
||||
grpcServerOptions := []grpc.ServerOption{
|
||||
grpc.StatsHandler(otelgrpc.NewServerHandler()),
|
||||
@@ -199,14 +215,22 @@ func (s *Server) Start() error {
|
||||
|
||||
grpcServerOptions = append(grpcServerOptions, creds)
|
||||
|
||||
s.mu.Lock()
|
||||
s.server = grpc.NewServer(grpcServerOptions...)
|
||||
s.health = health.NewServer()
|
||||
grpchealth.RegisterHealthServer(s.server, s.health)
|
||||
s.registerService(s.server)
|
||||
s.health.SetServingStatus(s.Name, grpchealth.HealthCheckResponse_SERVING)
|
||||
s.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
errCh <- s.server.Serve(listener)
|
||||
s.mu.RLock()
|
||||
server := s.server
|
||||
s.mu.RUnlock()
|
||||
|
||||
if server != nil {
|
||||
errCh <- server.Serve(listener)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -219,19 +243,33 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.stopped {
|
||||
return nil
|
||||
}
|
||||
s.stopped = true
|
||||
|
||||
defer s.Cancel()
|
||||
|
||||
c := make(chan bool)
|
||||
go func() {
|
||||
defer close(c)
|
||||
s.health.Shutdown()
|
||||
s.server.GracefulStop()
|
||||
if s.health != nil {
|
||||
s.health.Shutdown()
|
||||
}
|
||||
if s.server != nil {
|
||||
s.server.GracefulStop()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(stopWaitTime):
|
||||
}
|
||||
s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address))
|
||||
|
||||
s.Logger.Info(fmt.Sprintf("%s gRPC service shutdown at %s", s.Name, s.Address))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,425 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
)
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
assert.IsType(t, &grpcServer{}, server)
|
||||
|
||||
grpcSrv := server.(*grpcServer)
|
||||
assert.Equal(t, mockSvc, grpcSrv.svc)
|
||||
}
|
||||
|
||||
func TestCreateVm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *manager.CreateReq
|
||||
mockPort string
|
||||
mockId string
|
||||
mockErr error
|
||||
expectedRes *manager.CreateRes
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "successful VM creation",
|
||||
req: &manager.CreateReq{},
|
||||
mockPort: "8080",
|
||||
mockId: "vm-123",
|
||||
mockErr: nil,
|
||||
expectedRes: &manager.CreateRes{
|
||||
ForwardedPort: "8080",
|
||||
CvmId: "vm-123",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "VM creation with different port",
|
||||
req: &manager.CreateReq{},
|
||||
mockPort: "9090",
|
||||
mockId: "vm-456",
|
||||
mockErr: nil,
|
||||
expectedRes: &manager.CreateRes{
|
||||
ForwardedPort: "9090",
|
||||
CvmId: "vm-456",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "VM creation failure",
|
||||
req: &manager.CreateReq{},
|
||||
mockPort: "",
|
||||
mockId: "",
|
||||
mockErr: errors.New("failed to create VM"),
|
||||
expectedRes: nil,
|
||||
expectedErr: errors.New("failed to create VM"),
|
||||
},
|
||||
{
|
||||
name: "VM creation with empty request",
|
||||
req: &manager.CreateReq{},
|
||||
mockPort: "3000",
|
||||
mockId: "vm-empty",
|
||||
mockErr: nil,
|
||||
expectedRes: &manager.CreateRes{
|
||||
ForwardedPort: "3000",
|
||||
CvmId: "vm-empty",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
mockSvc.On("CreateVM", mock.Anything, tt.req).Return(tt.mockPort, tt.mockId, tt.mockErr)
|
||||
|
||||
res, err := server.CreateVm(context.Background(), tt.req)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr.Error(), err.Error())
|
||||
assert.Nil(t, res)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRes, res)
|
||||
}
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveVm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *manager.RemoveReq
|
||||
mockErr error
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "successful VM removal",
|
||||
req: &manager.RemoveReq{
|
||||
CvmId: "vm-123",
|
||||
},
|
||||
mockErr: nil,
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "VM removal failure",
|
||||
req: &manager.RemoveReq{
|
||||
CvmId: "vm-456",
|
||||
},
|
||||
mockErr: errors.New("VM not found"),
|
||||
expectedErr: errors.New("VM not found"),
|
||||
},
|
||||
{
|
||||
name: "VM removal with empty ID",
|
||||
req: &manager.RemoveReq{
|
||||
CvmId: "",
|
||||
},
|
||||
mockErr: errors.New("invalid VM ID"),
|
||||
expectedErr: errors.New("invalid VM ID"),
|
||||
},
|
||||
{
|
||||
name: "VM removal with non-existent ID",
|
||||
req: &manager.RemoveReq{
|
||||
CvmId: "non-existent-vm",
|
||||
},
|
||||
mockErr: errors.New("VM does not exist"),
|
||||
expectedErr: errors.New("VM does not exist"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
mockSvc.On("RemoveVM", mock.Anything, tt.req.CvmId).Return(tt.mockErr)
|
||||
|
||||
res, err := server.RemoveVm(context.Background(), tt.req)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr.Error(), err.Error())
|
||||
assert.Nil(t, res)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &emptypb.Empty{}, res)
|
||||
}
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCVMInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *manager.CVMInfoReq
|
||||
mockOvmf string
|
||||
mockCpuNum int
|
||||
mockCpuType string
|
||||
mockEosVersion string
|
||||
expectedRes *manager.CVMInfoRes
|
||||
}{
|
||||
{
|
||||
name: "successful CVM info retrieval",
|
||||
req: &manager.CVMInfoReq{
|
||||
Id: "cvm-123",
|
||||
},
|
||||
mockOvmf: "OVMF-v1.0",
|
||||
mockCpuNum: 4,
|
||||
mockCpuType: "Intel-x86_64",
|
||||
mockEosVersion: "EOS-v2.1",
|
||||
expectedRes: &manager.CVMInfoRes{
|
||||
OvmfVersion: "OVMF-v1.0",
|
||||
CpuNum: 4,
|
||||
CpuType: "Intel-x86_64",
|
||||
EosVersion: "EOS-v2.1",
|
||||
Id: "cvm-123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CVM info with different values",
|
||||
req: &manager.CVMInfoReq{
|
||||
Id: "cvm-456",
|
||||
},
|
||||
mockOvmf: "OVMF-v2.0",
|
||||
mockCpuNum: 8,
|
||||
mockCpuType: "AMD-x86_64",
|
||||
mockEosVersion: "EOS-v3.0",
|
||||
expectedRes: &manager.CVMInfoRes{
|
||||
OvmfVersion: "OVMF-v2.0",
|
||||
CpuNum: 8,
|
||||
CpuType: "AMD-x86_64",
|
||||
EosVersion: "EOS-v3.0",
|
||||
Id: "cvm-456",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CVM info with empty ID",
|
||||
req: &manager.CVMInfoReq{
|
||||
Id: "",
|
||||
},
|
||||
mockOvmf: "OVMF-v1.5",
|
||||
mockCpuNum: 2,
|
||||
mockCpuType: "ARM64",
|
||||
mockEosVersion: "EOS-v1.8",
|
||||
expectedRes: &manager.CVMInfoRes{
|
||||
OvmfVersion: "OVMF-v1.5",
|
||||
CpuNum: 2,
|
||||
CpuType: "ARM64",
|
||||
EosVersion: "EOS-v1.8",
|
||||
Id: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CVM info with zero CPU count",
|
||||
req: &manager.CVMInfoReq{
|
||||
Id: "cvm-zero",
|
||||
},
|
||||
mockOvmf: "OVMF-v1.0",
|
||||
mockCpuNum: 0,
|
||||
mockCpuType: "Unknown",
|
||||
mockEosVersion: "EOS-v1.0",
|
||||
expectedRes: &manager.CVMInfoRes{
|
||||
OvmfVersion: "OVMF-v1.0",
|
||||
CpuNum: 0,
|
||||
CpuType: "Unknown",
|
||||
EosVersion: "EOS-v1.0",
|
||||
Id: "cvm-zero",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
mockSvc.On("ReturnCVMInfo", mock.Anything).Return(
|
||||
tt.mockOvmf, tt.mockCpuNum, tt.mockCpuType, tt.mockEosVersion)
|
||||
|
||||
res, err := server.CVMInfo(context.Background(), tt.req)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRes, res)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *manager.AttestationPolicyReq
|
||||
mockPolicy string
|
||||
mockErr error
|
||||
expectedRes *manager.AttestationPolicyRes
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "successful attestation policy fetch",
|
||||
req: &manager.AttestationPolicyReq{
|
||||
Id: "policy-123",
|
||||
},
|
||||
mockPolicy: `{"version": "1.0", "rules": ["rule1", "rule2"]}`,
|
||||
mockErr: nil,
|
||||
expectedRes: &manager.AttestationPolicyRes{
|
||||
Info: []byte(`{"version": "1.0", "rules": ["rule1", "rule2"]}`),
|
||||
Id: "policy-123",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "attestation policy fetch failure",
|
||||
req: &manager.AttestationPolicyReq{
|
||||
Id: "policy-456",
|
||||
},
|
||||
mockPolicy: "",
|
||||
mockErr: errors.New("policy not found"),
|
||||
expectedRes: nil,
|
||||
expectedErr: errors.New("policy not found"),
|
||||
},
|
||||
{
|
||||
name: "attestation policy with empty ID",
|
||||
req: &manager.AttestationPolicyReq{
|
||||
Id: "",
|
||||
},
|
||||
mockPolicy: "",
|
||||
mockErr: errors.New("invalid policy ID"),
|
||||
expectedRes: nil,
|
||||
expectedErr: errors.New("invalid policy ID"),
|
||||
},
|
||||
{
|
||||
name: "attestation policy with different content",
|
||||
req: &manager.AttestationPolicyReq{
|
||||
Id: "policy-789",
|
||||
},
|
||||
mockPolicy: `{"version": "2.0", "attestation_type": "SGX"}`,
|
||||
mockErr: nil,
|
||||
expectedRes: &manager.AttestationPolicyRes{
|
||||
Info: []byte(`{"version": "2.0", "attestation_type": "SGX"}`),
|
||||
Id: "policy-789",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "attestation policy with empty policy content",
|
||||
req: &manager.AttestationPolicyReq{
|
||||
Id: "policy-empty",
|
||||
},
|
||||
mockPolicy: "",
|
||||
mockErr: nil,
|
||||
expectedRes: &manager.AttestationPolicyRes{
|
||||
Info: []byte{},
|
||||
Id: "policy-empty",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
mockSvc.On("FetchAttestationPolicy", mock.Anything, tt.req.Id).Return([]byte(tt.mockPolicy), tt.mockErr)
|
||||
|
||||
res, err := server.AttestationPolicy(context.Background(), tt.req)
|
||||
|
||||
if tt.expectedErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.expectedErr.Error(), err.Error())
|
||||
assert.Nil(t, res)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRes, res)
|
||||
}
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextCancellation(t *testing.T) {
|
||||
t.Run("CreateVm with cancelled context", func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel the context immediately
|
||||
|
||||
req := &manager.CreateReq{}
|
||||
mockSvc.On("CreateVM", mock.Anything, req).Return("", "", context.Canceled)
|
||||
|
||||
res, err := server.CreateVm(ctx, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.Canceled, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("RemoveVm with cancelled context", func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel the context immediately
|
||||
|
||||
req := &manager.RemoveReq{CvmId: "vm-123"}
|
||||
mockSvc.On("RemoveVM", mock.Anything, "vm-123").Return(context.Canceled)
|
||||
|
||||
res, err := server.RemoveVm(ctx, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.Canceled, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
t.Run("service returns multiple error types", func(t *testing.T) {
|
||||
mockSvc := new(mocks.Service)
|
||||
server := NewServer(mockSvc)
|
||||
|
||||
// Test with different error types
|
||||
customErr := errors.New("custom service error")
|
||||
|
||||
req := &manager.CreateReq{}
|
||||
mockSvc.On("CreateVM", mock.Anything, req).Return("", "", customErr)
|
||||
|
||||
res, err := server.CreateVm(context.Background(), req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, customErr, err)
|
||||
assert.Nil(t, res)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
@@ -265,6 +265,51 @@ func (_c *Service_ReturnCVMInfo_Call) RunAndReturn(run func(context.Context) (st
|
||||
return _c
|
||||
}
|
||||
|
||||
// Shutdown provides a mock function with no fields
|
||||
func (_m *Service) Shutdown() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Shutdown")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown'
|
||||
type Service_Shutdown_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Shutdown is a helper method to define mock.On call
|
||||
func (_e *Service_Expecter) Shutdown() *Service_Shutdown_Call {
|
||||
return &Service_Shutdown_Call{Call: _e.mock.On("Shutdown")}
|
||||
}
|
||||
|
||||
func (_c *Service_Shutdown_Call) Run(run func()) *Service_Shutdown_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Shutdown_Call) Return(_a0 error) *Service_Shutdown_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_Shutdown_Call) RunAndReturn(run func() error) *Service_Shutdown_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
|
||||
@@ -189,3 +189,11 @@ func TestTDXEnabled(t *testing.T) {
|
||||
assert.False(t, TDXEnabled("flags: tdx_host_platform", "0"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSEVSNPEnabledOnHost(t *testing.T) {
|
||||
assert.False(t, SEVSNPEnabledOnHost())
|
||||
}
|
||||
|
||||
func TestTDXEnabledOnHost(t *testing.T) {
|
||||
assert.False(t, TDXEnabledOnHost())
|
||||
}
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package qemu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
)
|
||||
|
||||
const VsockConfigPort uint32 = 9999
|
||||
|
||||
func (v *qemuVM) SendAgentConfig(ac agent.Computation) error {
|
||||
conn, err := vsock.Dial(uint32(v.vmi.Config.GuestCID), VsockConfigPort, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
payload, err := json.Marshal(ac)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+1
-49
@@ -6,10 +6,8 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
manager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
manager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
// VM is an autogenerated mock type for the VM type
|
||||
@@ -162,52 +160,6 @@ func (_c *VM_GetProcess_Call) RunAndReturn(run func() int) *VM_GetProcess_Call {
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendAgentConfig provides a mock function with given fields: ac
|
||||
func (_m *VM) SendAgentConfig(ac agent.Computation) error {
|
||||
ret := _m.Called(ac)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendAgentConfig")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(agent.Computation) error); ok {
|
||||
r0 = rf(ac)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// VM_SendAgentConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendAgentConfig'
|
||||
type VM_SendAgentConfig_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendAgentConfig is a helper method to define mock.On call
|
||||
// - ac agent.Computation
|
||||
func (_e *VM_Expecter) SendAgentConfig(ac interface{}) *VM_SendAgentConfig_Call {
|
||||
return &VM_SendAgentConfig_Call{Call: _e.mock.On("SendAgentConfig", ac)}
|
||||
}
|
||||
|
||||
func (_c *VM_SendAgentConfig_Call) Run(run func(ac agent.Computation)) *VM_SendAgentConfig_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(agent.Computation))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *VM_SendAgentConfig_Call) Return(_a0 error) *VM_SendAgentConfig_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *VM_SendAgentConfig_Call) RunAndReturn(run func(agent.Computation) error) *VM_SendAgentConfig_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetProcess provides a mock function with given fields: pid
|
||||
func (_m *VM) SetProcess(pid int) error {
|
||||
ret := _m.Called(pid)
|
||||
|
||||
@@ -5,7 +5,6 @@ package vm
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
type VM interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
SendAgentConfig(ac agent.Computation) error
|
||||
SetProcess(pid int) error
|
||||
GetProcess() int
|
||||
GetCID() int
|
||||
|
||||
+24
@@ -20,6 +20,11 @@ packages:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "agent_grpc_algo.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
AgentService_IMAMeasurementsClient:
|
||||
config:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "agent_grpc_ima.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
github.com/ultravioletrs/cocos/agent/auth:
|
||||
interfaces:
|
||||
Authenticator:
|
||||
@@ -119,3 +124,22 @@ packages:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "attestation.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
Verifier:
|
||||
config:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "verifier.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
github.com/ultravioletrs/cocos/agent/algorithm:
|
||||
interfaces:
|
||||
Algorithm:
|
||||
config:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "algorithm.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
github.com/ultravioletrs/cocos/pkg/attestation/cmdconfig:
|
||||
interfaces:
|
||||
MeasurementProvider:
|
||||
config:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "measurement_provider.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
|
||||
@@ -3,19 +3,33 @@
|
||||
package atls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
certssdk "github.com/absmach/certs/sdk"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
@@ -217,6 +231,346 @@ func TestGetPlatformTypeFromOID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCertificateExtension(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
attestationPB := prepVerifyAttReport(t)
|
||||
err = setAttestationPolicy(attestationPB, tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKeyDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err = rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
teeNonce := append(pubKeyDER, nonce...)
|
||||
hashNonce := sha3.Sum512(teeNonce)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
extension []byte
|
||||
pubKey []byte
|
||||
nonce []byte
|
||||
platformType attestation.PlatformType
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid extension with SNPvTPM",
|
||||
extension: hashNonce[:],
|
||||
pubKey: pubKeyDER,
|
||||
nonce: nonce,
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid platform type",
|
||||
extension: hashNonce[:],
|
||||
pubKey: pubKeyDER,
|
||||
nonce: nonce,
|
||||
platformType: 999,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty extension",
|
||||
extension: []byte{},
|
||||
pubKey: pubKeyDER,
|
||||
nonce: nonce,
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty public key",
|
||||
extension: hashNonce[:],
|
||||
pubKey: []byte{},
|
||||
nonce: nonce,
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty nonce",
|
||||
extension: hashNonce[:],
|
||||
pubKey: pubKeyDER,
|
||||
nonce: []byte{},
|
||||
platformType: attestation.SNPvTPM,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
err := VerifyCertificateExtension(c.extension, c.pubKey, c.nonce, c.platformType)
|
||||
if c.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateExtension(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return([]byte("mock-attestation-data"), nil)
|
||||
|
||||
pubKey := []byte("test-public-key")
|
||||
nonce := make([]byte, 32)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
testOID := asn1.ObjectIdentifier{1, 2, 3, 4}
|
||||
|
||||
extension, err := getCertificateExtension(mockProvider, pubKey, nonce, testOID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testOID, extension.Id)
|
||||
assert.Equal(t, []byte("mock-attestation-data"), extension.Value)
|
||||
}
|
||||
|
||||
func TestGetCertificateWithSelfSigned(t *testing.T) {
|
||||
getCertFunc := GetCertificate("", "")
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverName := hex.EncodeToString(nonce) + ".nonce"
|
||||
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
cert, err := getCertFunc(clientHello)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Expected error due to missing attestation setup: %v", err)
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NotNil(t, cert)
|
||||
assert.NotEmpty(t, cert.Certificate)
|
||||
assert.NotNil(t, cert.PrivateKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateWithCA(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
mockCert := certssdk.Certificate{
|
||||
Certificate: "-----BEGIN CERTIFICATE-----\nMIIBkTCB+wIBATANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1UZXN0IENBIFJvb3QwHhcNMjMwMzMxMDAwMDAwWhcNMjQwMzMxMDAwMDAwWjAYMRYwFAYDVQQDDA1UZXN0IENlcnRpZmljYXRlMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEtest-key-data-here\n-----END CERTIFICATE-----",
|
||||
}
|
||||
|
||||
response, _ := json.Marshal(mockCert)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write(response); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
getCertFunc := GetCertificate(mockServer.URL, "test-cvm-id")
|
||||
|
||||
nonce := make([]byte, 64)
|
||||
_, err := rand.Read(nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverName := hex.EncodeToString(nonce) + ".nonce"
|
||||
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: serverName,
|
||||
}
|
||||
|
||||
_, err = getCertFunc(clientHello)
|
||||
if err != nil {
|
||||
t.Logf("Expected error due to missing attestation setup: %v", err)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateInvalidServerName(t *testing.T) {
|
||||
getCertFunc := GetCertificate("", "")
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
serverName string
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "Missing .nonce suffix",
|
||||
serverName: "invalidname",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
{
|
||||
name: "Too short server name",
|
||||
serverName: "short",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
{
|
||||
name: "Invalid nonce encoding",
|
||||
serverName: "invalidhex.nonce",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
{
|
||||
name: "Wrong nonce length",
|
||||
serverName: "deadbeef.nonce",
|
||||
expectErr: "failed to get platform provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: c.serverName,
|
||||
}
|
||||
|
||||
cert, err := getCertFunc(clientHello)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), c.expectErr)
|
||||
assert.Nil(t, cert)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessRequest(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/success":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(`{"message": "success"}`)); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "/notfound":
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
if _, err := w.Write([]byte(`{"error": "not found"}`)); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "/headers":
|
||||
if r.Header.Get("X-Custom-Header") == "test-value" {
|
||||
w.Header().Set("X-Response-Header", "received")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte(`{"headers": "ok"}`)); err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
default:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
method string
|
||||
url string
|
||||
data []byte
|
||||
headers map[string]string
|
||||
expectedRespCodes []int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Successful GET request",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/success",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Successful POST request with data",
|
||||
method: http.MethodPost,
|
||||
url: testServer.URL + "/success",
|
||||
data: []byte(`{"test": "data"}`),
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Request with custom headers",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/headers",
|
||||
data: nil,
|
||||
headers: map[string]string{"X-Custom-Header": "test-value"},
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Request with unexpected status code",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/notfound",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Request with multiple expected status codes",
|
||||
method: http.MethodGet,
|
||||
url: testServer.URL + "/notfound",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK, http.StatusNotFound},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Request to invalid URL",
|
||||
method: http.MethodGet,
|
||||
url: "invalid-url",
|
||||
data: nil,
|
||||
headers: nil,
|
||||
expectedRespCodes: []int{http.StatusOK},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
headers, body, err := processRequest(c.method, c.url, c.data, c.headers, c.expectedRespCodes...)
|
||||
|
||||
if c.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, headers)
|
||||
assert.NotNil(t, body)
|
||||
|
||||
if c.name == "Request with custom headers" {
|
||||
assert.Equal(t, "received", headers.Get("X-Response-Header"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCertificateExtensionError(t *testing.T) {
|
||||
mockProvider := new(mocks.Provider)
|
||||
|
||||
mockProvider.On("Attestation", mock.Anything, mock.Anything).Return(nil, errors.New("failed to get attestation"))
|
||||
|
||||
pubKey := []byte("test-public-key")
|
||||
nonce := make([]byte, 32)
|
||||
testOID := asn1.ObjectIdentifier{1, 2, 3, 4}
|
||||
|
||||
extension, err := getCertificateExtension(mockProvider, pubKey, nonce, testOID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to get attestation")
|
||||
assert.Equal(t, pkix.Extension{}, extension)
|
||||
}
|
||||
|
||||
func prepVerifyAttReport(t *testing.T) *sevsnp.Attestation {
|
||||
file, err := os.ReadFile("../../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package attestation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCCPlatform(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sevSnpGuestExists bool
|
||||
sevSnpGuestvTPMExists bool
|
||||
tdxGuestExists bool
|
||||
isAzure bool
|
||||
expected PlatformType
|
||||
}{
|
||||
{
|
||||
name: "No CC platform detected",
|
||||
sevSnpGuestExists: false,
|
||||
sevSnpGuestvTPMExists: false,
|
||||
tdxGuestExists: false,
|
||||
isAzure: false,
|
||||
expected: NoCC,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CCPlatform()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSevSnpGuestDeviceExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
openDeviceErr error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "device does not exist or fails to open",
|
||||
openDeviceErr: fmt.Errorf("device not found"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SevSnpGuestDeviceExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSevSnpGuestvTPMExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vTPMExists bool
|
||||
sevSnpExists bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "vTPM exists but SEV-SNP does not",
|
||||
vTPMExists: true,
|
||||
sevSnpExists: false,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "SEV-SNP exists but vTPM does not",
|
||||
vTPMExists: false,
|
||||
sevSnpExists: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "neither exists",
|
||||
vTPMExists: false,
|
||||
sevSnpExists: false,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SevSnpGuestvTPMExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVTPMExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
openTPMErr error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "TPM fails to open",
|
||||
openTPMErr: fmt.Errorf("TPM not found"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := vTPMExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAzureVM(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vTPMExists bool
|
||||
statusCode int
|
||||
responseBody string
|
||||
httpError error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Azure VM with empty response body",
|
||||
vTPMExists: true,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: "",
|
||||
httpError: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Azure VM with non-200 status code",
|
||||
vTPMExists: true,
|
||||
statusCode: http.StatusNotFound,
|
||||
responseBody: "",
|
||||
httpError: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP request error",
|
||||
vTPMExists: true,
|
||||
statusCode: 0,
|
||||
responseBody: "",
|
||||
httpError: fmt.Errorf("connection failed"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "vTPM does not exist",
|
||||
vTPMExists: false,
|
||||
statusCode: http.StatusOK,
|
||||
responseBody: `{"compute":{"name":"test-vm"}}`,
|
||||
httpError: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "GET", r.Method)
|
||||
assert.Equal(t, "true", r.Header.Get("Metadata"))
|
||||
expectedURL := fmt.Sprintf("/?api-version=%s", azureApiVersion)
|
||||
assert.Equal(t, expectedURL, r.URL.String())
|
||||
|
||||
if tt.httpError != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(tt.statusCode)
|
||||
if tt.responseBody != "" {
|
||||
if _, err := w.Write([]byte(tt.responseBody)); err != nil {
|
||||
t.Fatalf("Failed to write response body: %v", err)
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
if tt.httpError != nil {
|
||||
server.Close()
|
||||
}
|
||||
|
||||
result := isAzureVM()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTDXGuestDeviceExists(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
openDeviceErr error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "TDX device does not exist or fails to open",
|
||||
openDeviceErr: fmt.Errorf("device not found"),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TDXGuestDeviceExists()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,578 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
testNonce = []byte("test-nonce-12345678901234567890123456789012")
|
||||
testReport = []byte("test-report-data")
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want attestation.Provider
|
||||
}{
|
||||
{
|
||||
name: "creates new provider successfully",
|
||||
want: provider{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewProvider()
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Attestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "maa parameters error",
|
||||
teeNonce: testNonce,
|
||||
vTpmNonce: testNonce,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to get report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewProvider()
|
||||
|
||||
result, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "maa parameters error",
|
||||
teeNonce: testNonce,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to get report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewProvider()
|
||||
|
||||
result, err := p.TeeAttestation(tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureAttestationToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenNonce []byte
|
||||
setupServer func() *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
tokenNonce: testNonce,
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
wantErr: true,
|
||||
errorMessage: "failed to fetch Azure token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = server.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
|
||||
p := NewProvider()
|
||||
|
||||
result, err := p.AzureAttestationToken(tt.tokenNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writer io.Writer
|
||||
}{
|
||||
{
|
||||
name: "creates verifier with buffer writer",
|
||||
writer: &bytes.Buffer{},
|
||||
},
|
||||
{
|
||||
name: "creates verifier with nil writer",
|
||||
writer: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(tt.writer)
|
||||
|
||||
verifier, ok := v.(verifier)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.writer, verifier.writer)
|
||||
assert.NotNil(t, verifier.Policy)
|
||||
assert.NotNil(t, verifier.Policy.Config)
|
||||
assert.NotNil(t, verifier.Policy.PcrConfig)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifierWithPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writer io.Writer
|
||||
policy *attestation.Config
|
||||
}{
|
||||
{
|
||||
name: "creates verifier with custom policy",
|
||||
writer: &bytes.Buffer{},
|
||||
policy: &attestation.Config{
|
||||
Config: &check.Config{
|
||||
Policy: &check.Policy{},
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "creates verifier with nil policy",
|
||||
writer: &bytes.Buffer{},
|
||||
policy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifierWithPolicy(tt.writer, tt.policy)
|
||||
|
||||
verifier, ok := v.(verifier)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.writer, verifier.writer)
|
||||
assert.NotNil(t, verifier.Policy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifTeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "empty report",
|
||||
report: []byte{},
|
||||
teeNonce: testNonce,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid report format",
|
||||
report: []byte("invalid-report"),
|
||||
teeNonce: testNonce,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil nonce",
|
||||
report: testReport,
|
||||
teeNonce: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
err := v.VerifTeeAttestation(tt.report, tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyAttestation(t *testing.T) {
|
||||
validQuote := &attest.Attestation{
|
||||
TeeAttestation: &attest.Attestation_SevSnpAttestation{
|
||||
SevSnpAttestation: &sevsnp.Attestation{
|
||||
Report: &sevsnp.Report{
|
||||
HostData: []byte("test-data"),
|
||||
},
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
|
||||
},
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
Extras: make(map[string][]byte),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
validReport, _ := proto.Marshal(validQuote)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "successful verification",
|
||||
report: validReport,
|
||||
teeNonce: testNonce,
|
||||
vTpmNonce: testNonce,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to verify vTPM attestation report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
err := v.VerifyAttestation(tt.report, tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAzureAttestationToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenNonce []byte
|
||||
maaURL string
|
||||
setupServer func() *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "server error",
|
||||
tokenNonce: testNonce,
|
||||
setupServer: func() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
},
|
||||
wantErr: true,
|
||||
errorMessage: "error fetching azure token",
|
||||
},
|
||||
{
|
||||
name: "invalid url",
|
||||
tokenNonce: testNonce,
|
||||
setupServer: func() *httptest.Server {
|
||||
return nil
|
||||
},
|
||||
wantErr: true,
|
||||
errorMessage: "error fetching azure token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var url string
|
||||
if tt.setupServer != nil {
|
||||
server := tt.setupServer()
|
||||
if server != nil {
|
||||
defer server.Close()
|
||||
url = server.URL
|
||||
}
|
||||
}
|
||||
|
||||
if tt.name == "invalid url" {
|
||||
url = "invalid-url"
|
||||
}
|
||||
|
||||
result, err := FetchAzureAttestationToken(tt.tokenNonce, url)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
setupServer func() *httptest.Server
|
||||
wantErr bool
|
||||
errorMessage string
|
||||
}{
|
||||
{
|
||||
name: "invalid token format",
|
||||
token: "invalid-token",
|
||||
setupServer: nil,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to parse token",
|
||||
},
|
||||
{
|
||||
name: "empty token",
|
||||
token: "",
|
||||
setupServer: nil,
|
||||
wantErr: true,
|
||||
errorMessage: "failed to parse token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupServer != nil {
|
||||
server := tt.setupServer()
|
||||
defer server.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = server.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
}
|
||||
|
||||
result, err := validateToken(tt.token)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMessage != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMessage)
|
||||
}
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_FullAttestationFlow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
t.Run("full attestation flow with mock server", func(t *testing.T) {
|
||||
maaServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/attest":
|
||||
response := map[string]interface{}{
|
||||
"token": createMockJWT(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
t.Fatalf("Failed to encode response: %v", err)
|
||||
}
|
||||
case "/.well-known/openid_configuration":
|
||||
config := map[string]interface{}{
|
||||
"jwks_uri": "maaServer.URL" + "/certs",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(config); err != nil {
|
||||
t.Fatalf("Failed to encode OpenID configuration: %v", err)
|
||||
}
|
||||
case "/certs":
|
||||
jwks := map[string]interface{}{
|
||||
"keys": []map[string]interface{}{
|
||||
{
|
||||
"kid": "test-kid",
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"n": "test-n-value",
|
||||
"e": "AQAB",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(jwks); err != nil {
|
||||
t.Fatalf("Failed to encode JWKS: %v", err)
|
||||
}
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer maaServer.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = maaServer.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
|
||||
provider := NewProvider()
|
||||
verifier := NewVerifier(&bytes.Buffer{})
|
||||
|
||||
teeNonce := []byte("test-tee-nonce-1234567890123456789012")
|
||||
vtpmNonce := []byte("test-vtpm-nonce-123456789012345678901")
|
||||
|
||||
teeReport, err := provider.TeeAttestation(teeNonce)
|
||||
if err != nil {
|
||||
t.Logf("TEE attestation failed (expected in mock environment): %v", err)
|
||||
}
|
||||
|
||||
vtpmReport, err := provider.VTpmAttestation(vtpmNonce)
|
||||
if err != nil {
|
||||
t.Logf("vTPM attestation failed (expected in mock environment): %v", err)
|
||||
}
|
||||
|
||||
token, err := provider.AzureAttestationToken(teeNonce)
|
||||
if err != nil {
|
||||
t.Logf("Azure attestation token failed (expected in mock environment): %v", err)
|
||||
}
|
||||
|
||||
assert.NotNil(t, provider)
|
||||
assert.NotNil(t, verifier)
|
||||
|
||||
t.Logf("TEE report length: %d", len(teeReport))
|
||||
t.Logf("vTPM report length: %d", len(vtpmReport))
|
||||
t.Logf("Token length: %d", len(token))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegration_ErrorPropagation(t *testing.T) {
|
||||
t.Run("error propagation through full stack", func(t *testing.T) {
|
||||
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
if _, err := w.Write([]byte("Internal Server Error")); err != nil {
|
||||
t.Fatalf("Failed to write response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer failingServer.Close()
|
||||
|
||||
originalURL := MaaURL
|
||||
MaaURL = failingServer.URL
|
||||
defer func() { MaaURL = originalURL }()
|
||||
|
||||
provider := NewProvider()
|
||||
|
||||
_, err := provider.AzureAttestationToken([]byte("test-nonce"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to fetch Azure token")
|
||||
|
||||
_, err = GenerateAttestationPolicy("invalid-token", "test-product", 1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to validate token")
|
||||
})
|
||||
}
|
||||
|
||||
func createMockJWT() string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-audience",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"x-ms-isolation-tee": map[string]interface{}{
|
||||
"x-ms-sevsnpvm-familyId": "1234567890abcdef",
|
||||
"x-ms-sevsnpvm-imageId": "fedcba0987654321",
|
||||
"x-ms-sevsnpvm-launchmeasurement": "abcdef1234567890",
|
||||
"x-ms-sevsnpvm-bootloader-svn": float64(1),
|
||||
"x-ms-sevsnpvm-tee-svn": float64(2),
|
||||
"x-ms-sevsnpvm-snpfw-svn": float64(3),
|
||||
"x-ms-sevsnpvm-microcode-svn": float64(4),
|
||||
"x-ms-sevsnpvm-guestsvn": float64(5),
|
||||
"x-ms-sevsnpvm-idkeydigest": "1234567890abcdef",
|
||||
"x-ms-sevsnpvm-reportid": "fedcba0987654321",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["jku"] = "https://test-url.com"
|
||||
token.Header["kid"] = "test-kid"
|
||||
|
||||
// Return unsigned token for testing
|
||||
return token.Raw
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// MeasurementProvider is an autogenerated mock type for the MeasurementProvider type
|
||||
type MeasurementProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MeasurementProvider_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MeasurementProvider) EXPECT() *MeasurementProvider_Expecter {
|
||||
return &MeasurementProvider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: binaryPath
|
||||
func (_m *MeasurementProvider) Run(binaryPath string) ([]byte, error) {
|
||||
ret := _m.Called(binaryPath)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Run")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string) ([]byte, error)); ok {
|
||||
return rf(binaryPath)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string) []byte); ok {
|
||||
r0 = rf(binaryPath)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(binaryPath)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MeasurementProvider_Run_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Run'
|
||||
type MeasurementProvider_Run_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Run is a helper method to define mock.On call
|
||||
// - binaryPath string
|
||||
func (_e *MeasurementProvider_Expecter) Run(binaryPath interface{}) *MeasurementProvider_Run_Call {
|
||||
return &MeasurementProvider_Run_Call{Call: _e.mock.On("Run", binaryPath)}
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Run_Call) Run(run func(binaryPath string)) *MeasurementProvider_Run_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Run_Call) Return(_a0 []byte, _a1 error) *MeasurementProvider_Run_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Run_Call) RunAndReturn(run func(string) ([]byte, error)) *MeasurementProvider_Run_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with no fields
|
||||
func (_m *MeasurementProvider) Stop() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MeasurementProvider_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type MeasurementProvider_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *MeasurementProvider_Expecter) Stop() *MeasurementProvider_Stop_Call {
|
||||
return &MeasurementProvider_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Stop_Call) Run(run func()) *MeasurementProvider_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Stop_Call) Return(_a0 error) *MeasurementProvider_Stop_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MeasurementProvider_Stop_Call) RunAndReturn(run func() error) *MeasurementProvider_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMeasurementProvider creates a new instance of MeasurementProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMeasurementProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MeasurementProvider {
|
||||
mock := &MeasurementProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package gcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"cloud.google.com/go/storage"
|
||||
"github.com/google/gce-tcb-verifier/proto/endorsement"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestExtract384BitMeasurement(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
setupMock func()
|
||||
expected string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil attestation",
|
||||
attestation: nil,
|
||||
expectError: true,
|
||||
errorMsg: "report is nil",
|
||||
},
|
||||
{
|
||||
name: "short report",
|
||||
attestation: &sevsnp.Attestation{Report: &sevsnp.Report{}},
|
||||
expectError: true,
|
||||
errorMsg: "failed to transform report to binary",
|
||||
},
|
||||
{
|
||||
name: "empty report",
|
||||
attestation: &sevsnp.Attestation{},
|
||||
expectError: true,
|
||||
errorMsg: "failed to transform report to binary",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := Extract384BitMeasurement(tt.attestation)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
assert.Empty(t, result)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLaunchEndorsement(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
measurement384 string
|
||||
setupMock func() ([]byte, error)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful retrieval",
|
||||
measurement384: "test-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
goldenUEFI := &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{1: []byte("test-measurement")},
|
||||
},
|
||||
}
|
||||
goldenBytes, _ := proto.Marshal(goldenUEFI)
|
||||
|
||||
launchEndorsement := &endorsement.VMLaunchEndorsement{
|
||||
SerializedUefiGolden: goldenBytes,
|
||||
}
|
||||
return proto.Marshal(launchEndorsement)
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "storage client error",
|
||||
measurement384: "test-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
return nil, errors.New("storage client error")
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "object not found",
|
||||
measurement384: "non-existent-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
return nil, storage.ErrObjectNotExist
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "invalid protobuf data",
|
||||
measurement384: "test-measurement",
|
||||
setupMock: func() ([]byte, error) {
|
||||
return []byte("invalid protobuf data"), nil
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// skip if credentials are not set
|
||||
if _, err := storage.NewClient(ctx); err != nil && tt.expectError {
|
||||
t.Skip("Skipping test due to missing GCP credentials")
|
||||
}
|
||||
|
||||
_, err := GetLaunchEndorsement(ctx, tt.measurement384)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAttestationPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endorsement *endorsement.VMGoldenMeasurement
|
||||
vcpuNum uint32
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid endorsement",
|
||||
endorsement: &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{1: []byte("test-measurement")},
|
||||
},
|
||||
},
|
||||
vcpuNum: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing measurement for vcpu",
|
||||
endorsement: &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{2: []byte("test-measurement")},
|
||||
},
|
||||
},
|
||||
vcpuNum: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty measurements map",
|
||||
endorsement: &endorsement.VMGoldenMeasurement{
|
||||
SevSnp: &endorsement.VMSevSnp{
|
||||
Policy: 12345,
|
||||
Measurements: map[uint32][]byte{},
|
||||
},
|
||||
},
|
||||
vcpuNum: 1,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := GenerateAttestationPolicy(tt.endorsement, tt.vcpuNum)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.NotNil(t, result.Config)
|
||||
assert.NotNil(t, result.Config.Policy)
|
||||
assert.NotNil(t, result.Config.RootOfTrust)
|
||||
assert.NotNil(t, result.PcrConfig)
|
||||
|
||||
assert.Equal(t, tt.endorsement.SevSnp.Policy, result.Config.Policy.Policy)
|
||||
assert.Equal(t, tt.endorsement.SevSnp.Measurements[tt.vcpuNum], result.Config.Policy.Measurement)
|
||||
assert.False(t, result.Config.RootOfTrust.DisallowNetwork)
|
||||
assert.True(t, result.Config.RootOfTrust.CheckCrl)
|
||||
assert.Equal(t, "Milan", result.Config.RootOfTrust.Product)
|
||||
assert.Equal(t, "Milan", result.Config.RootOfTrust.ProductLine)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadOvmfFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
digest string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "successful download",
|
||||
digest: "test-digest",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "storage client error",
|
||||
digest: "test-digest",
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "object not found",
|
||||
digest: "non-existent-digest",
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "read error",
|
||||
digest: "test-digest",
|
||||
expectError: true,
|
||||
errorMsg: "failed to create reader",
|
||||
},
|
||||
{
|
||||
name: "empty digest",
|
||||
digest: "",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// skip if credentials are not set
|
||||
if _, err := storage.NewClient(ctx); err != nil && tt.expectError {
|
||||
t.Skip("Skipping test due to missing GCP credentials")
|
||||
}
|
||||
|
||||
_, err := DownloadOvmfFile(ctx, tt.digest)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -253,148 +253,6 @@ func (_c *Provider_VTpmAttestation_Call) RunAndReturn(run func([]byte) ([]byte,
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifTeeAttestation provides a mock function with given fields: report, teeNonce
|
||||
func (_m *Provider) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifTeeAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Provider_VerifTeeAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifTeeAttestation'
|
||||
type Provider_VerifTeeAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifTeeAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
func (_e *Provider_Expecter) VerifTeeAttestation(report interface{}, teeNonce interface{}) *Provider_VerifTeeAttestation_Call {
|
||||
return &Provider_VerifTeeAttestation_Call{Call: _e.mock.On("VerifTeeAttestation", report, teeNonce)}
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifTeeAttestation_Call) Run(run func(report []byte, teeNonce []byte)) *Provider_VerifTeeAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifTeeAttestation_Call) Return(_a0 error) *Provider_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifTeeAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Provider_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation provides a mock function with given fields: report, vTpmNonce
|
||||
func (_m *Provider) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifVTpmAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Provider_VerifVTpmAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifVTpmAttestation'
|
||||
type Provider_VerifVTpmAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Provider_Expecter) VerifVTpmAttestation(report interface{}, vTpmNonce interface{}) *Provider_VerifVTpmAttestation_Call {
|
||||
return &Provider_VerifVTpmAttestation_Call{Call: _e.mock.On("VerifVTpmAttestation", report, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifVTpmAttestation_Call) Run(run func(report []byte, vTpmNonce []byte)) *Provider_VerifVTpmAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifVTpmAttestation_Call) Return(_a0 error) *Provider_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifVTpmAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Provider_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyAttestation provides a mock function with given fields: report, teeNonce, vTpmNonce
|
||||
func (_m *Provider) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Provider_VerifyAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAttestation'
|
||||
type Provider_VerifyAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Provider_Expecter) VerifyAttestation(report interface{}, teeNonce interface{}, vTpmNonce interface{}) *Provider_VerifyAttestation_Call {
|
||||
return &Provider_VerifyAttestation_Call{Call: _e.mock.On("VerifyAttestation", report, teeNonce, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifyAttestation_Call) Run(run func(report []byte, teeNonce []byte, vTpmNonce []byte)) *Provider_VerifyAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte), args[2].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifyAttestation_Call) Return(_a0 error) *Provider_VerifyAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Provider_VerifyAttestation_Call) RunAndReturn(run func([]byte, []byte, []byte) error) *Provider_VerifyAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewProvider creates a new instance of Provider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewProvider(t interface {
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import mock "github.com/stretchr/testify/mock"
|
||||
|
||||
// Verifier is an autogenerated mock type for the Verifier type
|
||||
type Verifier struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Verifier_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Verifier) EXPECT() *Verifier_Expecter {
|
||||
return &Verifier_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// JSONToPolicy provides a mock function with given fields: path
|
||||
func (_m *Verifier) JSONToPolicy(path string) error {
|
||||
ret := _m.Called(path)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for JSONToPolicy")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string) error); ok {
|
||||
r0 = rf(path)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_JSONToPolicy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'JSONToPolicy'
|
||||
type Verifier_JSONToPolicy_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// JSONToPolicy is a helper method to define mock.On call
|
||||
// - path string
|
||||
func (_e *Verifier_Expecter) JSONToPolicy(path interface{}) *Verifier_JSONToPolicy_Call {
|
||||
return &Verifier_JSONToPolicy_Call{Call: _e.mock.On("JSONToPolicy", path)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_JSONToPolicy_Call) Run(run func(path string)) *Verifier_JSONToPolicy_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_JSONToPolicy_Call) Return(_a0 error) *Verifier_JSONToPolicy_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_JSONToPolicy_Call) RunAndReturn(run func(string) error) *Verifier_JSONToPolicy_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifTeeAttestation provides a mock function with given fields: report, teeNonce
|
||||
func (_m *Verifier) VerifTeeAttestation(report []byte, teeNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifTeeAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifTeeAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifTeeAttestation'
|
||||
type Verifier_VerifTeeAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifTeeAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifTeeAttestation(report interface{}, teeNonce interface{}) *Verifier_VerifTeeAttestation_Call {
|
||||
return &Verifier_VerifTeeAttestation_Call{Call: _e.mock.On("VerifTeeAttestation", report, teeNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifTeeAttestation_Call) Run(run func(report []byte, teeNonce []byte)) *Verifier_VerifTeeAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifTeeAttestation_Call) Return(_a0 error) *Verifier_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifTeeAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Verifier_VerifTeeAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation provides a mock function with given fields: report, vTpmNonce
|
||||
func (_m *Verifier) VerifVTpmAttestation(report []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifVTpmAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte) error); ok {
|
||||
r0 = rf(report, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifVTpmAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifVTpmAttestation'
|
||||
type Verifier_VerifVTpmAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifVTpmAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifVTpmAttestation(report interface{}, vTpmNonce interface{}) *Verifier_VerifVTpmAttestation_Call {
|
||||
return &Verifier_VerifVTpmAttestation_Call{Call: _e.mock.On("VerifVTpmAttestation", report, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifVTpmAttestation_Call) Run(run func(report []byte, vTpmNonce []byte)) *Verifier_VerifVTpmAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifVTpmAttestation_Call) Return(_a0 error) *Verifier_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifVTpmAttestation_Call) RunAndReturn(run func([]byte, []byte) error) *Verifier_VerifVTpmAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyAttestation provides a mock function with given fields: report, teeNonce, vTpmNonce
|
||||
func (_m *Verifier) VerifyAttestation(report []byte, teeNonce []byte, vTpmNonce []byte) error {
|
||||
ret := _m.Called(report, teeNonce, vTpmNonce)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyAttestation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func([]byte, []byte, []byte) error); ok {
|
||||
r0 = rf(report, teeNonce, vTpmNonce)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Verifier_VerifyAttestation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyAttestation'
|
||||
type Verifier_VerifyAttestation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyAttestation is a helper method to define mock.On call
|
||||
// - report []byte
|
||||
// - teeNonce []byte
|
||||
// - vTpmNonce []byte
|
||||
func (_e *Verifier_Expecter) VerifyAttestation(report interface{}, teeNonce interface{}, vTpmNonce interface{}) *Verifier_VerifyAttestation_Call {
|
||||
return &Verifier_VerifyAttestation_Call{Call: _e.mock.On("VerifyAttestation", report, teeNonce, vTpmNonce)}
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyAttestation_Call) Run(run func(report []byte, teeNonce []byte, vTpmNonce []byte)) *Verifier_VerifyAttestation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]byte), args[1].([]byte), args[2].([]byte))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyAttestation_Call) Return(_a0 error) *Verifier_VerifyAttestation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Verifier_VerifyAttestation_Call) RunAndReturn(run func([]byte, []byte, []byte) error) *Verifier_VerifyAttestation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewVerifier creates a new instance of Verifier. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewVerifier(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Verifier {
|
||||
mock := &Verifier{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -7,11 +7,10 @@
|
||||
package quoteprovider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -19,51 +18,361 @@ import (
|
||||
)
|
||||
|
||||
func TestFillInAttestationLocal(t *testing.T) {
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer func() {
|
||||
os.Setenv("HOME", originalHome)
|
||||
}()
|
||||
|
||||
tempDir, err := os.MkdirTemp("", "test_home")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
cocosDir := tempDir + "/.cocos/Milan"
|
||||
os.Setenv("HOME", tempDir)
|
||||
|
||||
cocosDir := path.Join(tempDir, cocosDirectory, sevSnpProductMilan)
|
||||
err = os.MkdirAll(cocosDir, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
bundleContent := []byte("mock ASK ARK bundle")
|
||||
err = os.WriteFile(cocosDir+"/ask_ark.pem", bundleContent, 0o644)
|
||||
bundlePath := path.Join(cocosDir, caBundleName)
|
||||
err = os.WriteFile(bundlePath, bundleContent, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{},
|
||||
Policy: &check.Policy{},
|
||||
config := &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
err error
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
setupFunc func()
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty attestation",
|
||||
name: "Empty attestation - creates new chain",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
CertificateChain: nil,
|
||||
},
|
||||
err: nil,
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "Attestation with existing chain",
|
||||
name: "Attestation with existing chain - no changes needed",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("existing ASK cert"),
|
||||
ArkCert: []byte("existing ARK cert"),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
setupFunc: func() {},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Attestation with empty chain - tries to load from file",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {},
|
||||
expectedError: true,
|
||||
errorContains: "could not find ASK or ASVK PEM block; could not find ARK PEM block",
|
||||
},
|
||||
{
|
||||
name: "No bundle file exists - no error",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
setupFunc: func() {
|
||||
os.Remove(bundlePath)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := fillInAttestationLocal(tt.attestation, &config)
|
||||
assert.True(t, errors.Contains(err, tt.err), fmt.Sprintf("expected error %v, got %v", tt.err, err))
|
||||
os.Setenv("HOME", tempDir)
|
||||
if _, err := os.Stat(bundlePath); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(bundlePath, bundleContent, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write bundle file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
tt.setupFunc()
|
||||
|
||||
err := fillInAttestationLocal(tt.attestation, config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProductName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
product string
|
||||
expected sevsnp.SevProduct_SevProductName
|
||||
}{
|
||||
{
|
||||
name: "Milan product",
|
||||
product: sevSnpProductMilan,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
{
|
||||
name: "Genoa product",
|
||||
product: sevSnpProductGenoa,
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_GENOA,
|
||||
},
|
||||
{
|
||||
name: "Unknown product",
|
||||
product: "UnknownProduct",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Empty product",
|
||||
product: "",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive - milan lowercase",
|
||||
product: "milan",
|
||||
expected: sevsnp.SevProduct_SEV_PRODUCT_UNKNOWN,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetProductName(tt.product)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Invalid product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: "InvalidProduct",
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "product name must be",
|
||||
},
|
||||
{
|
||||
name: "Valid Milan product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Valid Genoa product line",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductGenoa,
|
||||
},
|
||||
Policy: &check.Policy{},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
{
|
||||
name: "Config with existing product policy",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{
|
||||
AskCert: []byte("mock ask cert"),
|
||||
ArkCert: []byte("mock ark cert"),
|
||||
},
|
||||
},
|
||||
config: &check.Config{
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: sevSnpProductMilan,
|
||||
},
|
||||
Policy: &check.Policy{
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation verification failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := verifyReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *sevsnp.Attestation
|
||||
config *check.Config
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Basic validation test",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
{
|
||||
name: "Validation with report data",
|
||||
attestation: &sevsnp.Attestation{
|
||||
CertificateChain: &sevsnp.CertificateChain{},
|
||||
},
|
||||
config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Policy: 196608,
|
||||
ReportData: []byte("test report datatest report datatest report datatest report data"),
|
||||
},
|
||||
},
|
||||
expectedError: true,
|
||||
errorContains: "attestation validation failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateReport(tt.attestation, tt.config)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reportData []byte
|
||||
vmpl uint
|
||||
expectedError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "Report data too large",
|
||||
reportData: make([]byte, Nonce+1),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Valid report data size",
|
||||
reportData: make([]byte, 32),
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Maximum valid report data size",
|
||||
reportData: make([]byte, Nonce),
|
||||
vmpl: 1,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
{
|
||||
name: "Empty report data",
|
||||
reportData: []byte{},
|
||||
vmpl: 0,
|
||||
expectedError: true,
|
||||
errorContains: "could not get quote provider",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FetchAttestation(tt.reportData, tt.vmpl)
|
||||
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, result)
|
||||
if tt.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLeveledQuoteProvider(t *testing.T) {
|
||||
t.Run("GetLeveledQuoteProvider call", func(t *testing.T) {
|
||||
provider, err := GetLeveledQuoteProvider()
|
||||
|
||||
if err != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, provider)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,6 +45,14 @@ func (v provider) Attestation(teeNonce []byte, vTpmNonce []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
func (v provider) TeeAttestation(teeNonce []byte) ([]byte, error) {
|
||||
if teeNonce == nil {
|
||||
return nil, errors.New("tee nonce is required for TDX attestation")
|
||||
}
|
||||
|
||||
if len(teeNonce) != 64 {
|
||||
return nil, fmt.Errorf("invalid tee nonce length: expected 64 bytes, got %d bytes", len(teeNonce))
|
||||
}
|
||||
|
||||
quoteprovider, err := client.GetQuoteProvider()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, errOpenTDXDevice)
|
||||
|
||||
@@ -0,0 +1,622 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package tdx
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-tdx-guest/proto/checkconfig"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want attestation.Provider
|
||||
}{
|
||||
{
|
||||
name: "should create new provider successfully",
|
||||
want: provider{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewProvider()
|
||||
assert.IsType(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Attestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should handle empty nonces",
|
||||
teeNonce: []byte{},
|
||||
vTpmNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "invalid tee nonce length: expected 64 bytes, got 0 bytes",
|
||||
},
|
||||
{
|
||||
name: "should handle valid nonces",
|
||||
teeNonce: []byte("test-noncetest-noncetest-noncetest-noncetest-noncetest-noncetest"),
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
{
|
||||
name: "should handle nil nonces",
|
||||
teeNonce: nil,
|
||||
vTpmNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "tee nonce is required for TDX attestation",
|
||||
},
|
||||
{
|
||||
name: "should handle large nonce",
|
||||
teeNonce: make([]byte, 64),
|
||||
vTpmNonce: make([]byte, 32),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.Attestation(tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_TeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should handle empty nonce",
|
||||
teeNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "invalid tee nonce length: expected 64 bytes, got 0 bytes",
|
||||
},
|
||||
{
|
||||
name: "should handle valid nonce",
|
||||
teeNonce: []byte("test-noncetest-noncetest-noncetest-noncetest-noncetest-noncetest"),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
{
|
||||
name: "should handle nil nonce",
|
||||
teeNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "tee nonce is required for TDX attestation",
|
||||
},
|
||||
{
|
||||
name: "should handle 64-byte nonce",
|
||||
teeNonce: make([]byte, 64),
|
||||
wantErr: true,
|
||||
errContains: "/sys/kernel/config/tsm/report",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.TeeAttestation(tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_VTpmAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error for empty nonce",
|
||||
vTpmNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "vTPM attestation fetch is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for valid nonce",
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "vTPM attestation fetch is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for nil nonce",
|
||||
vTpmNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "vTPM attestation fetch is not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.VTpmAttestation(tt.vTpmNonce)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureAttestationToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error for empty nonce",
|
||||
tokenNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "Azure attestation token is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for valid nonce",
|
||||
tokenNonce: []byte("token-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "Azure attestation token is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for nil nonce",
|
||||
tokenNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "Azure attestation token is not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := provider{}
|
||||
got, err := p.AzureAttestationToken(tt.tokenNonce)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "should create new verifier successfully",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewVerifier()
|
||||
v, ok := got.(verifier)
|
||||
assert.True(t, ok)
|
||||
assert.NotNil(t, v.Policy)
|
||||
assert.NotNil(t, v.Policy.RootOfTrust)
|
||||
assert.NotNil(t, v.Policy.Policy)
|
||||
assert.NotNil(t, v.Policy.Policy.HeaderPolicy)
|
||||
assert.NotNil(t, v.Policy.Policy.TdQuoteBodyPolicy)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewVerifierWithPolicy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *checkconfig.Config
|
||||
}{
|
||||
{
|
||||
name: "should create verifier with nil policy",
|
||||
policy: nil,
|
||||
},
|
||||
{
|
||||
name: "should create verifier with valid policy",
|
||||
policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should create verifier with empty policy",
|
||||
policy: &checkconfig.Config{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewVerifierWithPolicy(tt.policy)
|
||||
v, ok := got.(verifier)
|
||||
assert.True(t, ok)
|
||||
|
||||
if tt.policy == nil {
|
||||
assert.NotNil(t, v.Policy)
|
||||
assert.NotNil(t, v.Policy.RootOfTrust)
|
||||
assert.NotNil(t, v.Policy.Policy)
|
||||
} else {
|
||||
assert.Equal(t, tt.policy, v.Policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifTeeAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error when policy is nil",
|
||||
verifier: verifier{
|
||||
Policy: nil,
|
||||
},
|
||||
report: []byte("test-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "tdx policy is not provided",
|
||||
},
|
||||
{
|
||||
name: "should handle invalid report format",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: []byte("invalid-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should handle empty report",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: []byte{},
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should handle nil report",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: nil,
|
||||
teeNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.VerifTeeAttestation(tt.report, tt.teeNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifVTpmAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
report []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should return error for any input",
|
||||
verifier: verifier{},
|
||||
report: []byte("test-report"),
|
||||
vTpmNonce: []byte("test-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "VTPM attestation verification is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty inputs",
|
||||
verifier: verifier{},
|
||||
report: []byte{},
|
||||
vTpmNonce: []byte{},
|
||||
wantErr: true,
|
||||
errContains: "VTPM attestation verification is not supported",
|
||||
},
|
||||
{
|
||||
name: "should return error for nil inputs",
|
||||
verifier: verifier{},
|
||||
report: nil,
|
||||
vTpmNonce: nil,
|
||||
wantErr: true,
|
||||
errContains: "VTPM attestation verification is not supported",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.VerifVTpmAttestation(tt.report, tt.vTpmNonce)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_VerifyAttestation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
report []byte
|
||||
teeNonce []byte
|
||||
vTpmNonce []byte
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should delegate to VerifTeeAttestation with nil policy",
|
||||
verifier: verifier{
|
||||
Policy: nil,
|
||||
},
|
||||
report: []byte("test-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "tdx policy is not provided",
|
||||
},
|
||||
{
|
||||
name: "should delegate to VerifTeeAttestation with valid policy",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{},
|
||||
},
|
||||
},
|
||||
report: []byte("invalid-report"),
|
||||
teeNonce: []byte("test-nonce"),
|
||||
vTpmNonce: []byte("vtpm-nonce"),
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.VerifyAttestation(tt.report, tt.teeNonce, tt.vTpmNonce)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifier_JSONToPolicy(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
testPolicy := &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{
|
||||
HeaderPolicy: &checkconfig.HeaderPolicy{},
|
||||
TdQuoteBodyPolicy: &checkconfig.TDQuoteBodyPolicy{},
|
||||
},
|
||||
}
|
||||
|
||||
validPolicyJSON, err := protojson.Marshal(testPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPolicyFile := filepath.Join(tempDir, "valid_policy.json")
|
||||
err = os.WriteFile(validPolicyFile, validPolicyJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidPolicyFile := filepath.Join(tempDir, "invalid_policy.json")
|
||||
err = os.WriteFile(invalidPolicyFile, []byte("invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier verifier
|
||||
path string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should load valid policy file",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: validPolicyFile,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error for non-existent file",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: filepath.Join(tempDir, "non_existent.json"),
|
||||
wantErr: true,
|
||||
errContains: "no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "should return error for invalid JSON",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: invalidPolicyFile,
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty path",
|
||||
verifier: verifier{
|
||||
Policy: &checkconfig.Config{},
|
||||
},
|
||||
path: "",
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.verifier.JSONToPolicy(tt.path)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadTDXAttestationPolicy(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
testPolicy := &checkconfig.Config{
|
||||
RootOfTrust: &checkconfig.RootOfTrust{},
|
||||
Policy: &checkconfig.Policy{
|
||||
HeaderPolicy: &checkconfig.HeaderPolicy{},
|
||||
TdQuoteBodyPolicy: &checkconfig.TDQuoteBodyPolicy{},
|
||||
},
|
||||
}
|
||||
|
||||
validPolicyJSON, err := protojson.Marshal(testPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPolicyFile := filepath.Join(tempDir, "valid_policy.json")
|
||||
err = os.WriteFile(validPolicyFile, validPolicyJSON, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidPolicyFile := filepath.Join(tempDir, "invalid_policy.json")
|
||||
err = os.WriteFile(invalidPolicyFile, []byte("invalid json"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
emptyFile := filepath.Join(tempDir, "empty.json")
|
||||
err = os.WriteFile(emptyFile, []byte{}, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policyPath string
|
||||
policy *checkconfig.Config
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "should read valid policy file",
|
||||
policyPath: validPolicyFile,
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "should return error for non-existent file",
|
||||
policyPath: filepath.Join(tempDir, "non_existent.json"),
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "should return error for invalid JSON",
|
||||
policyPath: invalidPolicyFile,
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty file",
|
||||
policyPath: emptyFile,
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
{
|
||||
name: "should return error for empty path",
|
||||
policyPath: "",
|
||||
policy: &checkconfig.Config{},
|
||||
wantErr: true,
|
||||
errContains: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ReadTDXAttestationPolicy(tt.policyPath, tt.policy)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tt.policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,11 @@
|
||||
package vtpm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -13,6 +17,8 @@ import (
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/proto/check"
|
||||
"github.com/google/go-sev-guest/proto/sevsnp"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
ptpm "github.com/google/go-tpm-tools/proto/tpm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||||
@@ -24,6 +30,633 @@ const sevSnpProductMilan = "Milan"
|
||||
|
||||
var policy = attestation.Config{Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}}, PcrConfig: &attestation.PcrConfig{}}
|
||||
|
||||
type mockTPM struct {
|
||||
*bytes.Buffer
|
||||
closeErr error
|
||||
}
|
||||
|
||||
func (m *mockTPM) Close() error {
|
||||
return m.closeErr
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockWriter) Write(p []byte) (n int, err error) {
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
m.data = append(m.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestOpenTpm(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
externalTPM io.ReadWriteCloser
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "External TPM available",
|
||||
externalTPM: &mockTPM{Buffer: &bytes.Buffer{}},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "No external TPM",
|
||||
externalTPM: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
originalExternalTPM := ExternalTPM
|
||||
defer func() { ExternalTPM = originalExternalTPM }()
|
||||
|
||||
ExternalTPM = tt.externalTPM
|
||||
|
||||
tpm, err := OpenTpm()
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
if tt.externalTPM != nil {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tpm)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTpmEventLog(t *testing.T) {
|
||||
tempFile, err := os.CreateTemp("", "event_log")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
testData := []byte("test event log data")
|
||||
_, err = tempFile.Write(testData)
|
||||
require.NoError(t, err)
|
||||
tempFile.Close()
|
||||
|
||||
tpm := &tpm{ReadWriteCloser: &mockTPM{Buffer: &bytes.Buffer{}}}
|
||||
|
||||
_, err = tpm.EventLog()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
teeAttestation bool
|
||||
vmpl uint
|
||||
}{
|
||||
{
|
||||
name: "TEE attestation enabled",
|
||||
teeAttestation: true,
|
||||
vmpl: 1,
|
||||
},
|
||||
{
|
||||
name: "TEE attestation disabled",
|
||||
teeAttestation: false,
|
||||
vmpl: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := NewProvider(tt.teeAttestation, tt.vmpl)
|
||||
assert.NotNil(t, provider)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderAzureAttestationToken(t *testing.T) {
|
||||
provider := NewProvider(false, 0)
|
||||
|
||||
token, err := provider.AzureAttestationToken([]byte("test-nonce"))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, token)
|
||||
assert.Contains(t, err.Error(), "Azure attestation token is not supported")
|
||||
}
|
||||
|
||||
func TestNewVerifier(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
verifier := NewVerifier(writer)
|
||||
|
||||
assert.NotNil(t, verifier)
|
||||
}
|
||||
|
||||
func TestNewVerifierWithPolicy(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *attestation.Config
|
||||
}{
|
||||
{
|
||||
name: "With policy",
|
||||
policy: policy,
|
||||
},
|
||||
{
|
||||
name: "Without policy (nil)",
|
||||
policy: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifier := NewVerifierWithPolicy([]byte("test-key"), writer, tt.policy)
|
||||
assert.NotNil(t, verifier)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalQuote(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *attest.Attestation
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid attestation",
|
||||
attestation: &attest.Attestation{
|
||||
AkPub: []byte("test-key"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nil attestation",
|
||||
attestation: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
data, err := marshalQuote(tt.attestation)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, data)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.attestation != nil {
|
||||
assert.NotEmpty(t, data)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckExpectedPCRValues(t *testing.T) {
|
||||
testPCRValue := make([]byte, 32)
|
||||
for i := range testPCRValue {
|
||||
testPCRValue[i] = byte(i)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attestation *attest.Attestation
|
||||
policy *attestation.Config
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Matching PCR values SHA256",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": hex.EncodeToString(testPCRValue),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Mismatched PCR values",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": hex.EncodeToString(make([]byte, 32)),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "expected",
|
||||
},
|
||||
{
|
||||
name: "Unsupported hash algorithm",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_HASH_INVALID,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "hash algo is not supported",
|
||||
},
|
||||
{
|
||||
name: "Invalid PCR index",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"invalid": hex.EncodeToString(testPCRValue),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "error converting PCR index to int32",
|
||||
},
|
||||
{
|
||||
name: "Invalid PCR value hex",
|
||||
attestation: &attest.Attestation{
|
||||
Quotes: []*ptpm.Quote{
|
||||
{
|
||||
Pcrs: &ptpm.PCRs{
|
||||
Hash: ptpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: testPCRValue,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
policy: &attestation.Config{
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": "invalid-hex",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "error converting PCR value to byte",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := checkExpectedPCRValues(tt.attestation, tt.policy)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicy(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy_test")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
validPolicy := map[string]interface{}{
|
||||
"policy": map[string]interface{}{
|
||||
"product": map[string]interface{}{
|
||||
"name": "test-product",
|
||||
},
|
||||
},
|
||||
"rootOfTrust": map[string]interface{}{
|
||||
"productLine": "test-line",
|
||||
},
|
||||
"pcrConfig": map[string]interface{}{
|
||||
"pcrValues": map[string]interface{}{
|
||||
"sha256": map[string]string{
|
||||
"0": "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
validPolicyData, err := json.Marshal(validPolicy)
|
||||
require.NoError(t, err)
|
||||
|
||||
validPolicyPath := filepath.Join(tempDir, "valid_policy.json")
|
||||
err = os.WriteFile(validPolicyPath, validPolicyData, 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policyPath string
|
||||
expectError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Valid policy file",
|
||||
policyPath: validPolicyPath,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Non-existent policy file",
|
||||
policyPath: "/nonexistent/path",
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyOpen,
|
||||
},
|
||||
{
|
||||
name: "Empty policy path",
|
||||
policyPath: "",
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyMissing,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
err := ReadPolicy(tt.policyPath, config)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != nil {
|
||||
assert.True(t, errors.Contains(err, tt.expectedErr))
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPolicyFromByte(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
policyData []byte
|
||||
expectError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Valid policy data",
|
||||
policyData: []byte(`{
|
||||
"policy": {
|
||||
"product": {
|
||||
"name": "test-product"
|
||||
}
|
||||
},
|
||||
"rootOfTrust": {
|
||||
"productLine": "test-line"
|
||||
},
|
||||
"pcrConfig": {
|
||||
"pcrValues": {
|
||||
"sha256": {
|
||||
"0": "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`),
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
policyData: []byte(`{invalid json`),
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyDecode,
|
||||
},
|
||||
{
|
||||
name: "Empty policy data",
|
||||
policyData: []byte(``),
|
||||
expectError: true,
|
||||
expectedErr: ErrAttestationPolicyDecode,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
err := ReadPolicyFromByte(tt.policyData, config)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != nil {
|
||||
assert.True(t, errors.Contains(err, tt.expectedErr))
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPolicyToJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *attestation.Config
|
||||
expectError bool
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "Valid config",
|
||||
config: &attestation.Config{
|
||||
Config: &check.Config{
|
||||
Policy: &check.Policy{
|
||||
Product: &sevsnp.SevProduct{
|
||||
Name: sevsnp.SevProduct_SEV_PRODUCT_MILAN,
|
||||
},
|
||||
},
|
||||
RootOfTrust: &check.RootOfTrust{
|
||||
ProductLine: "Milan",
|
||||
},
|
||||
},
|
||||
PcrConfig: &attestation.PcrConfig{
|
||||
PCRValues: attestation.PcrValues{
|
||||
Sha256: map[string]string{
|
||||
"0": "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Nil config",
|
||||
config: &attestation.Config{
|
||||
Config: nil,
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
},
|
||||
expectError: false,
|
||||
expectedErr: ErrProtoMarshalFailed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
jsonData, err := ConvertPolicyToJSON(tt.config)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
if tt.expectedErr != nil {
|
||||
assert.True(t, errors.Contains(err, tt.expectedErr))
|
||||
}
|
||||
assert.Nil(t, jsonData)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jsonData)
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(jsonData, &result)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVTPMVerify(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
quote []byte
|
||||
teeNonce []byte
|
||||
vtpmNonce []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid quote data",
|
||||
quote: []byte("invalid"),
|
||||
teeNonce: []byte("tee-nonce"),
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty quote",
|
||||
quote: []byte{},
|
||||
teeNonce: []byte("tee-nonce"),
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VTPMVerify(tt.quote, tt.teeNonce, tt.vtpmNonce, writer, policy)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyQuote(t *testing.T) {
|
||||
writer := &mockWriter{}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
quote []byte
|
||||
vtpmNonce []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Invalid quote data",
|
||||
quote: []byte("invalid"),
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Empty quote",
|
||||
quote: []byte{},
|
||||
vtpmNonce: []byte("vtpm-nonce"),
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := VerifyQuote(tt.quote, tt.vtpmNonce, writer, policy)
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterError(t *testing.T) {
|
||||
writer := &mockWriter{err: fmt.Errorf("write error")}
|
||||
policy := &attestation.Config{
|
||||
Config: &check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}},
|
||||
PcrConfig: &attestation.PcrConfig{},
|
||||
}
|
||||
|
||||
err := VerifyQuote([]byte("invalid"), []byte("nonce"), writer, policy)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "policy")
|
||||
require.NoError(t, err)
|
||||
@@ -33,7 +666,7 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) {
|
||||
err = setAttestationPolicy(attestationPB, tempDir)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Change random data so in the signature so the signature failes
|
||||
// Change random data so in the signature so the signature fails
|
||||
attestationPB.Report.Signature[0] = attestationPB.Report.Signature[0] ^ 0x01
|
||||
|
||||
tests := []struct {
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cvm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/health"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
type TestServer struct {
|
||||
agent.UnimplementedAgentServiceServer
|
||||
server *grpc.Server
|
||||
health *health.Server
|
||||
port int
|
||||
listenAddr string
|
||||
}
|
||||
|
||||
func NewTestServer() (*TestServer, error) {
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen: %v", err)
|
||||
}
|
||||
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
|
||||
server := grpc.NewServer()
|
||||
healthServer := health.NewServer()
|
||||
|
||||
ts := &TestServer{
|
||||
server: server,
|
||||
health: healthServer,
|
||||
port: addr.Port,
|
||||
listenAddr: fmt.Sprintf("localhost:%d", addr.Port),
|
||||
}
|
||||
|
||||
svc := new(mocks.Service)
|
||||
agent.RegisterAgentServiceServer(server, agentgrpc.NewServer(svc))
|
||||
grpchealth.RegisterHealthServer(server, healthServer)
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil {
|
||||
fmt.Printf("Server exited with error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
healthServer.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
func (s *TestServer) Stop() {
|
||||
if s.server != nil {
|
||||
s.server.GracefulStop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentClientIntegration(t *testing.T) {
|
||||
testServer, err := NewTestServer()
|
||||
require.NoError(t, err)
|
||||
defer testServer.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
serverRunning bool
|
||||
config pkggrpc.CVMClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "successful connection",
|
||||
serverRunning: true,
|
||||
config: pkggrpc.CVMClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
URL: testServer.listenAddr,
|
||||
Timeout: 1,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "server not healthy",
|
||||
serverRunning: false,
|
||||
config: pkggrpc.CVMClientConfig{
|
||||
BaseConfig: pkggrpc.BaseConfig{
|
||||
URL: "",
|
||||
Timeout: 1,
|
||||
},
|
||||
},
|
||||
err: errors.New("failed to connect to grpc server"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !tt.serverRunning {
|
||||
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_NOT_SERVING)
|
||||
} else {
|
||||
testServer.health.SetServingStatus("agent", grpchealth.HealthCheckResponse_SERVING)
|
||||
}
|
||||
|
||||
client, agentClient, err := NewCVMClient(tt.config)
|
||||
assert.True(t, errors.Contains(err, tt.err))
|
||||
if err != nil {
|
||||
assert.Nil(t, client)
|
||||
assert.Nil(t, agentClient)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, client)
|
||||
require.NotNil(t, agentClient)
|
||||
defer client.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -529,6 +529,76 @@ func TestReceiveAttestation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceiverIMAMeasurements(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
description string
|
||||
totalSize int
|
||||
chunks [][]byte
|
||||
wantResult []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful single chunk receive",
|
||||
description: "Receiving IMA measurements",
|
||||
totalSize: 20,
|
||||
chunks: [][]byte{[]byte("12345678912345678999")},
|
||||
wantResult: []byte("12345678912345678999"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "stream error",
|
||||
description: "Receiving IMA measurements",
|
||||
totalSize: 20,
|
||||
chunks: [][]byte{[]byte("12345678912345678999")},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("stream error"),
|
||||
},
|
||||
{
|
||||
name: "size mismatch",
|
||||
description: "Receiving IMA measurements",
|
||||
totalSize: 10,
|
||||
chunks: [][]byte{[]byte("12345678912345678999")},
|
||||
wantResult: nil,
|
||||
wantErr: errors.New("progress update exceeds total bytes: attempted to add 20 bytes, but only 10 bytes remain"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockStream := new(mocks.AgentService_IMAMeasurementsClient[agent.IMAMeasurementsResponse])
|
||||
|
||||
p := New(true)
|
||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||
|
||||
resultFile, err := os.CreateTemp("", "test_ima_measurements")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(resultFile.Name())
|
||||
})
|
||||
|
||||
if tt.wantErr != nil {
|
||||
mockStream.On("Recv").Return(nil, tt.wantErr).Once()
|
||||
}
|
||||
mockStream.On("Recv").Return(&agent.IMAMeasurementsResponse{Pcr10: []byte(tt.chunks[0]), File: []byte(tt.chunks[0])}, nil).Once()
|
||||
mockStream.On("Recv").Return(nil, io.EOF).Once()
|
||||
|
||||
pcr10, err := p.ReceiveIMAMeasurements(tt.description, tt.totalSize, mockStream, resultFile)
|
||||
|
||||
assert.NoError(t, resultFile.Close())
|
||||
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantResult, pcr10)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockAlgoStream struct {
|
||||
stream agent.AgentService_AlgoClient
|
||||
sendCount int
|
||||
|
||||
+29
-29
@@ -176,6 +176,35 @@ func (sdk *agentSDK) AttestationResult(ctx context.Context, nonce [size32]byte,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) IMAMeasurements(ctx context.Context, resultFile *os.File) ([]byte, error) {
|
||||
request := &agent.IMAMeasurementsRequest{}
|
||||
|
||||
stream, err := sdk.client.IMAMeasurements(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
|
||||
if len(fileSizeStr) == 0 {
|
||||
fileSizeStr = append(fileSizeStr, "0")
|
||||
}
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveIMAMeasurements(imaMeasurementsProgressDescription, fileSize, stream, resultFile)
|
||||
}
|
||||
|
||||
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
|
||||
var signature []byte
|
||||
var err error
|
||||
@@ -208,32 +237,3 @@ func generateMetadata(userID string, privateKey crypto.PrivateKey) (metadata.MD,
|
||||
kv[auth.SignatureMetadataKey] = base64.StdEncoding.EncodeToString(signature)
|
||||
return metadata.New(kv), nil
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) IMAMeasurements(ctx context.Context, resultFile *os.File) ([]byte, error) {
|
||||
request := &agent.IMAMeasurementsRequest{}
|
||||
|
||||
stream, err := sdk.client.IMAMeasurements(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
|
||||
if len(fileSizeStr) == 0 {
|
||||
fileSizeStr = append(fileSizeStr, "0")
|
||||
}
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveIMAMeasurements(imaMeasurementsProgressDescription, fileSize, stream, resultFile)
|
||||
}
|
||||
|
||||
@@ -558,6 +558,82 @@ func TestAttestationResult(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIMAMeasurements(t *testing.T) {
|
||||
conn, err := grpc.NewClient("passthrough://bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(bufDialer))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := agent.NewAgentServiceClient(conn)
|
||||
|
||||
sdk := sdk.NewAgentSDK(client)
|
||||
|
||||
response := &agent.IMAMeasurementsResponse{
|
||||
File: []byte{
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
response *agent.IMAMeasurementsResponse
|
||||
svcRes []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "fetch IMA measurements successfully",
|
||||
response: response,
|
||||
svcRes: response.File,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "failed to fetch IMA measurements",
|
||||
response: &agent.IMAMeasurementsResponse{File: []byte{}},
|
||||
svcRes: nil,
|
||||
err: errors.New("failed to fetch IMA measurements"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("IMAMeasurements", mock.Anything).Return(tc.svcRes, tc.svcRes, tc.err)
|
||||
|
||||
file, err := os.CreateTemp("", "ima_measurements")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(file.Name())
|
||||
})
|
||||
|
||||
_, err = sdk.IMAMeasurements(context.Background(), file)
|
||||
|
||||
require.NoError(t, file.Close())
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("Expected gRPC status error, but got: %v", err)
|
||||
}
|
||||
|
||||
if tc.err != nil {
|
||||
if st.Message() != tc.err.Error() {
|
||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
|
||||
res, err := os.ReadFile(file.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.response.File, res, tc.name)
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateKeys(t *testing.T, keyType string) (priv any, pub []byte) {
|
||||
switch keyType {
|
||||
case "ecdsa":
|
||||
|
||||
Reference in New Issue
Block a user