mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
4b27b98edb
* Refactor attestation handling: rename AttestationResult to AzureAttestationToken - Updated the protobuf definition to change azureAttestationResponse to azureAttestationToken. - Refactored the Service interface and its implementation to replace AttestationResult with AzureAttestationToken. - Modified mock functions and tests to reflect the new naming and functionality. - Adjusted CLI commands to use the new AzureAttestationToken method. - Removed the AzureToken constant from the attestation package as it is no longer needed. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove redundant data checks and logging in SendData and sendData methods Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update agent/api/grpc/server_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update agent/api/grpc/endpoint_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Refactor attestation handling: rename AttestationToken to AzureAttestationToken in server and test files Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor attestation command output messages for clarity and consistency Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Rename AttestationToken to AzureAttestationToken in TestAttestationToken for consistency Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Refactor TestChangeAttestationConfiguration to use vtpm.ConvertPolicyToJSON for JSON conversion Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix: reset temporary file pointer after zipping directory Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
672 lines
17 KiB
Go
672 lines
17 KiB
Go
// Copyright (c) Ultraviolet
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
package agent
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
mglog "github.com/absmach/supermq/logger"
|
|
"github.com/absmach/supermq/pkg/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/ultravioletrs/cocos/agent/algorithm"
|
|
algomocks "github.com/ultravioletrs/cocos/agent/algorithm/mocks"
|
|
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
|
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
|
"github.com/ultravioletrs/cocos/agent/statemachine"
|
|
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation"
|
|
mocks2 "github.com/ultravioletrs/cocos/pkg/attestation/mocks"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
|
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
|
"golang.org/x/crypto/sha3"
|
|
"google.golang.org/grpc/metadata"
|
|
)
|
|
|
|
var (
|
|
algoPath = "../test/manual/algo/lin_reg.py"
|
|
reqPath = "../test/manual/algo/requirements.txt"
|
|
dataPath = "../test/manual/data/iris.csv"
|
|
)
|
|
|
|
const datasetFile = "iris.csv"
|
|
|
|
func TestAlgo(t *testing.T) {
|
|
algo, err := os.ReadFile(algoPath)
|
|
require.NoError(t, err)
|
|
|
|
algoHash := sha3.Sum256(algo)
|
|
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
|
|
|
reqFile, err := os.ReadFile(reqPath)
|
|
require.NoError(t, err)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
err error
|
|
algo Algorithm
|
|
algoType string
|
|
}{
|
|
{
|
|
name: "Test Algo successfully",
|
|
algo: Algorithm{
|
|
Algorithm: algo,
|
|
Hash: algoHash,
|
|
},
|
|
algoType: "python",
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test Algo successfully with requirements file",
|
|
algo: Algorithm{
|
|
Algorithm: algo,
|
|
Hash: algoHash,
|
|
Requirements: reqFile,
|
|
},
|
|
algoType: "python",
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test Algo type binary successfully",
|
|
algo: Algorithm{
|
|
Algorithm: algo,
|
|
Hash: algoHash,
|
|
},
|
|
algoType: "bin",
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test Algo type wasm successfully",
|
|
algo: Algorithm{
|
|
Algorithm: algo,
|
|
Hash: algoHash,
|
|
},
|
|
algoType: "wasm",
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test Algo type docker successfully",
|
|
algo: Algorithm{
|
|
Algorithm: algo,
|
|
Hash: algoHash,
|
|
},
|
|
algoType: "docker",
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test algo hash mismatch",
|
|
algo: Algorithm{},
|
|
algoType: "python",
|
|
err: ErrHashMismatch,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err = os.RemoveAll("datasets")
|
|
require.NoError(t, err)
|
|
|
|
ctx := metadata.NewIncomingContext(context.Background(),
|
|
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
|
|
)
|
|
|
|
events := new(mocks.Service)
|
|
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0)
|
|
|
|
err := svc.InitComputation(ctx, testComputation(t))
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(300 * time.Millisecond)
|
|
|
|
err = svc.Algo(ctx, tc.algo)
|
|
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
|
t.Cleanup(func() {
|
|
err = os.RemoveAll("venv")
|
|
err = os.RemoveAll("algo")
|
|
err = os.RemoveAll("datasets")
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestData(t *testing.T) {
|
|
algo, err := os.ReadFile(algoPath)
|
|
require.NoError(t, err)
|
|
|
|
algoHash := sha3.Sum256(algo)
|
|
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
|
|
|
alg := Algorithm{
|
|
Hash: algoHash,
|
|
Algorithm: algo,
|
|
}
|
|
|
|
data, err := os.ReadFile(dataPath)
|
|
require.NoError(t, err)
|
|
|
|
dataHash := sha3.Sum256(data)
|
|
|
|
cases := []struct {
|
|
name string
|
|
data Dataset
|
|
err error
|
|
}{
|
|
{
|
|
name: "Test data successfully",
|
|
data: Dataset{
|
|
Hash: dataHash,
|
|
Dataset: data,
|
|
Filename: datasetFile,
|
|
},
|
|
},
|
|
{
|
|
name: "Test State not ready",
|
|
data: Dataset{
|
|
Dataset: data,
|
|
Hash: dataHash,
|
|
Filename: datasetFile,
|
|
},
|
|
err: ErrStateNotReady,
|
|
},
|
|
{
|
|
name: "Test File name does not match manifest",
|
|
data: Dataset{
|
|
Dataset: data,
|
|
Hash: dataHash,
|
|
Filename: "invalid",
|
|
},
|
|
err: ErrFileNameMismatch,
|
|
},
|
|
{
|
|
name: "Test dataset not declared in manifest",
|
|
data: Dataset{
|
|
Filename: datasetFile,
|
|
},
|
|
err: ErrUndeclaredDataset,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctx := metadata.NewIncomingContext(context.Background(),
|
|
metadata.Pairs(
|
|
algorithm.AlgoTypeKey, "python",
|
|
python.PyRuntimeKey, python.PyRuntime),
|
|
)
|
|
|
|
events := new(mocks.Service)
|
|
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
if tc.err != ErrUndeclaredDataset {
|
|
ctx = IndexToContext(ctx, 0)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0)
|
|
|
|
err := svc.InitComputation(ctx, testComputation(t))
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(300 * time.Millisecond)
|
|
|
|
if tc.err != ErrStateNotReady {
|
|
err = svc.Algo(ctx, alg)
|
|
require.NoError(t, err)
|
|
time.Sleep(300 * time.Millisecond)
|
|
}
|
|
err = svc.Data(ctx, tc.data)
|
|
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
|
t.Cleanup(func() {
|
|
_ = os.RemoveAll("datasets")
|
|
_ = os.RemoveAll("results")
|
|
err = os.RemoveAll("venv")
|
|
err = os.RemoveAll("algo")
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResult(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
err error
|
|
setup func(svc *agentService)
|
|
ctxSetup func(ctx context.Context) context.Context
|
|
state statemachine.State
|
|
}{
|
|
{
|
|
name: "Test results not ready",
|
|
err: ErrResultsNotReady,
|
|
setup: func(svc *agentService) {
|
|
},
|
|
state: Running,
|
|
},
|
|
{
|
|
name: "Test undeclared consumer",
|
|
err: ErrUndeclaredConsumer,
|
|
setup: func(svc *agentService) {
|
|
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
|
|
},
|
|
ctxSetup: func(ctx context.Context) context.Context {
|
|
return ctx
|
|
},
|
|
state: ConsumingResults,
|
|
},
|
|
{
|
|
name: "Test results consumed and event sent",
|
|
err: nil,
|
|
setup: func(svc *agentService) {
|
|
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
|
|
},
|
|
ctxSetup: func(ctx context.Context) context.Context {
|
|
return IndexToContext(ctx, 0)
|
|
},
|
|
state: ConsumingResults,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
events := new(mocks.Service)
|
|
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ctx := metadata.NewIncomingContext(context.Background(),
|
|
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
|
)
|
|
|
|
if tc.ctxSetup != nil {
|
|
ctx = tc.ctxSetup(ctx)
|
|
}
|
|
|
|
sm := new(smmocks.StateMachine)
|
|
sm.On("Start", ctx).Return(nil)
|
|
sm.On("GetState").Return(tc.state)
|
|
sm.On("SendEvent", mock.Anything).Return()
|
|
|
|
svc := &agentService{
|
|
sm: sm,
|
|
eventSvc: events,
|
|
provider: &attestation.EmptyProvider{},
|
|
computation: testComputation(t),
|
|
}
|
|
|
|
go func() {
|
|
if err := svc.sm.Start(ctx); err != nil {
|
|
t.Errorf("Error starting state machine: %v", err)
|
|
}
|
|
}()
|
|
tc.setup(svc)
|
|
_, err := svc.Result(ctx)
|
|
t.Cleanup(func() {
|
|
_ = os.RemoveAll("datasets")
|
|
_ = os.RemoveAll("results")
|
|
})
|
|
assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAttestation(t *testing.T) {
|
|
provider := new(mocks2.Provider)
|
|
|
|
cases := []struct {
|
|
name string
|
|
reportData [quoteprovider.Nonce]byte
|
|
nonce [vtpm.Nonce]byte
|
|
rawQuote []uint8
|
|
platform attestation.PlatformType
|
|
err error
|
|
}{
|
|
{
|
|
name: "Test SNP attestation successful",
|
|
reportData: generateReportData(),
|
|
nonce: [32]byte{},
|
|
rawQuote: make([]uint8, 0),
|
|
platform: attestation.SNP,
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test SNP attestation failed",
|
|
reportData: generateReportData(),
|
|
nonce: [32]byte{},
|
|
rawQuote: nil,
|
|
platform: attestation.SNP,
|
|
err: ErrAttestationFailed,
|
|
},
|
|
{
|
|
name: "Test vTPM attestation successful",
|
|
reportData: generateReportData(),
|
|
nonce: [32]byte{},
|
|
rawQuote: make([]uint8, 0),
|
|
platform: attestation.VTPM,
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test vTPM attestation failed",
|
|
reportData: generateReportData(),
|
|
nonce: [32]byte{},
|
|
rawQuote: nil,
|
|
platform: attestation.VTPM,
|
|
err: ErrAttestationVTpmFailed,
|
|
},
|
|
{
|
|
name: "Test SNP-vTPM attestation successful",
|
|
reportData: generateReportData(),
|
|
nonce: [32]byte{},
|
|
rawQuote: make([]uint8, 0),
|
|
platform: attestation.SNPvTPM,
|
|
err: nil,
|
|
},
|
|
{
|
|
name: "Test SNP-vTPM attestation failed",
|
|
reportData: generateReportData(),
|
|
nonce: [32]byte{},
|
|
rawQuote: nil,
|
|
platform: attestation.SNPvTPM,
|
|
err: ErrAttestationVTpmFailed,
|
|
},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
events := new(mocks.Service)
|
|
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
ctx := metadata.NewIncomingContext(context.Background(),
|
|
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
|
)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
getQuote := provider.On("TeeAttestation", mock.Anything).Return(tc.rawQuote, tc.err)
|
|
vtpmQuote := provider.On("VTpmAttestation", mock.Anything).Return(tc.rawQuote, tc.err)
|
|
snpVtpm := provider.On("Attestation", mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err)
|
|
if tc.err != ErrAttestationFailed && tc.err != ErrAttestationVTpmFailed {
|
|
getQuote = provider.On("TeeAttestation", mock.Anything).Return(tc.nonce, nil)
|
|
vtpmQuote = provider.On("VTpmAttestation", mock.Anything).Return(tc.nonce[:], nil)
|
|
snpVtpm = provider.On("Attestation", mock.Anything, mock.Anything).Return(tc.nonce[:], nil)
|
|
}
|
|
defer getQuote.Unset()
|
|
defer vtpmQuote.Unset()
|
|
defer snpVtpm.Unset()
|
|
|
|
svc := New(ctx, mglog.NewMock(), events, provider, 0)
|
|
time.Sleep(300 * time.Millisecond)
|
|
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform)
|
|
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAzureAttestationToken(t *testing.T) {
|
|
provider := new(mocks2.Provider)
|
|
cases := []struct {
|
|
name string
|
|
nonce [vtpm.Nonce]byte
|
|
token []byte
|
|
err error
|
|
}{
|
|
{
|
|
name: "Azure token fetch successful",
|
|
nonce: [32]byte{1, 2, 3}, // any test nonce
|
|
token: []byte("mockToken"),
|
|
err: ErrAttestationType,
|
|
},
|
|
{
|
|
name: "Azure token fetch failed",
|
|
nonce: [32]byte{4, 5, 6},
|
|
token: []byte{},
|
|
err: ErrAttestationType,
|
|
},
|
|
{
|
|
name: "Invalid attestation type",
|
|
nonce: [32]byte{7, 8, 9},
|
|
token: []byte{},
|
|
err: ErrAttestationType,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
events := new(mocks.Service)
|
|
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
provider.On("AzureAttestationToken", tc.nonce[:]).Return(tc.token, tc.err)
|
|
|
|
ctx := context.Background()
|
|
|
|
svc := New(ctx, mglog.NewMock(), events, provider, 0)
|
|
|
|
_, err := svc.AzureAttestationToken(ctx, tc.nonce)
|
|
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
|
})
|
|
}
|
|
}
|
|
|
|
func generateReportData() [quoteprovider.Nonce]byte {
|
|
bytes := make([]byte, quoteprovider.Nonce)
|
|
_, err := rand.Read(bytes)
|
|
if err != nil {
|
|
log.Fatalf("Failed to generate random bytes: %v", err)
|
|
}
|
|
return [64]byte(bytes)
|
|
}
|
|
|
|
func testComputation(t *testing.T) Computation {
|
|
algo, err := os.ReadFile(algoPath)
|
|
require.NoError(t, err)
|
|
|
|
algoHash := sha3.Sum256(algo)
|
|
|
|
data, err := os.ReadFile(dataPath)
|
|
require.NoError(t, err)
|
|
|
|
dataHash := sha3.Sum256(data)
|
|
|
|
return Computation{
|
|
ID: "1",
|
|
Name: "sample computation",
|
|
Description: "sample description",
|
|
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
|
|
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
|
|
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
|
|
}
|
|
}
|
|
|
|
func TestStopComputation(t *testing.T) {
|
|
testDataDir := "test_datasets"
|
|
testResultsDir := "test_results"
|
|
|
|
cases := []struct {
|
|
name string
|
|
setupDirs bool
|
|
setupAlgo bool
|
|
algoStopErr error
|
|
expectedErr error
|
|
}{
|
|
{
|
|
name: "Stop computation successfully",
|
|
setupDirs: true,
|
|
setupAlgo: true,
|
|
algoStopErr: nil,
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "Stop computation with algorithm stop error",
|
|
setupDirs: true,
|
|
setupAlgo: true,
|
|
algoStopErr: fmt.Errorf("algorithm stop failed"),
|
|
expectedErr: fmt.Errorf("error stopping computation: algorithm stop failed"),
|
|
},
|
|
{
|
|
name: "Stop computation without algorithm",
|
|
setupDirs: true,
|
|
setupAlgo: false,
|
|
algoStopErr: nil,
|
|
expectedErr: nil,
|
|
},
|
|
{
|
|
name: "Stop computation with missing directories",
|
|
setupDirs: false,
|
|
setupAlgo: false,
|
|
algoStopErr: nil,
|
|
expectedErr: nil, // os.RemoveAll doesn't error on non-existing directories
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
events := new(mocks.Service)
|
|
events.On("SendEvent", mock.Anything, "Stopped", "Stopped", mock.Anything).Return()
|
|
|
|
ctx := context.Background()
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0).(*agentService)
|
|
|
|
svc.computation = Computation{
|
|
ID: "test-computation",
|
|
Name: "test",
|
|
}
|
|
|
|
if tc.setupDirs {
|
|
err := os.MkdirAll(testDataDir, 0o755)
|
|
require.NoError(t, err)
|
|
err = os.MkdirAll(testResultsDir, 0o755)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
if tc.setupAlgo {
|
|
mockAlgo := new(algomocks.Algorithm)
|
|
mockAlgo.On("Stop").Return(tc.algoStopErr)
|
|
svc.algorithm = mockAlgo
|
|
}
|
|
|
|
err := svc.StopComputation(ctx)
|
|
|
|
if tc.expectedErr != nil {
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tc.expectedErr.Error())
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
assert.Equal(t, ReceivingManifest, svc.sm.GetState())
|
|
assert.Nil(t, svc.result)
|
|
assert.Nil(t, svc.runError)
|
|
assert.False(t, svc.resultsConsumed)
|
|
|
|
events.AssertExpectations(t)
|
|
|
|
_ = os.RemoveAll(testDataDir)
|
|
_ = os.RemoveAll(testResultsDir)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStopComputationIntegration(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping integration test in short mode")
|
|
}
|
|
|
|
algo := []byte("#!/bin/bash\necho 'test algorithm'")
|
|
algoHash := sha3.Sum256(algo)
|
|
|
|
testDir := "test_integration"
|
|
err := os.MkdirAll(testDir, 0o755)
|
|
require.NoError(t, err)
|
|
defer os.RemoveAll(testDir)
|
|
|
|
algoFile := filepath.Join(testDir, "test_algo")
|
|
err = os.WriteFile(algoFile, algo, 0o755)
|
|
require.NoError(t, err)
|
|
|
|
events := new(mocks.Service)
|
|
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
ctx := metadata.NewIncomingContext(context.Background(),
|
|
metadata.Pairs(algorithm.AlgoTypeKey, "bin"),
|
|
)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0)
|
|
|
|
computation := Computation{
|
|
ID: "integration-test",
|
|
Name: "Integration Test",
|
|
Algorithm: Algorithm{
|
|
Hash: algoHash,
|
|
Algorithm: algo,
|
|
},
|
|
}
|
|
|
|
err = svc.InitComputation(ctx, computation)
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
err = svc.Algo(ctx, Algorithm{
|
|
Hash: algoHash,
|
|
Algorithm: algo,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
err = svc.StopComputation(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, "ReceivingManifest", svc.State())
|
|
}
|
|
|
|
func TestStopComputationConcurrent(t *testing.T) {
|
|
events := new(mocks.Service)
|
|
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
|
|
|
ctx := context.Background()
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
svc := New(ctx, mglog.NewMock(), events, &attestation.EmptyProvider{}, 0)
|
|
|
|
svc.(*agentService).computation = Computation{
|
|
ID: "concurrent-test",
|
|
Name: "Concurrent Test",
|
|
}
|
|
|
|
const numGoroutines = 10
|
|
errChan := make(chan error, numGoroutines)
|
|
|
|
for i := 0; i < numGoroutines; i++ {
|
|
go func() {
|
|
err := svc.StopComputation(ctx)
|
|
errChan <- err
|
|
}()
|
|
}
|
|
|
|
var errors []error
|
|
for i := 0; i < numGoroutines; i++ {
|
|
err := <-errChan
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
|
|
assert.True(t, len(errors) < numGoroutines, "All StopComputation calls failed")
|
|
}
|