COCOS-143 - Add agent service tests (#170)

* Add agent service tests

Signed-off-by: Jilks Smith <smithjilks@gmail.com>

* Update agent service tests

* Fix agent service tests

* Improve agent service test coverage

* Improve agent service test coverage

Signed-off-by: Jilks Smith <smithjilks@gmail.com>

* Fix tests

Signed-off-by: Jilks Smith <smithjilks@gmail.com>

* Refactor and improve coverage

Signed-off-by: Jilks Smith <smithjilks@gmail.com>

---------

Signed-off-by: Jilks Smith <smithjilks@gmail.com>
This commit is contained in:
Smith Jilks
2024-09-12 17:54:09 +03:00
committed by GitHub
parent 20ddb3aa29
commit e26deb98e4
5 changed files with 499 additions and 6 deletions
+2 -1
View File
@@ -4,6 +4,7 @@ package algorithm
import (
"bytes"
"encoding/json"
"io"
"log/slog"
@@ -65,7 +66,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) {
s.Logger.Error(string(buf[:n]))
}
if err := s.EventSvc.SendEvent("algorithm-run", "error", nil); err != nil {
if err := s.EventSvc.SendEvent("algorithm-run", "error", json.RawMessage{}); err != nil {
return len(p), err
}
+3 -3
View File
@@ -18,15 +18,15 @@ import (
const (
PyRuntime = "python3"
pyRuntimeKey = "python_runtime"
PyRuntimeKey = "python_runtime"
)
func PythonRunTimeToContext(ctx context.Context, runtime string) context.Context {
return metadata.AppendToOutgoingContext(ctx, pyRuntimeKey, runtime)
return metadata.AppendToOutgoingContext(ctx, PyRuntimeKey, runtime)
}
func PythonRunTimeFromContext(ctx context.Context) string {
return metadata.ValueFromIncomingContext(ctx, pyRuntimeKey)[0]
return metadata.ValueFromIncomingContext(ctx, PyRuntimeKey)[0]
}
var _ algorithm.Algorithm = (*python)(nil)
@@ -0,0 +1,95 @@
// Code generated by mockery v2.45.0. DO NOT EDIT.
package mocks
import (
sevsnp "github.com/google/go-sev-guest/proto/sevsnp"
mock "github.com/stretchr/testify/mock"
)
// QuoteProvider is an autogenerated mock type for the QuoteProvider type
type QuoteProvider struct {
mock.Mock
}
// GetRawQuote provides a mock function with given fields: reportData
func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) {
ret := _m.Called(reportData)
if len(ret) == 0 {
panic("no return value specified for GetRawQuote")
}
var r0 []uint8
var r1 error
if rf, ok := ret.Get(0).(func([64]byte) ([]uint8, error)); ok {
return rf(reportData)
}
if rf, ok := ret.Get(0).(func([64]byte) []uint8); ok {
r0 = rf(reportData)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]uint8)
}
}
if rf, ok := ret.Get(1).(func([64]byte) error); ok {
r1 = rf(reportData)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// IsSupported provides a mock function with given fields:
func (_m *QuoteProvider) IsSupported() bool {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for IsSupported")
}
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Product provides a mock function with given fields:
func (_m *QuoteProvider) Product() *sevsnp.SevProduct {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Product")
}
var r0 *sevsnp.SevProduct
if rf, ok := ret.Get(0).(func() *sevsnp.SevProduct); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*sevsnp.SevProduct)
}
}
return r0
}
// NewQuoteProvider creates a new instance of QuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewQuoteProvider(t interface {
mock.TestingT
Cleanup(func())
}) *QuoteProvider {
mock := &QuoteProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+5 -2
View File
@@ -55,6 +55,8 @@ var (
ErrFileNameMismatch = errors.New("malformed data, filename does not match manifest")
// ErrAllResultsConsumed indicates all results have been consumed.
ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers")
// ErrAttestationFailed attestation failed.
ErrAttestationFailed = errors.New("failed to get raw quote")
)
// Service specifies an API that must be fullfiled by the domain service
@@ -124,7 +126,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
return fmt.Errorf("error getting current directory: %v", err)
}
f, err := os.Create(filepath.Join(currentDir, "algorithm"))
f, err := os.Create(filepath.Join(currentDir, "algo"))
if err != nil {
return fmt.Errorf("error creating algorithm file: %v", err)
}
@@ -317,8 +319,9 @@ func (as *agentService) runComputation() {
}
func (as *agentService) publishEvent(status string, details json.RawMessage) func() {
st := as.sm.GetState().String()
return func() {
if err := as.eventSvc.SendEvent(as.sm.State.String(), status, details); err != nil {
if err := as.eventSvc.SendEvent(st, status, details); err != nil {
as.sm.logger.Warn(err.Error())
}
}
+394
View File
@@ -0,0 +1,394 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"crypto/rand"
"log"
"os"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/events/mocks"
"github.com/ultravioletrs/cocos/agent/quoteprovider"
mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
)
var (
algoPath = "../test/manual/algo/lin_reg.py"
reqPath = "../test/manual/algo/requirements.txt"
dataPath = "../test/manual/data/iris.csv"
)
const datasetFile = "iris.csv"
func TestAlgo(t *testing.T) {
events := new(mocks.Service)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
qp, err := quoteprovider.GetQuoteProvider()
require.NoError(t, err)
algo, err := os.ReadFile(algoPath)
require.NoError(t, err)
algoHash := sha3.Sum256(algo)
reqFile, err := os.ReadFile(reqPath)
require.NoError(t, err)
testCases := []struct {
name string
err error
algo Algorithm
algoType string
}{
{
name: "Test Algo successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "python",
err: nil,
},
{
name: "Test Algo successfully with requirements file",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
Requirements: reqFile,
},
algoType: "python",
err: nil,
},
{
name: "Test Algo type binary successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "bin",
err: nil,
},
{
name: "Test Algo type wasm successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "wasm",
err: nil,
},
{
name: "Test Algo type docker successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "docker",
err: nil,
},
{
name: "Test algo hash mismatch",
algo: Algorithm{},
algoType: "python",
err: ErrHashMismatch,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = os.RemoveAll("datasets")
require.NoError(t, err)
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
time.Sleep(300 * time.Millisecond)
err = svc.Algo(ctx, tc.algo)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
})
}
}
func TestData(t *testing.T) {
events := new(mocks.Service)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
qp, err := quoteprovider.GetQuoteProvider()
require.NoError(t, err)
algo, err := os.ReadFile(algoPath)
require.NoError(t, err)
algoHash := sha3.Sum256(algo)
alg := Algorithm{
Hash: algoHash,
Algorithm: algo,
}
data, err := os.ReadFile(dataPath)
require.NoError(t, err)
dataHash := sha3.Sum256(data)
cases := []struct {
name string
data Dataset
err error
}{
{
name: "Test data successfully",
data: Dataset{
Hash: dataHash,
Dataset: data,
Filename: datasetFile,
},
},
{
name: "Test State not ready",
data: Dataset{
Dataset: data,
Hash: dataHash,
Filename: datasetFile,
},
err: ErrStateNotReady,
},
{
name: "Test File name does not match manifest",
data: Dataset{
Dataset: data,
Hash: dataHash,
Filename: "invalid",
},
err: ErrFileNameMismatch,
},
{
name: "Test dataset not declared in manifest",
data: Dataset{
Filename: datasetFile,
},
err: ErrUndeclaredDataset,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(
algorithm.AlgoTypeKey, "python",
python.PyRuntimeKey, python.PyRuntime),
)
if tc.err != ErrUndeclaredDataset {
ctx = IndexToContext(ctx, 0)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
comp := testComputation(t)
svc := New(ctx, mglog.NewMock(), events, comp, qp)
time.Sleep(300 * time.Millisecond)
if tc.err != ErrStateNotReady {
_ = svc.Algo(ctx, alg)
time.Sleep(300 * time.Millisecond)
}
err = svc.Data(ctx, tc.data)
_ = os.RemoveAll("datasets")
_ = os.RemoveAll("results")
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
})
}
}
func TestResult(t *testing.T) {
events := new(mocks.Service)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
qp, err := quoteprovider.GetQuoteProvider()
require.NoError(t, err)
cases := []struct {
name string
err error
setup func(svc *agentService)
ctxSetup func(ctx context.Context) context.Context
}{
{
name: "Test results not ready",
err: ErrResultsNotReady,
setup: func(svc *agentService) {
},
},
{
name: "Test all results consumed",
err: ErrAllResultsConsumed,
setup: func(svc *agentService) {
svc.sm.SetState(resultsReady)
svc.computation.ResultConsumers = []ResultConsumer{}
},
ctxSetup: func(ctx context.Context) context.Context {
return IndexToContext(ctx, 0)
},
},
{
name: "Test undeclared consumer",
err: ErrUndeclaredConsumer,
setup: func(svc *agentService) {
svc.sm.SetState(resultsReady)
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
},
ctxSetup: func(ctx context.Context) context.Context {
return ctx
},
},
{
name: "Test results consumed and event sent",
err: nil,
setup: func(svc *agentService) {
svc.sm.SetState(resultsReady)
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
},
ctxSetup: func(ctx context.Context) context.Context {
return IndexToContext(ctx, 0)
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
)
if tc.ctxSetup != nil {
ctx = tc.ctxSetup(ctx)
}
svc := &agentService{
sm: NewStateMachine(mglog.NewMock(), testComputation(t)),
eventSvc: events,
quoteProvider: qp,
computation: testComputation(t),
}
go svc.sm.Start(ctx)
tc.setup(svc)
_, err := svc.Result(ctx)
_ = os.RemoveAll("datasets")
_ = os.RemoveAll("results")
assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err)
})
}
}
func TestAttestation(t *testing.T) {
events := new(mocks.Service)
qp := new(mocks2.QuoteProvider)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
cases := []struct {
name string
reportData [ReportDataSize]byte
rawQuote []uint8
err error
}{
{
name: "Test attestation successful",
reportData: generateReportData(),
rawQuote: make([]uint8, 0),
err: nil,
},
{
name: "Test attestation failed",
reportData: generateReportData(),
rawQuote: nil,
err: ErrAttestationFailed,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
getQuote := qp.On("GetRawQuote", mock.Anything).Return(tc.rawQuote, tc.err)
if tc.err != ErrAttestationFailed {
getQuote = qp.On("GetRawQuote", mock.Anything).Return(tc.reportData, nil)
}
defer getQuote.Unset()
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
time.Sleep(300 * time.Millisecond)
_, err := svc.Attestation(ctx, tc.reportData)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
})
}
}
func generateReportData() [ReportDataSize]byte {
bytes := make([]byte, ReportDataSize)
_, err := rand.Read(bytes)
if err != nil {
log.Fatalf("Failed to generate random bytes: %v", err)
}
return [64]byte(bytes)
}
func testComputation(t *testing.T) Computation {
algo, err := os.ReadFile(algoPath)
require.NoError(t, err)
algoHash := sha3.Sum256(algo)
data, err := os.ReadFile(dataPath)
require.NoError(t, err)
dataHash := sha3.Sum256(data)
return Computation{
ID: "1",
Name: "sample computation",
Description: "sample description",
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
AgentConfig: AgentConfig{
Port: "7002",
LogLevel: "debug",
AttestedTls: false,
},
}
}