diff --git a/agent/mocks/agent.go b/agent/mocks/agent.go new file mode 100644 index 00000000..7f97ae96 --- /dev/null +++ b/agent/mocks/agent.go @@ -0,0 +1,129 @@ +// Code generated by mockery v2.42.3. DO NOT EDIT. + +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + context "context" + + agent "github.com/ultravioletrs/cocos/agent" + + mock "github.com/stretchr/testify/mock" +) + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +// Algo provides a mock function with given fields: ctx, algorithm +func (_m *Service) Algo(ctx context.Context, algorithm agent.Algorithm) error { + ret := _m.Called(ctx, algorithm) + + if len(ret) == 0 { + panic("no return value specified for Algo") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm) error); ok { + r0 = rf(ctx, algorithm) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Attestation provides a mock function with given fields: ctx, reportData +func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) { + ret := _m.Called(ctx, reportData) + + if len(ret) == 0 { + panic("no return value specified for Attestation") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok { + return rf(ctx, reportData) + } + if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok { + r0 = rf(ctx, reportData) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok { + r1 = rf(ctx, reportData) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Data provides a mock function with given fields: ctx, dataset +func (_m *Service) Data(ctx context.Context, dataset agent.Dataset) error { + ret := _m.Called(ctx, dataset) + + if len(ret) == 0 { + panic("no return value specified for Data") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset) error); ok { + r0 = rf(ctx, dataset) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Result provides a mock function with given fields: ctx +func (_m *Service) Result(ctx context.Context) ([]byte, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Result") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]byte, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []byte); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/mocks/auth.go b/agent/mocks/auth.go index 476d7ac7..718c9c14 100644 --- a/agent/mocks/auth.go +++ b/agent/mocks/auth.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.43.2. DO NOT EDIT. +// Code generated by mockery v2.42.3. DO NOT EDIT. // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 diff --git a/agent/service.go b/agent/service.go index d53e9a72..f833c0e8 100644 --- a/agent/service.go +++ b/agent/service.go @@ -34,21 +34,23 @@ var ( // when accessing a protected resource. ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided") // errUndeclaredAlgorithm indicates algorithm was not declared in computation manifest. - errUndeclaredDataset = errors.New("dataset not declared in computation manifest") + ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest") // errAllManifestItemsReceived indicates no new computation manifest items expected. - errAllManifestItemsReceived = errors.New("all expected manifest Items have been received") + ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received") // errUndeclaredConsumer indicates the consumer requesting results in not declared in computation manifest. - errUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest") + ErrUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest") // errResultsNotReady indicates the computation results are not ready. - errResultsNotReady = errors.New("computation results are not yet ready") + ErrResultsNotReady = errors.New("computation results are not yet ready") // errStateNotReady agent received a request in the wrong state. - errStateNotReady = errors.New("agent not expecting this operation in the current state") + ErrStateNotReady = errors.New("agent not expecting this operation in the current state") // errHashMismatch provided algorithm/dataset does not match hash in manifest. - errHashMismatch = errors.New("malformed data, hash does not match manifest") + ErrHashMismatch = errors.New("malformed data, hash does not match manifest") ) // Service specifies an API that must be fullfiled by the domain service // implementation, and all of its decorators (e.g. logging & metrics). +// +//go:generate mockery --name Service --output=mocks --filename agent.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" type Service interface { Algo(ctx context.Context, algorithm Algorithm) error Data(ctx context.Context, dataset Dataset) error @@ -92,16 +94,16 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error { if as.sm.GetState() != receivingAlgorithm { - return errStateNotReady + return ErrStateNotReady } if as.algorithm != "" { - return errAllManifestItemsReceived + return ErrAllManifestItemsReceived } hash := sha3.Sum256(algorithm.Algorithm) if hash != as.computation.Algorithm.Hash { - return errHashMismatch + return ErrHashMismatch } f, err := os.CreateTemp("", "algorithm") @@ -132,21 +134,21 @@ func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error { func (as *agentService) Data(ctx context.Context, dataset Dataset) error { if as.sm.GetState() != receivingData { - return errStateNotReady + return ErrStateNotReady } if len(as.computation.Datasets) == 0 { - return errAllManifestItemsReceived + return ErrAllManifestItemsReceived } hash := sha3.Sum256(dataset.Dataset) index, ok := IndexFromContext(ctx) if !ok { - return errUndeclaredDataset + return ErrUndeclaredDataset } if hash != as.computation.Datasets[index].Hash { - return errHashMismatch + return ErrHashMismatch } as.computation.Datasets = slices.Delete(as.computation.Datasets, index, index+1) @@ -173,14 +175,14 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { func (as *agentService) Result(ctx context.Context) ([]byte, error) { if as.sm.GetState() != resultsReady { - return []byte{}, errResultsNotReady + return []byte{}, ErrResultsNotReady } if len(as.computation.ResultConsumers) == 0 { - return []byte{}, errAllManifestItemsReceived + return []byte{}, ErrAllManifestItemsReceived } index, ok := IndexFromContext(ctx) if !ok { - return []byte{}, errUndeclaredConsumer + return []byte{}, ErrUndeclaredConsumer } as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1) diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go new file mode 100644 index 00000000..57ed7315 --- /dev/null +++ b/pkg/sdk/agent_test.go @@ -0,0 +1,435 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "errors" + "os" + "testing" + + mglog "github.com/absmach/magistrala/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/sdk" + "golang.org/x/crypto/sha3" + "google.golang.org/grpc" + "google.golang.org/grpc/status" +) + +var ( + algoPath = "../../test/manual/algo/lin_reg.py" + dataPath = "../../test/manual/data/iris.csv" + + errInappropriateIoctl = errors.New("inappropriate ioctl for device") +) + +func TestAlgo(t *testing.T) { + logger, err := mglog.New(os.Stdout, "info") + require.NoError(t, err) + + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + defer conn.Close() + + client := agent.NewAgentServiceClient(conn) + + sdk := sdk.NewAgentSDK(logger, client) + algo, err := os.ReadFile(algoPath) + require.NoError(t, err) + + algoHash := sha3.Sum256(algo) + + algorithmProviderKey, algorithmProviderPubKey := generateKeys(t, "ed25519") + + algoProvider1Key, algoProvider1PubKey := generateKeys(t, "ed25519") + + algorithm := agent.Algorithm{ + Algorithm: algo, + Hash: algoHash, + UserKey: algorithmProviderPubKey, + } + + cases := []struct { + name string + err error + algo agent.Algorithm + userKey any + }{ + { + name: "Test Algo successfully", + algo: algorithm, + userKey: algorithmProviderKey, + err: nil, + }, + { + name: "hash mismatch", + algo: agent.Algorithm{ + Algorithm: algo, + Hash: sha3.Sum256([]byte("algo")), + UserKey: algoProvider1PubKey, + }, + userKey: algoProvider1Key, + err: errInappropriateIoctl, + }, + { + name: "no manifest expected", + algo: agent.Algorithm{ + Algorithm: algo, + UserKey: algorithmProviderPubKey, + Hash: algoHash, + }, + userKey: algorithmProviderKey, + err: errInappropriateIoctl, + }, + { + name: "state not ready", + algo: agent.Algorithm{ + Algorithm: algo, + UserKey: algorithmProviderPubKey, + Hash: algoHash, + }, + userKey: algorithmProviderKey, + err: errInappropriateIoctl, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err) + err = sdk.Algo(context.Background(), tc.algo, tc.userKey) + + st, _ := status.FromError(err) + + if tc.err != nil { + if st.Message() != tc.err.Error() { + t.Errorf("%s : Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message()) + } + } + + svcCall.Unset() + }) + } +} + +func TestData(t *testing.T) { + logger, err := mglog.New(os.Stdout, "info") + require.NoError(t, err) + + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + defer conn.Close() + + client := agent.NewAgentServiceClient(conn) + + sdk := sdk.NewAgentSDK(logger, client) + + data, err := os.ReadFile(dataPath) + require.NoError(t, err) + + dataHash := sha3.Sum256(data) + + dataProviderKey, dataProviderPubKey := generateKeys(t, "ecdsa") + + dataProvider1Key, dataProvider1PubKey := generateKeys(t, "ed25519") + + dataset := agent.Dataset{ + Hash: dataHash, + Dataset: data, + UserKey: dataProviderPubKey, + } + + cases := []struct { + name string + data agent.Dataset + userKey any + svcErr error + }{ + { + name: "Test data successfully", + data: dataset, + userKey: dataProviderKey, + }, + { + name: "undeclared dataset", + data: agent.Dataset{ + Dataset: data, + UserKey: dataProvider1PubKey, + Hash: dataHash, + }, + userKey: dataProvider1Key, + svcErr: errInappropriateIoctl, + }, + { + name: "hash mismatch", + data: agent.Dataset{ + Dataset: data, + UserKey: dataProvider1PubKey, + Hash: dataHash, + }, + userKey: dataProvider1Key, + svcErr: errInappropriateIoctl, + }, + { + name: "all manifest items received", + data: agent.Dataset{ + Dataset: data, + UserKey: dataProvider1PubKey, + Hash: dataHash, + }, + userKey: dataProvider1Key, + svcErr: errInappropriateIoctl, + }, + { + name: "missing dataset file", + data: agent.Dataset{ + UserKey: dataProvider1PubKey, + Hash: dataHash, + }, + userKey: dataProvider1Key, + svcErr: errors.New("dataset CSV file is required"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr) + + err = sdk.Data(context.Background(), tc.data, tc.userKey) + + st, _ := status.FromError(err) + + if tc.svcErr != nil { + if st.Message() != tc.svcErr.Error() { + t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.svcErr.Error(), st.Message()) + } + } + + dataCall.Unset() + }) + } +} + +func TestResult(t *testing.T) { + logger, err := mglog.New(os.Stdout, "info") + require.NoError(t, err) + + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + defer conn.Close() + + client := agent.NewAgentServiceClient(conn) + + sdk := sdk.NewAgentSDK(logger, client) + + resultConsumerKey, _ := generateKeys(t, "ecdsa") + resultConsumer1Key, _ := generateKeys(t, "ed25519") + + response := &agent.ResultResponse{ + File: []byte{ + 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, + }, + } + + cases := []struct { + name string + userKey any + response *agent.ResultResponse + svcRes []byte + err error + }{ + { + name: "Test result successfully", + userKey: resultConsumerKey, + response: response, + svcRes: response.File, + err: nil, + }, + { + name: "Test result successfully with ed25519 key type", + userKey: resultConsumer1Key, + response: response, + svcRes: response.File, + err: nil, + }, + { + name: "Results not ready", + userKey: resultConsumer1Key, + response: &agent.ResultResponse{ + File: []byte(nil), + }, + svcRes: nil, + err: agent.ErrResultsNotReady, + }, + { + name: "All manifest items received", + userKey: resultConsumer1Key, + response: &agent.ResultResponse{ + File: []byte(nil), + }, + svcRes: nil, + err: agent.ErrAllManifestItemsReceived, + }, + { + name: "Undeclared consumer", + userKey: resultConsumer1Key, + response: &agent.ResultResponse{ + File: []byte(nil), + }, + svcRes: nil, + err: agent.ErrUndeclaredConsumer, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) + res, err := sdk.Result(context.Background(), tc.userKey) + + st, ok := status.FromError(err) + if !ok { + t.Fatalf("Expected gRPC status error, but got: %v", err) + } + + if tc.err != nil { + if st.Message() != tc.err.Error() { + t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message()) + } + } + assert.Equal(t, tc.response.File, res, tc.name) + + svcCall.Unset() + }) + } +} + +func TestAttestation(t *testing.T) { + logger, err := mglog.New(os.Stdout, "info") + require.NoError(t, err) + + resultConsumerKey, _ := generateKeys(t, "rsa") + resultConsumer1Key, _ := generateKeys(t, "ed25519") + + reportData := make([]byte, 64) + report := []byte{ + 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, + } + + conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Failed to dial bufnet: %v", err) + } + defer conn.Close() + + client := agent.NewAgentServiceClient(conn) + + sdk := sdk.NewAgentSDK(logger, client) + + _, err = rand.Read(reportData) + require.NoError(t, err) + + cases := []struct { + name string + userKey any + reportData [agent.ReportDataSize]byte + response *agent.AttestationResponse + svcRes []byte + err error + }{ + { + name: "fetch attestation report successfully", + userKey: resultConsumerKey, + reportData: [agent.ReportDataSize]byte(reportData), + response: &agent.AttestationResponse{ + File: report, + }, + svcRes: report, + err: nil, + }, + { + name: "fetch attestation report with different key type", + userKey: resultConsumer1Key, + reportData: [agent.ReportDataSize]byte(reportData), + response: &agent.AttestationResponse{ + File: report, + }, + svcRes: report, + err: nil, + }, + { + name: "failed to fetch attestation report", + userKey: resultConsumerKey, + reportData: [agent.ReportDataSize]byte(reportData), + response: &agent.AttestationResponse{ + File: nil, + }, + err: nil, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) + + res, err := sdk.Attestation(context.Background(), tc.reportData) + + st, ok := status.FromError(err) + if !ok { + t.Fatalf("Expected gRPC status error, but got: %v", err) + } + + if tc.err != nil { + if st.Message() != tc.err.Error() { + t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message()) + } + } + + assert.Equal(t, tc.response.File, res, tc.name) + + svcCall.Unset() + }) + } +} + +func generateKeys(t *testing.T, keyType string) (priv any, pub []byte) { + switch keyType { + case "ecdsa": + privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privEcdsaKey.PublicKey) + require.NoError(t, err) + return privEcdsaKey, pubKeyBytes + + case "ed25519": + pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + pubKey, err := x509.MarshalPKIXPublicKey(pubEd25519Key) + require.NoError(t, err) + return privEd25519Key, pubKey + + default: + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + require.NoError(t, err) + return privKey, pubKeyBytes + } +} diff --git a/pkg/sdk/setup_test.go b/pkg/sdk/setup_test.go new file mode 100644 index 00000000..c42dd596 --- /dev/null +++ b/pkg/sdk/setup_test.go @@ -0,0 +1,48 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "fmt" + "net" + "os" + "testing" + + "github.com/ultravioletrs/cocos/agent" + agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc" + "github.com/ultravioletrs/cocos/agent/mocks" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" +) + +const bufSize = 1024 * 1024 + +var ( + lis *bufconn.Listener + svc = &mocks.Service{} +) + +func TestMain(m *testing.M) { + lis = bufconn.Listen(bufSize) + s := grpc.NewServer() + + agent.RegisterAgentServiceServer(s, agentgrpc.NewServer(svc)) + + go func() { + if err := s.Serve(lis); err != nil { + fmt.Println("Server exited with error:", err) + } + }() + + code := m.Run() + + s.Stop() + lis.Close() + os.Exit(code) +} + +func bufDialer(context.Context, string) (net.Conn, error) { + return lis.Dial() +}