mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
b44780df95
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
* feat: Enhance OCI image extraction to return algorithm and requirements paths, and add deferred cleanup for temporary files. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: implement deterministic zipping and enhance checksum verification for resources Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Update component build sources, add gRPC health checks to the CVM server, and refine algorithm argument handling and documentation. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * docs: Update remote resources testing guide with `sudo` for KBS, algorithm result saving, `requirements.txt`, and `algo-args` for RVPS. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: Explicitly ignore `stderr.Write` return values and add minor whitespace in tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: add comprehensive error path and edge case tests for file, zip, OCI, and agent components. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add mutexes for thread-safe algorithm execution and expand recognized data file extensions to include common archive formats. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * feat: Add OCI extraction tests for Python algorithms and multi-layer datasets, refactor algorithm execution for testability, and enhance algorithm stop and error handling tests. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Add error assertions to OCI extraction test helpers and remove an unused mock exec command. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * test: Improve error handling test coverage for algorithm execution and OCI resource extraction. Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix: Improve algorithm process termination, enhance computation error handling, and add concurrency safety to agent service. Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
1754 lines
50 KiB
Go
1754 lines
50 KiB
Go
// Copyright (c) Ultraviolet
|
||
// SPDX-License-Identifier: Apache-2.0
|
||
package agent
|
||
|
||
import (
|
||
"archive/tar"
|
||
"archive/zip"
|
||
"bytes"
|
||
"compress/gzip"
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"log/slog"
|
||
"os"
|
||
"path/filepath"
|
||
"testing"
|
||
"time"
|
||
|
||
mglog "github.com/absmach/supermq/logger"
|
||
"github.com/absmach/supermq/pkg/errors"
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/mock"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||
agentevents "github.com/ultravioletrs/cocos/agent/events"
|
||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
|
||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
|
||
"github.com/ultravioletrs/cocos/pkg/attestation"
|
||
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
|
||
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
|
||
"github.com/ultravioletrs/cocos/pkg/oci"
|
||
"golang.org/x/crypto/sha3"
|
||
"google.golang.org/grpc/metadata"
|
||
"google.golang.org/protobuf/types/known/emptypb"
|
||
)
|
||
|
||
type MockOCIClient struct {
|
||
mock.Mock
|
||
}
|
||
|
||
func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error {
|
||
args := m.Called(ctx, source, destDir)
|
||
return args.Error(0)
|
||
}
|
||
|
||
func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
|
||
args := m.Called(ctx, ociDir, destFile)
|
||
return args.Error(0)
|
||
}
|
||
|
||
var (
|
||
algoPath = "../test/manual/algo/lin_reg.py"
|
||
reqPath = "../test/manual/algo/requirements.txt"
|
||
dataPath = "../test/manual/data/iris.csv"
|
||
)
|
||
|
||
const datasetFile = "iris.csv"
|
||
|
||
func TestAlgo(t *testing.T) {
|
||
algo, err := os.ReadFile(algoPath)
|
||
require.NoError(t, err)
|
||
|
||
algoHash := sha3.Sum256(algo)
|
||
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
||
|
||
reqFile, err := os.ReadFile(reqPath)
|
||
require.NoError(t, err)
|
||
|
||
testCases := []struct {
|
||
name string
|
||
err error
|
||
algo Algorithm
|
||
algoType string
|
||
}{
|
||
{
|
||
name: "Test Algo successfully",
|
||
algo: Algorithm{
|
||
Algorithm: algo,
|
||
Hash: algoHash,
|
||
},
|
||
algoType: "python",
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test Algo successfully with requirements file",
|
||
algo: Algorithm{
|
||
Algorithm: algo,
|
||
Hash: algoHash,
|
||
Requirements: reqFile,
|
||
},
|
||
algoType: "python",
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test Algo type binary successfully",
|
||
algo: Algorithm{
|
||
Algorithm: algo,
|
||
Hash: algoHash,
|
||
},
|
||
algoType: "bin",
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test Algo type wasm successfully",
|
||
algo: Algorithm{
|
||
Algorithm: algo,
|
||
Hash: algoHash,
|
||
},
|
||
algoType: "wasm",
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test Algo type docker successfully",
|
||
algo: Algorithm{
|
||
Algorithm: algo,
|
||
Hash: algoHash,
|
||
},
|
||
algoType: "docker",
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test algo hash mismatch",
|
||
algo: Algorithm{},
|
||
algoType: "python",
|
||
err: ErrHashMismatch,
|
||
},
|
||
}
|
||
|
||
for _, tc := range testCases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
err = os.RemoveAll("datasets")
|
||
require.NoError(t, err)
|
||
|
||
ctx := metadata.NewIncomingContext(context.Background(),
|
||
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
|
||
)
|
||
|
||
events := new(mocks.Service)
|
||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
client := new(MockAttestationClient)
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||
|
||
err := svc.InitComputation(ctx, testComputation(t))
|
||
require.NoError(t, err)
|
||
|
||
time.Sleep(300 * time.Millisecond)
|
||
|
||
err = svc.Algo(ctx, tc.algo)
|
||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||
t.Cleanup(func() {
|
||
err = os.RemoveAll("venv")
|
||
err = os.RemoveAll("algo")
|
||
err = os.RemoveAll("datasets")
|
||
})
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestData(t *testing.T) {
|
||
algo, err := os.ReadFile(algoPath)
|
||
require.NoError(t, err)
|
||
|
||
algoHash := sha3.Sum256(algo)
|
||
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
||
|
||
alg := Algorithm{
|
||
Hash: algoHash,
|
||
Algorithm: algo,
|
||
}
|
||
|
||
data, err := os.ReadFile(dataPath)
|
||
require.NoError(t, err)
|
||
|
||
dataHash := sha3.Sum256(data)
|
||
|
||
cases := []struct {
|
||
name string
|
||
data Dataset
|
||
err error
|
||
}{
|
||
{
|
||
name: "Test data successfully",
|
||
data: Dataset{
|
||
Hash: dataHash,
|
||
Dataset: data,
|
||
Filename: datasetFile,
|
||
},
|
||
},
|
||
{
|
||
name: "Test State not ready",
|
||
data: Dataset{
|
||
Dataset: data,
|
||
Hash: dataHash,
|
||
Filename: datasetFile,
|
||
},
|
||
err: ErrStateNotReady,
|
||
},
|
||
{
|
||
name: "Test File name does not match manifest",
|
||
data: Dataset{
|
||
Dataset: data,
|
||
Hash: dataHash,
|
||
Filename: "invalid",
|
||
},
|
||
err: ErrFileNameMismatch,
|
||
},
|
||
{
|
||
name: "Test dataset not declared in manifest",
|
||
data: Dataset{
|
||
Filename: datasetFile,
|
||
},
|
||
err: ErrUndeclaredDataset,
|
||
},
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
ctx := metadata.NewIncomingContext(context.Background(),
|
||
metadata.Pairs(
|
||
algorithm.AlgoTypeKey, "python",
|
||
python.PyRuntimeKey, python.PyRuntime),
|
||
)
|
||
|
||
events := new(mocks.Service)
|
||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
if tc.err != ErrUndeclaredDataset {
|
||
ctx = IndexToContext(ctx, 0)
|
||
}
|
||
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
client := new(MockAttestationClient)
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||
|
||
err := svc.InitComputation(ctx, testComputation(t))
|
||
require.NoError(t, err)
|
||
|
||
time.Sleep(300 * time.Millisecond)
|
||
|
||
if tc.err != ErrStateNotReady {
|
||
err = svc.Algo(ctx, alg)
|
||
require.NoError(t, err)
|
||
time.Sleep(300 * time.Millisecond)
|
||
}
|
||
err = svc.Data(ctx, tc.data)
|
||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||
t.Cleanup(func() {
|
||
_ = os.RemoveAll("datasets")
|
||
_ = os.RemoveAll("results")
|
||
err = os.RemoveAll("venv")
|
||
err = os.RemoveAll("algo")
|
||
})
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestResult(t *testing.T) {
|
||
cases := []struct {
|
||
name string
|
||
err error
|
||
setup func(svc *agentService)
|
||
ctxSetup func(ctx context.Context) context.Context
|
||
state statemachine.State
|
||
}{
|
||
{
|
||
name: "Test results not ready",
|
||
err: ErrResultsNotReady,
|
||
setup: func(svc *agentService) {
|
||
},
|
||
state: Running,
|
||
},
|
||
{
|
||
name: "Test undeclared consumer",
|
||
err: ErrUndeclaredConsumer,
|
||
setup: func(svc *agentService) {
|
||
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
|
||
},
|
||
ctxSetup: func(ctx context.Context) context.Context {
|
||
return ctx
|
||
},
|
||
state: ConsumingResults,
|
||
},
|
||
{
|
||
name: "Test results consumed and event sent",
|
||
err: nil,
|
||
setup: func(svc *agentService) {
|
||
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
|
||
},
|
||
ctxSetup: func(ctx context.Context) context.Context {
|
||
return IndexToContext(ctx, 0)
|
||
},
|
||
state: ConsumingResults,
|
||
},
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
events := new(mocks.Service)
|
||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
ctx := metadata.NewIncomingContext(context.Background(),
|
||
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
||
)
|
||
|
||
if tc.ctxSetup != nil {
|
||
ctx = tc.ctxSetup(ctx)
|
||
}
|
||
|
||
client := new(MockAttestationClient)
|
||
runnerCli := new(runnermocks.Client)
|
||
|
||
sm := new(smmocks.StateMachine)
|
||
sm.On("Start", ctx).Return(nil)
|
||
sm.On("GetState").Return(tc.state)
|
||
sm.On("SendEvent", mock.Anything).Return()
|
||
|
||
svc := &agentService{
|
||
sm: sm,
|
||
eventSvc: events,
|
||
attestationClient: client,
|
||
runnerClient: runnerCli,
|
||
computation: testComputation(t),
|
||
}
|
||
|
||
go func() {
|
||
if err := svc.sm.Start(ctx); err != nil {
|
||
t.Errorf("Error starting state machine: %v", err)
|
||
}
|
||
}()
|
||
tc.setup(svc)
|
||
_, err := svc.Result(ctx)
|
||
t.Cleanup(func() {
|
||
_ = os.RemoveAll("datasets")
|
||
_ = os.RemoveAll("results")
|
||
})
|
||
assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAttestation(t *testing.T) {
|
||
client := new(MockAttestationClient)
|
||
|
||
cases := []struct {
|
||
name string
|
||
reportData [vtpm.SEVNonce]byte
|
||
nonce [vtpm.Nonce]byte
|
||
rawQuote []uint8
|
||
platform attestation.PlatformType
|
||
err error
|
||
}{
|
||
{
|
||
name: "Test SNP attestation successful",
|
||
reportData: generateReportData(),
|
||
nonce: [32]byte{},
|
||
rawQuote: make([]uint8, 0),
|
||
platform: attestation.SNP,
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test SNP attestation failed",
|
||
reportData: generateReportData(),
|
||
nonce: [32]byte{},
|
||
rawQuote: nil,
|
||
platform: attestation.SNP,
|
||
err: ErrAttestationFailed,
|
||
},
|
||
{
|
||
name: "Test vTPM attestation successful",
|
||
reportData: generateReportData(),
|
||
nonce: [32]byte{},
|
||
rawQuote: make([]uint8, 0),
|
||
platform: attestation.VTPM,
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test vTPM attestation failed",
|
||
reportData: generateReportData(),
|
||
nonce: [32]byte{},
|
||
rawQuote: nil,
|
||
platform: attestation.VTPM,
|
||
err: ErrAttestationVTpmFailed,
|
||
},
|
||
{
|
||
name: "Test SNP-vTPM attestation successful",
|
||
reportData: generateReportData(),
|
||
nonce: [32]byte{},
|
||
rawQuote: make([]uint8, 0),
|
||
platform: attestation.SNPvTPM,
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Test SNP-vTPM attestation failed",
|
||
reportData: generateReportData(),
|
||
nonce: [32]byte{},
|
||
rawQuote: nil,
|
||
platform: attestation.SNPvTPM,
|
||
err: ErrAttestationVTpmFailed,
|
||
},
|
||
}
|
||
for _, tc := range cases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
events := new(mocks.Service)
|
||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
ctx := metadata.NewIncomingContext(context.Background(),
|
||
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
||
)
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
getQuote := client.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err)
|
||
if tc.err != ErrAttestationFailed && tc.err != ErrAttestationVTpmFailed {
|
||
getQuote = client.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.nonce[:], nil)
|
||
}
|
||
defer getQuote.Unset()
|
||
|
||
runnerCli := new(runnermocks.Client)
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||
time.Sleep(300 * time.Millisecond)
|
||
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform)
|
||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestAzureAttestationToken(t *testing.T) {
|
||
client := new(MockAttestationClient)
|
||
cases := []struct {
|
||
name string
|
||
nonce [vtpm.Nonce]byte
|
||
token []byte
|
||
err error
|
||
}{
|
||
{
|
||
name: "Azure token fetch successful",
|
||
nonce: [32]byte{1, 2, 3}, // any test nonce
|
||
token: []byte("mockToken"),
|
||
err: nil,
|
||
},
|
||
{
|
||
name: "Azure token fetch failed",
|
||
nonce: [32]byte{4, 5, 6},
|
||
token: []byte{},
|
||
err: ErrAttestationType,
|
||
},
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
events := new(mocks.Service)
|
||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
client.On("GetAzureToken", mock.Anything, tc.nonce).Return(tc.token, tc.err)
|
||
|
||
ctx := context.Background()
|
||
|
||
runnerCli := new(runnermocks.Client)
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||
|
||
_, err := svc.AzureAttestationToken(ctx, tc.nonce)
|
||
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
||
})
|
||
}
|
||
}
|
||
|
||
func generateReportData() [vtpm.SEVNonce]byte {
|
||
bytes := make([]byte, vtpm.SEVNonce)
|
||
_, err := rand.Read(bytes)
|
||
if err != nil {
|
||
log.Fatalf("Failed to generate random bytes: %v", err)
|
||
}
|
||
return [64]byte(bytes)
|
||
}
|
||
|
||
func testComputation(t *testing.T) Computation {
|
||
algo, err := os.ReadFile(algoPath)
|
||
require.NoError(t, err)
|
||
|
||
algoHash := sha3.Sum256(algo)
|
||
|
||
data, err := os.ReadFile(dataPath)
|
||
require.NoError(t, err)
|
||
|
||
dataHash := sha3.Sum256(data)
|
||
|
||
return Computation{
|
||
ID: "1",
|
||
Name: "sample computation",
|
||
Description: "sample description",
|
||
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
|
||
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
|
||
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
|
||
}
|
||
}
|
||
|
||
func TestStopComputation(t *testing.T) {
|
||
cases := []struct {
|
||
name string
|
||
setupDirs bool
|
||
setupAlgo bool
|
||
algoStopErr error
|
||
expectedErr error
|
||
}{
|
||
{
|
||
name: "Stop computation successfully",
|
||
setupDirs: true,
|
||
setupAlgo: true,
|
||
algoStopErr: nil,
|
||
expectedErr: nil,
|
||
},
|
||
{
|
||
name: "Stop computation with algorithm stop error",
|
||
setupDirs: true,
|
||
setupAlgo: true,
|
||
algoStopErr: fmt.Errorf("algorithm stop failed"),
|
||
expectedErr: nil, // Warn only
|
||
},
|
||
// We log warnings but don't return error in StopComputation in new implementation for Stop failure.
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
events := new(mocks.Service)
|
||
events.On("SendEvent", mock.Anything, "Stopped", "Stopped", mock.Anything).Return()
|
||
|
||
ctx := context.Background()
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
client := new(MockAttestationClient)
|
||
runnerCli := new(runnermocks.Client)
|
||
|
||
// Mock Stop call
|
||
var stopErr error
|
||
if tc.algoStopErr != nil {
|
||
stopErr = tc.algoStopErr
|
||
}
|
||
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, stopErr)
|
||
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0).(*agentService)
|
||
|
||
svc.computation = Computation{
|
||
ID: "test-computation",
|
||
Name: "test",
|
||
}
|
||
|
||
if tc.setupDirs {
|
||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||
require.NoError(t, err)
|
||
err = os.MkdirAll(algorithm.ResultsDir, 0o755)
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
// Use real dirs for test
|
||
// algorithm.DatasetsDir refers to global var?
|
||
// "github.com/ultravioletrs/cocos/agent/algorithm"
|
||
// It uses hardcoded path "datasets" and "results" in current dir.
|
||
// Tests create them in current dir.
|
||
|
||
err := svc.StopComputation(ctx)
|
||
|
||
if tc.expectedErr != nil {
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), tc.expectedErr.Error())
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
assert.Equal(t, ReceivingManifest, svc.sm.GetState())
|
||
assert.Nil(t, svc.result)
|
||
assert.Nil(t, svc.runError)
|
||
assert.False(t, svc.resultsConsumed)
|
||
|
||
events.AssertExpectations(t)
|
||
|
||
_ = os.RemoveAll(algorithm.DatasetsDir)
|
||
_ = os.RemoveAll(algorithm.ResultsDir)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestStopComputationIntegration(t *testing.T) {
|
||
if testing.Short() {
|
||
t.Skip("Skipping integration test in short mode")
|
||
}
|
||
|
||
algo := []byte("#!/bin/bash\necho 'test algorithm'")
|
||
algoHash := sha3.Sum256(algo)
|
||
|
||
testDir := "test_integration"
|
||
err := os.MkdirAll(testDir, 0o755)
|
||
require.NoError(t, err)
|
||
defer os.RemoveAll(testDir)
|
||
|
||
algoFile := filepath.Join(testDir, "test_algo")
|
||
err = os.WriteFile(algoFile, algo, 0o755)
|
||
require.NoError(t, err)
|
||
|
||
events := new(mocks.Service)
|
||
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
ctx := metadata.NewIncomingContext(context.Background(),
|
||
metadata.Pairs(algorithm.AlgoTypeKey, "bin"),
|
||
)
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
client := new(MockAttestationClient)
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
|
||
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||
|
||
computation := Computation{
|
||
ID: "integration-test",
|
||
Name: "Integration Test",
|
||
Algorithm: Algorithm{
|
||
Hash: algoHash,
|
||
Algorithm: algo,
|
||
},
|
||
}
|
||
|
||
err = svc.InitComputation(ctx, computation)
|
||
require.NoError(t, err)
|
||
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
err = svc.Algo(ctx, Algorithm{
|
||
Hash: algoHash,
|
||
Algorithm: algo,
|
||
})
|
||
require.NoError(t, err)
|
||
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
err = svc.StopComputation(ctx)
|
||
assert.NoError(t, err)
|
||
|
||
assert.Equal(t, "ReceivingManifest", svc.State())
|
||
}
|
||
|
||
func TestStopComputationConcurrent(t *testing.T) {
|
||
events := new(mocks.Service)
|
||
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||
|
||
ctx := context.Background()
|
||
ctx, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
|
||
client := new(MockAttestationClient)
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
|
||
|
||
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
|
||
|
||
svc.(*agentService).computation = Computation{
|
||
ID: "concurrent-test",
|
||
Name: "Concurrent Test",
|
||
}
|
||
|
||
const numGoroutines = 10
|
||
errChan := make(chan error, numGoroutines)
|
||
|
||
for i := 0; i < numGoroutines; i++ {
|
||
go func() {
|
||
err := svc.StopComputation(ctx)
|
||
errChan <- err
|
||
}()
|
||
}
|
||
|
||
var errors []error
|
||
for i := 0; i < numGoroutines; i++ {
|
||
err := <-errChan
|
||
if err != nil {
|
||
errors = append(errors, err)
|
||
}
|
||
}
|
||
|
||
assert.True(t, len(errors) < numGoroutines, "All StopComputation calls failed")
|
||
}
|
||
|
||
// newTestAgentService creates a minimal agentService for direct method testing.
|
||
func newTestAgentService(sm statemachine.StateMachine, eventSvc agentevents.Service) *agentService {
|
||
return &agentService{
|
||
logger: slog.Default(),
|
||
eventSvc: eventSvc,
|
||
sm: sm,
|
||
}
|
||
}
|
||
|
||
func TestDownloadAndDecryptResource(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", mock.Anything).Return().Maybe()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
|
||
ctx := context.Background()
|
||
|
||
t.Run("unsupported URL format no type", func(t *testing.T) {
|
||
source := &ResourceSource{URL: "http://unsupported-format"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "unsupported source URL format")
|
||
})
|
||
|
||
t.Run("ftp URL unsupported format", func(t *testing.T) {
|
||
source := &ResourceSource{URL: "ftp://some-server/file"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "unsupported source URL format")
|
||
})
|
||
|
||
t.Run("unsupported explicit source type", func(t *testing.T) {
|
||
source := &ResourceSource{Type: "s3-bucket", URL: "s3://mybucket/algo"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "unsupported source type: s3-bucket")
|
||
})
|
||
|
||
t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) {
|
||
// This exercises the oci-image path; will fail at skopeo step
|
||
source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||
require.Error(t, err)
|
||
// Should be a skopeo or OCI error, not an "unsupported" error
|
||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||
})
|
||
|
||
t.Run("oci: URL inferred as oci-image routes to skopeo", func(t *testing.T) {
|
||
source := &ResourceSource{URL: "oci:some-local-dir"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||
require.Error(t, err)
|
||
assert.NotContains(t, err.Error(), "unsupported source URL format")
|
||
})
|
||
|
||
t.Run("explicit oci-image type routes to skopeo", func(t *testing.T) {
|
||
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/algo:latest"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "algorithm")
|
||
require.Error(t, err)
|
||
assert.NotContains(t, err.Error(), "unsupported source type")
|
||
})
|
||
|
||
t.Run("dataset resource type with oci-image", func(t *testing.T) {
|
||
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"}
|
||
_, err := svc.downloadAndDecryptResource(ctx, source, "dataset")
|
||
require.Error(t, err)
|
||
})
|
||
}
|
||
|
||
func TestDownloadAlgorithmIfRemote(t *testing.T) {
|
||
t.Run("no source configured - no-op, waits for direct upload", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
// No SendEvent expected — just the no-op path
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{} // Algorithm.Source == nil
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.Nil(t, svc.runError)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("source set but KBS disabled - no-op", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
Source: &ResourceSource{URL: "docker://registry/algo:latest"},
|
||
},
|
||
KBS: KBSConfig{Enabled: false},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.Nil(t, svc.runError)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("source + KBS enabled - download fails, sends RunFailed", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://invalid.example.com/algo:latest",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.NotNil(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "failed to download and decrypt algorithm")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("unsupported URL format - download fails, sends RunFailed", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
Source: &ResourceSource{
|
||
URL: "http://unsupported-format/algo",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.NotNil(t, svc.runError)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
}
|
||
|
||
func TestDownloadDatasetsIfRemote(t *testing.T) {
|
||
t.Run("no datasets with remote sources - no-op", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
// Dataset with no Source
|
||
dataHash := sha3.Sum256([]byte("testdata"))
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{Hash: dataHash, Filename: "data.csv"},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
// No RunFailed event, no DataReceived event
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("no datasets at all - no-op", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("KBS disabled even with source - no-op", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Source: &ResourceSource{URL: "docker://registry/data:latest"},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: false},
|
||
}
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("remote dataset + KBS enabled - download fails, sends RunFailed", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://invalid.example.com/data:latest",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
|
||
}
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("unsupported URL fails - sends RunFailed", func(t *testing.T) {
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Source: &ResourceSource{
|
||
URL: "ftp://unsupported/data",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
}
|
||
|
||
func TestRunComputation(t *testing.T) {
|
||
// Helper to set up a temp working directory and restore CWD afterwards.
|
||
withTempDir := func(t *testing.T) (tmpDir string, restore func()) {
|
||
t.Helper()
|
||
origDir, err := os.Getwd()
|
||
require.NoError(t, err)
|
||
tmpDir = t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
return tmpDir, func() { _ = os.Chdir(origDir) }
|
||
}
|
||
|
||
t.Run("algo file not found sends RunFailed", func(t *testing.T) {
|
||
_, restore := withTempDir(t)
|
||
defer restore()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
// No algo file exists – runComputation should hit the ReadFile error path.
|
||
svc.runComputation(Running)
|
||
|
||
assert.Error(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "failed to read algo file")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("runner client returns error sends RunFailed", func(t *testing.T) {
|
||
_, restore := withTempDir(t)
|
||
defer restore()
|
||
|
||
// Write a dummy algo file so ReadFile succeeds.
|
||
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
|
||
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Run", mock.Anything, mock.Anything).Return((*runnerpb.RunResponse)(nil), fmt.Errorf("runner unavailable"))
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.runnerClient = runnerCli
|
||
|
||
svc.runComputation(Running)
|
||
|
||
assert.Error(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "runner unavailable")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("runner returns non-empty error field sends RunFailed", func(t *testing.T) {
|
||
_, restore := withTempDir(t)
|
||
defer restore()
|
||
|
||
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
|
||
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{Error: "computation crashed"}, nil)
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.runnerClient = runnerCli
|
||
|
||
svc.runComputation(Running)
|
||
|
||
assert.Error(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "computation crashed")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
}
|
||
|
||
func TestIMAMeasurements(t *testing.T) {
|
||
t.Run("error when IMA measurements file does not exist in non-SGX environment", func(t *testing.T) {
|
||
// In a regular test environment (non-SGX), the IMA measurements file
|
||
// at /sys/kernel/security/integrity/ima/ascii_runtime_measurements won't exist.
|
||
// Verify our error handling works correctly.
|
||
origPath := ImaMeasurementsFilePath
|
||
ImaMeasurementsFilePath = "/non/existent/path"
|
||
defer func() { ImaMeasurementsFilePath = origPath }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
|
||
data, pcr10, err := svc.IMAMeasurements(context.Background())
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "error reading Linux IMA measurements file")
|
||
assert.Nil(t, data)
|
||
assert.Nil(t, pcr10)
|
||
})
|
||
|
||
t.Run("successful reading of IMA measurements", func(t *testing.T) {
|
||
tempFile := filepath.Join(t.TempDir(), "ima_measurements")
|
||
content := []byte("10 sha1:0000000000000000000000000000000000000000 ima-ng sha256:0000000000000000000000000000000000000000000000000000000000000000 /usr/bin/python3\n")
|
||
err := os.WriteFile(tempFile, content, 0o644)
|
||
require.NoError(t, err)
|
||
vtpm.ExternalTPM = &vtpm.DummyRWC{}
|
||
|
||
origPath := ImaMeasurementsFilePath
|
||
ImaMeasurementsFilePath = tempFile
|
||
defer func() { ImaMeasurementsFilePath = origPath }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
|
||
data, pcr10, err := svc.IMAMeasurements(context.Background())
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, content, data)
|
||
assert.NotEmpty(t, pcr10)
|
||
})
|
||
}
|
||
|
||
func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
|
||
// Skip this test in short mode as it might involve more setup if we were using real OCI
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", AlgorithmReceived).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
algoContent := []byte("print('hello')")
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "main.py", algoContent)
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
algoContent = []byte("print('hello')")
|
||
algoHash := sha3.Sum256(algoContent)
|
||
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
Hash: algoHash,
|
||
AlgoType: "python",
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/algo-success",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
// We need to bypass oci.ExtractAlgorithm by manually creating what it would create
|
||
// OR use a real-enough looking OCI layout.
|
||
// Since we can't easily mock oci.ExtractAlgorithm, we'll try to provide a minimal OCI layout
|
||
// so that oci.ExtractAlgorithm doesn't fail.
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
|
||
assert.Nil(t, svc.runError)
|
||
assert.True(t, svc.algoReceived)
|
||
sm.AssertExpectations(t)
|
||
mockOCI.AssertExpectations(t)
|
||
}
|
||
|
||
func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", AlgorithmReceived).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||
|
||
dummyContent := []byte("dummy docker tar")
|
||
dummyHash := sha3.Sum256(dummyContent)
|
||
|
||
mockOCI.On("ToDockerArchive", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destFile := args.String(2)
|
||
err := os.WriteFile(destFile, dummyContent, 0o644)
|
||
require.NoError(t, err)
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
AlgoType: "docker",
|
||
Hash: dummyHash,
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/algo-docker-success",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
|
||
assert.Nil(t, svc.runError)
|
||
assert.True(t, svc.algoReceived)
|
||
sm.AssertExpectations(t)
|
||
mockOCI.AssertExpectations(t)
|
||
}
|
||
|
||
func setupMinimalOCI(t *testing.T, ociDir, filename string, content []byte) {
|
||
t.Helper()
|
||
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
|
||
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
|
||
|
||
layerPath := filepath.Join(blobsDir, "layer123")
|
||
layerFile, err := os.Create(layerPath)
|
||
require.NoError(t, err)
|
||
|
||
gw := gzip.NewWriter(layerFile)
|
||
tw := tar.NewWriter(gw)
|
||
|
||
hdr := &tar.Header{
|
||
Name: filename,
|
||
Mode: 0o755,
|
||
Size: int64(len(content)),
|
||
}
|
||
require.NoError(t, tw.WriteHeader(hdr))
|
||
_, err = tw.Write(content)
|
||
|
||
require.NoError(t, err)
|
||
|
||
require.NoError(t, tw.Close())
|
||
require.NoError(t, gw.Close())
|
||
require.NoError(t, layerFile.Close())
|
||
|
||
manifest := struct {
|
||
Layers []struct {
|
||
Digest string `json:"digest"`
|
||
} `json:"layers"`
|
||
}{
|
||
Layers: []struct {
|
||
Digest string `json:"digest"`
|
||
}{{Digest: "sha256:layer123"}},
|
||
}
|
||
manifestData, err := json.Marshal(manifest)
|
||
require.NoError(t, err)
|
||
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
|
||
|
||
index := oci.OCIIndex{
|
||
SchemaVersion: 2,
|
||
Manifests: []struct {
|
||
MediaType string `json:"mediaType"`
|
||
Digest string `json:"digest"`
|
||
Size int `json:"size"`
|
||
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
|
||
}
|
||
indexData, err := json.Marshal(index)
|
||
require.NoError(t, err)
|
||
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
|
||
}
|
||
|
||
func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", DataReceived).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
dataContent := []byte("a,b,c\n1,2,3")
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "data.csv", dataContent)
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
dataContent = []byte("a,b,c\n1,2,3")
|
||
dataHash := sha3.Sum256(dataContent)
|
||
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Hash: dataHash,
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/data-success",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||
require.NoError(t, err)
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
|
||
assert.Nil(t, svc.runError)
|
||
assert.Len(t, svc.computation.Datasets, 0)
|
||
sm.AssertExpectations(t)
|
||
mockOCI.AssertExpectations(t)
|
||
}
|
||
|
||
func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", DataReceived).Return().Maybe()
|
||
sm.On("SendEvent", RunFailed).Return().Maybe()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
|
||
// Create a zip file in memory
|
||
var buf bytes.Buffer
|
||
zw := zip.NewWriter(&buf)
|
||
f, err := zw.Create("test.txt")
|
||
require.NoError(t, err)
|
||
_, err = f.Write([]byte("hello zip"))
|
||
require.NoError(t, err)
|
||
require.NoError(t, zw.Close())
|
||
zipData := buf.Bytes()
|
||
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "data.zip", zipData)
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
dataHash := sha3.Sum256(zipData)
|
||
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.zip",
|
||
Hash: dataHash,
|
||
Decompress: true,
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/data-decompress",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
err = os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||
require.NoError(t, err)
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
|
||
assert.Nil(t, svc.runError)
|
||
assert.Len(t, svc.computation.Datasets, 0)
|
||
// Check if file was decompressed
|
||
decompressedFile := filepath.Join(algorithm.DatasetsDir, "test.txt")
|
||
_, err = os.Stat(decompressedFile)
|
||
assert.NoError(t, err)
|
||
|
||
sm.AssertExpectations(t)
|
||
mockOCI.AssertExpectations(t)
|
||
}
|
||
|
||
func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
t.Run("hash mismatch", func(t *testing.T) {
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "main.py", []byte("wrong content"))
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
Hash: sha3.Sum256([]byte("expected content")),
|
||
AlgoType: "python",
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/algo-hash-mismatch",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.Error(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "algorithm hash mismatch")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("create algo file failure", func(t *testing.T) {
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
// Create a directory named "algo" to make file creation fail
|
||
require.NoError(t, os.Mkdir("algo", 0o755))
|
||
defer os.RemoveAll("algo")
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
algoContent := "print(1)"
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "main.py", []byte(algoContent))
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
Hash: sha3.Sum256([]byte(algoContent)),
|
||
AlgoType: "python",
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/algo-create-fail",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.Error(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "error creating algorithm file")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
t.Run("extraction failure", func(t *testing.T) {
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
// Setup OCI with NO main.py or any algorithm file
|
||
require.NoError(t, os.MkdirAll(filepath.Join(destDir, "blobs"), 0o755))
|
||
// Create a legit-looking but empty index.json
|
||
require.NoError(t, os.WriteFile(filepath.Join(destDir, "index.json"), []byte(`{"schemaVersion":2,"manifests":[]}`), 0o644))
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Algorithm: Algorithm{
|
||
AlgoType: "python",
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/image",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
|
||
assert.Error(t, svc.runError)
|
||
assert.Contains(t, svc.runError.Error(), "no manifests found")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
}
|
||
|
||
func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
// Use a fresh mock in each subtest to avoid state pollution
|
||
|
||
t.Run("dataset create file failure", func(t *testing.T) {
|
||
eventsSvc := mocks.NewService(t)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
// Create a directory named "data.csv" in datasets dir to make file creation fail
|
||
require.NoError(t, os.MkdirAll(filepath.Join(algorithm.DatasetsDir, "data.csv"), 0o755))
|
||
defer os.RemoveAll(algorithm.DatasetsDir)
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
dataContent := "a,b,c"
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Hash: sha3.Sum256([]byte(dataContent)),
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/data-create-fail",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("dataset hash mismatch", func(t *testing.T) {
|
||
eventsSvc := mocks.NewService(t)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { _ = os.Chdir(origDir) }()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
dataContent := "wrong content"
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Hash: sha3.Sum256([]byte("expected content")),
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/data-mismatch",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||
require.NoError(t, err)
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
if svc.runError == nil {
|
||
t.Fatalf("runError should not be nil in hash mismatch test")
|
||
}
|
||
assert.Contains(t, svc.runError.Error(), "dataset data.csv hash mismatch")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
|
||
t.Run("dataset unzip failure", func(t *testing.T) {
|
||
eventsSvc := mocks.NewService(t)
|
||
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { _ = os.Chdir(origDir) }()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunFailed).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
// Provide invalid zip content
|
||
dataContent := "not a zip file"
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "data.zip", []byte(dataContent))
|
||
}).Return(nil)
|
||
|
||
svc := newTestAgentService(sm, eventsSvc)
|
||
svc.ociClient = mockOCI
|
||
|
||
svc.computation = Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.zip",
|
||
Hash: sha3.Sum256([]byte(dataContent)),
|
||
Decompress: true,
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/data-unzip-fail",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
}
|
||
|
||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||
require.NoError(t, err)
|
||
|
||
svc.downloadDatasetsIfRemote(ReceivingData)
|
||
if svc.runError == nil {
|
||
t.Fatalf("runError should not be nil in unzip failure test")
|
||
}
|
||
assert.Contains(t, svc.runError.Error(), "failed to unzip dataset")
|
||
sm.AssertExpectations(t)
|
||
})
|
||
}
|
||
|
||
func TestAlgo_RemoteSource(t *testing.T) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("GetState").Return(ReceivingAlgorithm)
|
||
sm.On("SendEvent", AlgorithmReceived).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
algoContent := []byte("print('remote algo')")
|
||
algoHash := sha3.Sum256(algoContent)
|
||
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "main.py", algoContent)
|
||
}).Return(nil)
|
||
|
||
svc := &agentService{
|
||
logger: slog.Default(),
|
||
eventSvc: eventsSvc,
|
||
sm: sm,
|
||
ociClient: mockOCI,
|
||
computation: Computation{
|
||
Algorithm: Algorithm{
|
||
Hash: algoHash,
|
||
AlgoType: "python",
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/algo-remote",
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
},
|
||
}
|
||
|
||
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python"))
|
||
err := svc.Algo(ctx, Algorithm{})
|
||
assert.NoError(t, err)
|
||
assert.True(t, svc.algoReceived)
|
||
sm.AssertExpectations(t)
|
||
mockOCI.AssertExpectations(t)
|
||
}
|
||
|
||
func TestData_RemoteSource(t *testing.T) {
|
||
if testing.Short() {
|
||
t.Skip("skipping in short mode")
|
||
}
|
||
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("GetState").Return(ReceivingData)
|
||
sm.On("SendEvent", DataReceived).Return().Once()
|
||
|
||
mockOCI := new(MockOCIClient)
|
||
dataContent := []byte("remote data")
|
||
dataHash := sha3.Sum256(dataContent)
|
||
|
||
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||
destDir := args.String(2)
|
||
setupMinimalOCI(t, destDir, "data.csv", dataContent)
|
||
}).Return(nil)
|
||
|
||
svc := &agentService{
|
||
logger: slog.Default(),
|
||
eventSvc: eventsSvc,
|
||
sm: sm,
|
||
ociClient: mockOCI,
|
||
computation: Computation{
|
||
Datasets: []Dataset{
|
||
{
|
||
Filename: "data.csv",
|
||
Hash: dataHash,
|
||
Source: &ResourceSource{
|
||
Type: "oci-image",
|
||
URL: "docker://test/data-remote",
|
||
},
|
||
},
|
||
},
|
||
KBS: KBSConfig{Enabled: true},
|
||
},
|
||
}
|
||
|
||
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
|
||
require.NoError(t, err)
|
||
|
||
ctx := context.Background()
|
||
err = svc.Data(ctx, Dataset{})
|
||
assert.NoError(t, err)
|
||
assert.Len(t, svc.computation.Datasets, 0)
|
||
sm.AssertExpectations(t)
|
||
mockOCI.AssertExpectations(t)
|
||
}
|
||
|
||
func TestRunComputation_Success(t *testing.T) {
|
||
origDir, _ := os.Getwd()
|
||
tmpDir := t.TempDir()
|
||
require.NoError(t, os.Chdir(tmpDir))
|
||
defer func() { require.NoError(t, os.Chdir(origDir)) }()
|
||
|
||
// Write a dummy algo file
|
||
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
|
||
|
||
runnerCli := new(runnermocks.Client)
|
||
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
|
||
|
||
eventsSvc := new(mocks.Service)
|
||
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
|
||
|
||
sm := &smmocks.StateMachine{}
|
||
sm.On("SendEvent", RunComplete).Return().Once()
|
||
|
||
svc := &agentService{
|
||
logger: slog.Default(),
|
||
eventSvc: eventsSvc,
|
||
sm: sm,
|
||
runnerClient: runnerCli,
|
||
computation: Computation{ID: "test-run"},
|
||
}
|
||
|
||
svc.runComputation(Running)
|
||
|
||
assert.Nil(t, svc.runError)
|
||
sm.AssertExpectations(t)
|
||
runnerCli.AssertExpectations(t)
|
||
}
|