mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-143 - Add agent service tests (#170)
* Add agent service tests Signed-off-by: Jilks Smith <smithjilks@gmail.com> * Update agent service tests * Fix agent service tests * Improve agent service test coverage * Improve agent service test coverage Signed-off-by: Jilks Smith <smithjilks@gmail.com> * Fix tests Signed-off-by: Jilks Smith <smithjilks@gmail.com> * Refactor and improve coverage Signed-off-by: Jilks Smith <smithjilks@gmail.com> --------- Signed-off-by: Jilks Smith <smithjilks@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ package algorithm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
@@ -65,7 +66,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) {
|
||||
s.Logger.Error(string(buf[:n]))
|
||||
}
|
||||
|
||||
if err := s.EventSvc.SendEvent("algorithm-run", "error", nil); err != nil {
|
||||
if err := s.EventSvc.SendEvent("algorithm-run", "error", json.RawMessage{}); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
|
||||
@@ -18,15 +18,15 @@ import (
|
||||
|
||||
const (
|
||||
PyRuntime = "python3"
|
||||
pyRuntimeKey = "python_runtime"
|
||||
PyRuntimeKey = "python_runtime"
|
||||
)
|
||||
|
||||
func PythonRunTimeToContext(ctx context.Context, runtime string) context.Context {
|
||||
return metadata.AppendToOutgoingContext(ctx, pyRuntimeKey, runtime)
|
||||
return metadata.AppendToOutgoingContext(ctx, PyRuntimeKey, runtime)
|
||||
}
|
||||
|
||||
func PythonRunTimeFromContext(ctx context.Context) string {
|
||||
return metadata.ValueFromIncomingContext(ctx, pyRuntimeKey)[0]
|
||||
return metadata.ValueFromIncomingContext(ctx, PyRuntimeKey)[0]
|
||||
}
|
||||
|
||||
var _ algorithm.Algorithm = (*python)(nil)
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
// Code generated by mockery v2.45.0. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
sevsnp "github.com/google/go-sev-guest/proto/sevsnp"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// QuoteProvider is an autogenerated mock type for the QuoteProvider type
|
||||
type QuoteProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// GetRawQuote provides a mock function with given fields: reportData
|
||||
func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) {
|
||||
ret := _m.Called(reportData)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetRawQuote")
|
||||
}
|
||||
|
||||
var r0 []uint8
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func([64]byte) ([]uint8, error)); ok {
|
||||
return rf(reportData)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func([64]byte) []uint8); ok {
|
||||
r0 = rf(reportData)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]uint8)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func([64]byte) error); ok {
|
||||
r1 = rf(reportData)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// IsSupported provides a mock function with given fields:
|
||||
func (_m *QuoteProvider) IsSupported() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IsSupported")
|
||||
}
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Product provides a mock function with given fields:
|
||||
func (_m *QuoteProvider) Product() *sevsnp.SevProduct {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Product")
|
||||
}
|
||||
|
||||
var r0 *sevsnp.SevProduct
|
||||
if rf, ok := ret.Get(0).(func() *sevsnp.SevProduct); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*sevsnp.SevProduct)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewQuoteProvider creates a new instance of QuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewQuoteProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *QuoteProvider {
|
||||
mock := &QuoteProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
+5
-2
@@ -55,6 +55,8 @@ var (
|
||||
ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest")
|
||||
// ErrAllResultsConsumed indicates all results have been consumed.
|
||||
ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers")
|
||||
// ErrAttestationFailed attestation failed.
|
||||
ErrAttestationFailed = errors.New("failed to get raw quote")
|
||||
)
|
||||
|
||||
// Service specifies an API that must be fullfiled by the domain service
|
||||
@@ -124,7 +126,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
return fmt.Errorf("error getting current directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(filepath.Join(currentDir, "algorithm"))
|
||||
f, err := os.Create(filepath.Join(currentDir, "algo"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating algorithm file: %v", err)
|
||||
}
|
||||
@@ -317,8 +319,9 @@ func (as *agentService) runComputation() {
|
||||
}
|
||||
|
||||
func (as *agentService) publishEvent(status string, details json.RawMessage) func() {
|
||||
st := as.sm.GetState().String()
|
||||
return func() {
|
||||
if err := as.eventSvc.SendEvent(as.sm.State.String(), status, details); err != nil {
|
||||
if err := as.eventSvc.SendEvent(st, status, details); err != nil {
|
||||
as.sm.logger.Warn(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,394 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/quoteprovider"
|
||||
mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
var (
|
||||
algoPath = "../test/manual/algo/lin_reg.py"
|
||||
reqPath = "../test/manual/algo/requirements.txt"
|
||||
dataPath = "../test/manual/data/iris.csv"
|
||||
)
|
||||
|
||||
const datasetFile = "iris.csv"
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
require.NoError(t, err)
|
||||
|
||||
algo, err := os.ReadFile(algoPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
algoHash := sha3.Sum256(algo)
|
||||
|
||||
reqFile, err := os.ReadFile(reqPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
err error
|
||||
algo Algorithm
|
||||
algoType string
|
||||
}{
|
||||
{
|
||||
name: "Test Algo successfully",
|
||||
algo: Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: algoHash,
|
||||
},
|
||||
algoType: "python",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test Algo successfully with requirements file",
|
||||
algo: Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: algoHash,
|
||||
Requirements: reqFile,
|
||||
},
|
||||
algoType: "python",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test Algo type binary successfully",
|
||||
algo: Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: algoHash,
|
||||
},
|
||||
algoType: "bin",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test Algo type wasm successfully",
|
||||
algo: Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: algoHash,
|
||||
},
|
||||
algoType: "wasm",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test Algo type docker successfully",
|
||||
algo: Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: algoHash,
|
||||
},
|
||||
algoType: "docker",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test algo hash mismatch",
|
||||
algo: Algorithm{},
|
||||
algoType: "python",
|
||||
err: ErrHashMismatch,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = os.RemoveAll("datasets")
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
err = svc.Algo(ctx, tc.algo)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
require.NoError(t, err)
|
||||
|
||||
algo, err := os.ReadFile(algoPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
algoHash := sha3.Sum256(algo)
|
||||
|
||||
alg := Algorithm{
|
||||
Hash: algoHash,
|
||||
Algorithm: algo,
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(dataPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataHash := sha3.Sum256(data)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
data Dataset
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Test data successfully",
|
||||
data: Dataset{
|
||||
Hash: dataHash,
|
||||
Dataset: data,
|
||||
Filename: datasetFile,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test State not ready",
|
||||
data: Dataset{
|
||||
Dataset: data,
|
||||
Hash: dataHash,
|
||||
Filename: datasetFile,
|
||||
},
|
||||
err: ErrStateNotReady,
|
||||
},
|
||||
{
|
||||
name: "Test File name does not match manifest",
|
||||
data: Dataset{
|
||||
Dataset: data,
|
||||
Hash: dataHash,
|
||||
Filename: "invalid",
|
||||
},
|
||||
err: ErrFileNameMismatch,
|
||||
},
|
||||
{
|
||||
name: "Test dataset not declared in manifest",
|
||||
data: Dataset{
|
||||
Filename: datasetFile,
|
||||
},
|
||||
err: ErrUndeclaredDataset,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(
|
||||
algorithm.AlgoTypeKey, "python",
|
||||
python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
|
||||
if tc.err != ErrUndeclaredDataset {
|
||||
ctx = IndexToContext(ctx, 0)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
comp := testComputation(t)
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, comp, qp)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
if tc.err != ErrStateNotReady {
|
||||
_ = svc.Algo(ctx, alg)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
err = svc.Data(ctx, tc.data)
|
||||
_ = os.RemoveAll("datasets")
|
||||
_ = os.RemoveAll("results")
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
setup func(svc *agentService)
|
||||
ctxSetup func(ctx context.Context) context.Context
|
||||
}{
|
||||
{
|
||||
name: "Test results not ready",
|
||||
err: ErrResultsNotReady,
|
||||
setup: func(svc *agentService) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test all results consumed",
|
||||
err: ErrAllResultsConsumed,
|
||||
setup: func(svc *agentService) {
|
||||
svc.sm.SetState(resultsReady)
|
||||
svc.computation.ResultConsumers = []ResultConsumer{}
|
||||
},
|
||||
ctxSetup: func(ctx context.Context) context.Context {
|
||||
return IndexToContext(ctx, 0)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test undeclared consumer",
|
||||
err: ErrUndeclaredConsumer,
|
||||
setup: func(svc *agentService) {
|
||||
svc.sm.SetState(resultsReady)
|
||||
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
|
||||
},
|
||||
ctxSetup: func(ctx context.Context) context.Context {
|
||||
return ctx
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test results consumed and event sent",
|
||||
err: nil,
|
||||
setup: func(svc *agentService) {
|
||||
svc.sm.SetState(resultsReady)
|
||||
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
|
||||
},
|
||||
ctxSetup: func(ctx context.Context) context.Context {
|
||||
return IndexToContext(ctx, 0)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
|
||||
if tc.ctxSetup != nil {
|
||||
ctx = tc.ctxSetup(ctx)
|
||||
}
|
||||
|
||||
svc := &agentService{
|
||||
sm: NewStateMachine(mglog.NewMock(), testComputation(t)),
|
||||
eventSvc: events,
|
||||
quoteProvider: qp,
|
||||
computation: testComputation(t),
|
||||
}
|
||||
|
||||
go svc.sm.Start(ctx)
|
||||
tc.setup(svc)
|
||||
_, err := svc.Result(ctx)
|
||||
_ = os.RemoveAll("datasets")
|
||||
_ = os.RemoveAll("results")
|
||||
|
||||
assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestation(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
qp := new(mocks2.QuoteProvider)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
reportData [ReportDataSize]byte
|
||||
rawQuote []uint8
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Test attestation successful",
|
||||
reportData: generateReportData(),
|
||||
rawQuote: make([]uint8, 0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test attestation failed",
|
||||
reportData: generateReportData(),
|
||||
rawQuote: nil,
|
||||
err: ErrAttestationFailed,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
getQuote := qp.On("GetRawQuote", mock.Anything).Return(tc.rawQuote, tc.err)
|
||||
if tc.err != ErrAttestationFailed {
|
||||
getQuote = qp.On("GetRawQuote", mock.Anything).Return(tc.reportData, nil)
|
||||
}
|
||||
defer getQuote.Unset()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
_, err := svc.Attestation(ctx, tc.reportData)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateReportData() [ReportDataSize]byte {
|
||||
bytes := make([]byte, ReportDataSize)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate random bytes: %v", err)
|
||||
}
|
||||
return [64]byte(bytes)
|
||||
}
|
||||
|
||||
func testComputation(t *testing.T) Computation {
|
||||
algo, err := os.ReadFile(algoPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
algoHash := sha3.Sum256(algo)
|
||||
|
||||
data, err := os.ReadFile(dataPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataHash := sha3.Sum256(data)
|
||||
|
||||
return Computation{
|
||||
ID: "1",
|
||||
Name: "sample computation",
|
||||
Description: "sample description",
|
||||
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
|
||||
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
|
||||
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
|
||||
AgentConfig: AgentConfig{
|
||||
Port: "7002",
|
||||
LogLevel: "debug",
|
||||
AttestedTls: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user