mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
COCOS-144 - Add Agent SDK Tests (#167)
* add tests and mocks Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> fix ci Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> update test Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> fix(agent/grpc): revert change Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> fix ci Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> * refactor attestation and report tests Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> refactor tests Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> remove commented code Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> remove comment Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> remove comments * add test cases Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> export agent errors Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> remove comm Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> * fix tests Signed-off-by: WashingtonKK <washingtonkigan@gmail.com> --------- Signed-off-by: WashingtonKK <washingtonkigan@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2ceb1c3562
commit
d76074ae41
@@ -0,0 +1,129 @@
|
||||
// Code generated by mockery v2.42.3. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Algo provides a mock function with given fields: ctx, algorithm
|
||||
func (_m *Service) Algo(ctx context.Context, algorithm agent.Algorithm) error {
|
||||
ret := _m.Called(ctx, algorithm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Algo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm) error); ok {
|
||||
r0 = rf(ctx, algorithm)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Attestation provides a mock function with given fields: ctx, reportData
|
||||
func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) {
|
||||
ret := _m.Called(ctx, reportData)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Attestation")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok {
|
||||
return rf(ctx, reportData)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok {
|
||||
r0 = rf(ctx, reportData)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok {
|
||||
r1 = rf(ctx, reportData)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Data provides a mock function with given fields: ctx, dataset
|
||||
func (_m *Service) Data(ctx context.Context, dataset agent.Dataset) error {
|
||||
ret := _m.Called(ctx, dataset)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Data")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset) error); ok {
|
||||
r0 = rf(ctx, dataset)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Result provides a mock function with given fields: ctx
|
||||
func (_m *Service) Result(ctx context.Context) ([]byte, error) {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Result")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) ([]byte, error)); ok {
|
||||
return rf(ctx)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context) []byte); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
|
||||
r1 = rf(ctx)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
// Code generated by mockery v2.42.3. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
+18
-16
@@ -34,21 +34,23 @@ var (
|
||||
// when accessing a protected resource.
|
||||
ErrUnauthorizedAccess = errors.New("missing or invalid credentials provided")
|
||||
// errUndeclaredAlgorithm indicates algorithm was not declared in computation manifest.
|
||||
errUndeclaredDataset = errors.New("dataset not declared in computation manifest")
|
||||
ErrUndeclaredDataset = errors.New("dataset not declared in computation manifest")
|
||||
// errAllManifestItemsReceived indicates no new computation manifest items expected.
|
||||
errAllManifestItemsReceived = errors.New("all expected manifest Items have been received")
|
||||
ErrAllManifestItemsReceived = errors.New("all expected manifest Items have been received")
|
||||
// errUndeclaredConsumer indicates the consumer requesting results in not declared in computation manifest.
|
||||
errUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest")
|
||||
ErrUndeclaredConsumer = errors.New("result consumer is undeclared in computation manifest")
|
||||
// errResultsNotReady indicates the computation results are not ready.
|
||||
errResultsNotReady = errors.New("computation results are not yet ready")
|
||||
ErrResultsNotReady = errors.New("computation results are not yet ready")
|
||||
// errStateNotReady agent received a request in the wrong state.
|
||||
errStateNotReady = errors.New("agent not expecting this operation in the current state")
|
||||
ErrStateNotReady = errors.New("agent not expecting this operation in the current state")
|
||||
// errHashMismatch provided algorithm/dataset does not match hash in manifest.
|
||||
errHashMismatch = errors.New("malformed data, hash does not match manifest")
|
||||
ErrHashMismatch = errors.New("malformed data, hash does not match manifest")
|
||||
)
|
||||
|
||||
// Service specifies an API that must be fullfiled by the domain service
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
//
|
||||
//go:generate mockery --name Service --output=mocks --filename agent.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Service interface {
|
||||
Algo(ctx context.Context, algorithm Algorithm) error
|
||||
Data(ctx context.Context, dataset Dataset) error
|
||||
@@ -92,16 +94,16 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
|
||||
|
||||
func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error {
|
||||
if as.sm.GetState() != receivingAlgorithm {
|
||||
return errStateNotReady
|
||||
return ErrStateNotReady
|
||||
}
|
||||
if as.algorithm != "" {
|
||||
return errAllManifestItemsReceived
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
|
||||
hash := sha3.Sum256(algorithm.Algorithm)
|
||||
|
||||
if hash != as.computation.Algorithm.Hash {
|
||||
return errHashMismatch
|
||||
return ErrHashMismatch
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "algorithm")
|
||||
@@ -132,21 +134,21 @@ func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error {
|
||||
|
||||
func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
if as.sm.GetState() != receivingData {
|
||||
return errStateNotReady
|
||||
return ErrStateNotReady
|
||||
}
|
||||
if len(as.computation.Datasets) == 0 {
|
||||
return errAllManifestItemsReceived
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
|
||||
hash := sha3.Sum256(dataset.Dataset)
|
||||
|
||||
index, ok := IndexFromContext(ctx)
|
||||
if !ok {
|
||||
return errUndeclaredDataset
|
||||
return ErrUndeclaredDataset
|
||||
}
|
||||
|
||||
if hash != as.computation.Datasets[index].Hash {
|
||||
return errHashMismatch
|
||||
return ErrHashMismatch
|
||||
}
|
||||
as.computation.Datasets = slices.Delete(as.computation.Datasets, index, index+1)
|
||||
|
||||
@@ -173,14 +175,14 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
|
||||
func (as *agentService) Result(ctx context.Context) ([]byte, error) {
|
||||
if as.sm.GetState() != resultsReady {
|
||||
return []byte{}, errResultsNotReady
|
||||
return []byte{}, ErrResultsNotReady
|
||||
}
|
||||
if len(as.computation.ResultConsumers) == 0 {
|
||||
return []byte{}, errAllManifestItemsReceived
|
||||
return []byte{}, ErrAllManifestItemsReceived
|
||||
}
|
||||
index, ok := IndexFromContext(ctx)
|
||||
if !ok {
|
||||
return []byte{}, errUndeclaredConsumer
|
||||
return []byte{}, ErrUndeclaredConsumer
|
||||
}
|
||||
as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1)
|
||||
|
||||
|
||||
@@ -0,0 +1,435 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package sdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/sdk"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var (
|
||||
algoPath = "../../test/manual/algo/lin_reg.py"
|
||||
dataPath = "../../test/manual/data/iris.csv"
|
||||
|
||||
errInappropriateIoctl = errors.New("inappropriate ioctl for device")
|
||||
)
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
logger, err := mglog.New(os.Stdout, "info")
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := agent.NewAgentServiceClient(conn)
|
||||
|
||||
sdk := sdk.NewAgentSDK(logger, client)
|
||||
algo, err := os.ReadFile(algoPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
algoHash := sha3.Sum256(algo)
|
||||
|
||||
algorithmProviderKey, algorithmProviderPubKey := generateKeys(t, "ed25519")
|
||||
|
||||
algoProvider1Key, algoProvider1PubKey := generateKeys(t, "ed25519")
|
||||
|
||||
algorithm := agent.Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: algoHash,
|
||||
UserKey: algorithmProviderPubKey,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
algo agent.Algorithm
|
||||
userKey any
|
||||
}{
|
||||
{
|
||||
name: "Test Algo successfully",
|
||||
algo: algorithm,
|
||||
userKey: algorithmProviderKey,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "hash mismatch",
|
||||
algo: agent.Algorithm{
|
||||
Algorithm: algo,
|
||||
Hash: sha3.Sum256([]byte("algo")),
|
||||
UserKey: algoProvider1PubKey,
|
||||
},
|
||||
userKey: algoProvider1Key,
|
||||
err: errInappropriateIoctl,
|
||||
},
|
||||
{
|
||||
name: "no manifest expected",
|
||||
algo: agent.Algorithm{
|
||||
Algorithm: algo,
|
||||
UserKey: algorithmProviderPubKey,
|
||||
Hash: algoHash,
|
||||
},
|
||||
userKey: algorithmProviderKey,
|
||||
err: errInappropriateIoctl,
|
||||
},
|
||||
{
|
||||
name: "state not ready",
|
||||
algo: agent.Algorithm{
|
||||
Algorithm: algo,
|
||||
UserKey: algorithmProviderPubKey,
|
||||
Hash: algoHash,
|
||||
},
|
||||
userKey: algorithmProviderKey,
|
||||
err: errInappropriateIoctl,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err)
|
||||
err = sdk.Algo(context.Background(), tc.algo, tc.userKey)
|
||||
|
||||
st, _ := status.FromError(err)
|
||||
|
||||
if tc.err != nil {
|
||||
if st.Message() != tc.err.Error() {
|
||||
t.Errorf("%s : Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
logger, err := mglog.New(os.Stdout, "info")
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := agent.NewAgentServiceClient(conn)
|
||||
|
||||
sdk := sdk.NewAgentSDK(logger, client)
|
||||
|
||||
data, err := os.ReadFile(dataPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
dataHash := sha3.Sum256(data)
|
||||
|
||||
dataProviderKey, dataProviderPubKey := generateKeys(t, "ecdsa")
|
||||
|
||||
dataProvider1Key, dataProvider1PubKey := generateKeys(t, "ed25519")
|
||||
|
||||
dataset := agent.Dataset{
|
||||
Hash: dataHash,
|
||||
Dataset: data,
|
||||
UserKey: dataProviderPubKey,
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
data agent.Dataset
|
||||
userKey any
|
||||
svcErr error
|
||||
}{
|
||||
{
|
||||
name: "Test data successfully",
|
||||
data: dataset,
|
||||
userKey: dataProviderKey,
|
||||
},
|
||||
{
|
||||
name: "undeclared dataset",
|
||||
data: agent.Dataset{
|
||||
Dataset: data,
|
||||
UserKey: dataProvider1PubKey,
|
||||
Hash: dataHash,
|
||||
},
|
||||
userKey: dataProvider1Key,
|
||||
svcErr: errInappropriateIoctl,
|
||||
},
|
||||
{
|
||||
name: "hash mismatch",
|
||||
data: agent.Dataset{
|
||||
Dataset: data,
|
||||
UserKey: dataProvider1PubKey,
|
||||
Hash: dataHash,
|
||||
},
|
||||
userKey: dataProvider1Key,
|
||||
svcErr: errInappropriateIoctl,
|
||||
},
|
||||
{
|
||||
name: "all manifest items received",
|
||||
data: agent.Dataset{
|
||||
Dataset: data,
|
||||
UserKey: dataProvider1PubKey,
|
||||
Hash: dataHash,
|
||||
},
|
||||
userKey: dataProvider1Key,
|
||||
svcErr: errInappropriateIoctl,
|
||||
},
|
||||
{
|
||||
name: "missing dataset file",
|
||||
data: agent.Dataset{
|
||||
UserKey: dataProvider1PubKey,
|
||||
Hash: dataHash,
|
||||
},
|
||||
userKey: dataProvider1Key,
|
||||
svcErr: errors.New("dataset CSV file is required"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr)
|
||||
|
||||
err = sdk.Data(context.Background(), tc.data, tc.userKey)
|
||||
|
||||
st, _ := status.FromError(err)
|
||||
|
||||
if tc.svcErr != nil {
|
||||
if st.Message() != tc.svcErr.Error() {
|
||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.svcErr.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
|
||||
dataCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
logger, err := mglog.New(os.Stdout, "info")
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := agent.NewAgentServiceClient(conn)
|
||||
|
||||
sdk := sdk.NewAgentSDK(logger, client)
|
||||
|
||||
resultConsumerKey, _ := generateKeys(t, "ecdsa")
|
||||
resultConsumer1Key, _ := generateKeys(t, "ed25519")
|
||||
|
||||
response := &agent.ResultResponse{
|
||||
File: []byte{
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
userKey any
|
||||
response *agent.ResultResponse
|
||||
svcRes []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Test result successfully",
|
||||
userKey: resultConsumerKey,
|
||||
response: response,
|
||||
svcRes: response.File,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Test result successfully with ed25519 key type",
|
||||
userKey: resultConsumer1Key,
|
||||
response: response,
|
||||
svcRes: response.File,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Results not ready",
|
||||
userKey: resultConsumer1Key,
|
||||
response: &agent.ResultResponse{
|
||||
File: []byte(nil),
|
||||
},
|
||||
svcRes: nil,
|
||||
err: agent.ErrResultsNotReady,
|
||||
},
|
||||
{
|
||||
name: "All manifest items received",
|
||||
userKey: resultConsumer1Key,
|
||||
response: &agent.ResultResponse{
|
||||
File: []byte(nil),
|
||||
},
|
||||
svcRes: nil,
|
||||
err: agent.ErrAllManifestItemsReceived,
|
||||
},
|
||||
{
|
||||
name: "Undeclared consumer",
|
||||
userKey: resultConsumer1Key,
|
||||
response: &agent.ResultResponse{
|
||||
File: []byte(nil),
|
||||
},
|
||||
svcRes: nil,
|
||||
err: agent.ErrUndeclaredConsumer,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||
res, err := sdk.Result(context.Background(), tc.userKey)
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("Expected gRPC status error, but got: %v", err)
|
||||
}
|
||||
|
||||
if tc.err != nil {
|
||||
if st.Message() != tc.err.Error() {
|
||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
assert.Equal(t, tc.response.File, res, tc.name)
|
||||
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestation(t *testing.T) {
|
||||
logger, err := mglog.New(os.Stdout, "info")
|
||||
require.NoError(t, err)
|
||||
|
||||
resultConsumerKey, _ := generateKeys(t, "rsa")
|
||||
resultConsumer1Key, _ := generateKeys(t, "ed25519")
|
||||
|
||||
reportData := make([]byte, 64)
|
||||
report := []byte{
|
||||
0x01, 0x02, 0x03, 0x04,
|
||||
0x05, 0x06, 0x07, 0x08,
|
||||
}
|
||||
|
||||
conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial bufnet: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := agent.NewAgentServiceClient(conn)
|
||||
|
||||
sdk := sdk.NewAgentSDK(logger, client)
|
||||
|
||||
_, err = rand.Read(reportData)
|
||||
require.NoError(t, err)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
userKey any
|
||||
reportData [agent.ReportDataSize]byte
|
||||
response *agent.AttestationResponse
|
||||
svcRes []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "fetch attestation report successfully",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
},
|
||||
svcRes: report,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "fetch attestation report with different key type",
|
||||
userKey: resultConsumer1Key,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
response: &agent.AttestationResponse{
|
||||
File: report,
|
||||
},
|
||||
svcRes: report,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "failed to fetch attestation report",
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
response: &agent.AttestationResponse{
|
||||
File: nil,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||
|
||||
res, err := sdk.Attestation(context.Background(), tc.reportData)
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("Expected gRPC status error, but got: %v", err)
|
||||
}
|
||||
|
||||
if tc.err != nil {
|
||||
if st.Message() != tc.err.Error() {
|
||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.response.File, res, tc.name)
|
||||
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateKeys(t *testing.T, keyType string) (priv any, pub []byte) {
|
||||
switch keyType {
|
||||
case "ecdsa":
|
||||
privEcdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privEcdsaKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
return privEcdsaKey, pubKeyBytes
|
||||
|
||||
case "ed25519":
|
||||
pubEd25519Key, privEd25519Key, err := ed25519.GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKey, err := x509.MarshalPKIXPublicKey(pubEd25519Key)
|
||||
require.NoError(t, err)
|
||||
return privEd25519Key, pubKey
|
||||
|
||||
default:
|
||||
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
require.NoError(t, err)
|
||||
return privKey, pubKeyBytes
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package sdk_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
const bufSize = 1024 * 1024
|
||||
|
||||
var (
|
||||
lis *bufconn.Listener
|
||||
svc = &mocks.Service{}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
lis = bufconn.Listen(bufSize)
|
||||
s := grpc.NewServer()
|
||||
|
||||
agent.RegisterAgentServiceServer(s, agentgrpc.NewServer(svc))
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
fmt.Println("Server exited with error:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
code := m.Run()
|
||||
|
||||
s.Stop()
|
||||
lis.Close()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func bufDialer(context.Context, string) (net.Conn, error) {
|
||||
return lis.Dial()
|
||||
}
|
||||
Reference in New Issue
Block a user