Files
cocos/agent/service_test.go
Sammy Kerata Oina 6169766666
CI / lint (push) Has been cancelled
CI / test (agent) (push) Has been cancelled
CI / test (cli) (push) Has been cancelled
CI / test (cmd) (push) Has been cancelled
CI / test (internal) (push) Has been cancelled
CI / test (manager, true) (push) Has been cancelled
CI / test (pkg) (push) Has been cancelled
CI / upload-coverage (push) Has been cancelled
NOISSUE - Fix agent startup issues (#605)
* Update attestationFromCert function to include ccPlatform parameter for enhanced attestation processing

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: migrate dependencies from supermq to magistrala and update build configurations

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* chore: update project dependencies, repository source, and support TDX QuoteV5 attestation

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
2026-06-11 17:08:24 +02:00

1811 lines
52 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"archive/tar"
"archive/zip"
"bytes"
"compress/gzip"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"log"
"log/slog"
"os"
"path/filepath"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/python"
agentevents "github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/agent/events/mocks"
runnerpb "github.com/ultravioletrs/cocos/agent/runner"
"github.com/ultravioletrs/cocos/agent/statemachine"
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
"github.com/ultravioletrs/cocos/pkg/attestation"
"github.com/ultravioletrs/cocos/pkg/attestation/vtpm"
runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks"
"github.com/ultravioletrs/cocos/pkg/oci"
"github.com/ultravioletrs/cocos/pkg/resource"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/emptypb"
)
type MockOCIClient struct {
mock.Mock
}
func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error {
args := m.Called(ctx, source, destDir)
return args.Error(0)
}
func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error {
args := m.Called(ctx, ociDir, destFile)
return args.Error(0)
}
var (
algoPath = "../test/manual/algo/lin_reg.py"
reqPath = "../test/manual/algo/requirements.txt"
dataPath = "../test/manual/data/iris.csv"
)
const datasetFile = "iris.csv"
func TestAlgo(t *testing.T) {
algo, err := os.ReadFile(algoPath)
require.NoError(t, err)
algoHash := sha3.Sum256(algo)
vtpm.ExternalTPM = &vtpm.DummyRWC{}
reqFile, err := os.ReadFile(reqPath)
require.NoError(t, err)
testCases := []struct {
name string
err error
algo Algorithm
algoType string
}{
{
name: "Test Algo successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "python",
err: nil,
},
{
name: "Test Algo successfully with requirements file",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
Requirements: reqFile,
},
algoType: "python",
err: nil,
},
{
name: "Test Algo type binary successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "bin",
err: nil,
},
{
name: "Test Algo type wasm successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "wasm",
err: nil,
},
{
name: "Test Algo type docker successfully",
algo: Algorithm{
Algorithm: algo,
Hash: algoHash,
},
algoType: "docker",
err: nil,
},
{
name: "Test algo hash mismatch",
algo: Algorithm{},
algoType: "python",
err: ErrHashMismatch,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err = os.RemoveAll("datasets")
require.NoError(t, err)
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
)
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
time.Sleep(300 * time.Millisecond)
err = svc.Algo(ctx, tc.algo)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
t.Cleanup(func() {
err = os.RemoveAll("venv")
err = os.RemoveAll("algo")
err = os.RemoveAll("datasets")
})
})
}
}
func TestData(t *testing.T) {
algo, err := os.ReadFile(algoPath)
require.NoError(t, err)
algoHash := sha3.Sum256(algo)
vtpm.ExternalTPM = &vtpm.DummyRWC{}
alg := Algorithm{
Hash: algoHash,
Algorithm: algo,
}
data, err := os.ReadFile(dataPath)
require.NoError(t, err)
dataHash := sha3.Sum256(data)
cases := []struct {
name string
data Dataset
err error
}{
{
name: "Test data successfully",
data: Dataset{
Hash: dataHash,
Dataset: data,
Filename: datasetFile,
},
},
{
name: "Test State not ready",
data: Dataset{
Dataset: data,
Hash: dataHash,
Filename: datasetFile,
},
err: ErrStateNotReady,
},
{
name: "Test File name does not match manifest",
data: Dataset{
Dataset: data,
Hash: dataHash,
Filename: "invalid",
},
err: ErrFileNameMismatch,
},
{
name: "Test dataset not declared in manifest",
data: Dataset{
Filename: datasetFile,
},
err: ErrUndeclaredDataset,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(
algorithm.AlgoTypeKey, "python",
python.PyRuntimeKey, python.PyRuntime),
)
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
if tc.err != ErrUndeclaredDataset {
ctx = IndexToContext(ctx, 0)
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
time.Sleep(300 * time.Millisecond)
if tc.err != ErrStateNotReady {
err = svc.Algo(ctx, alg)
require.NoError(t, err)
time.Sleep(300 * time.Millisecond)
}
err = svc.Data(ctx, tc.data)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
t.Cleanup(func() {
_ = os.RemoveAll("datasets")
_ = os.RemoveAll("results")
err = os.RemoveAll("venv")
err = os.RemoveAll("algo")
})
})
}
}
func TestResult(t *testing.T) {
cases := []struct {
name string
err error
setup func(svc *agentService)
ctxSetup func(ctx context.Context) context.Context
state statemachine.State
}{
{
name: "Test results not ready",
err: ErrResultsNotReady,
setup: func(svc *agentService) {
},
state: Running,
},
{
name: "Test undeclared consumer",
err: ErrUndeclaredConsumer,
setup: func(svc *agentService) {
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
},
ctxSetup: func(ctx context.Context) context.Context {
return ctx
},
state: ConsumingResults,
},
{
name: "Test results consumed and event sent",
err: nil,
setup: func(svc *agentService) {
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
},
ctxSetup: func(ctx context.Context) context.Context {
return IndexToContext(ctx, 0)
},
state: ConsumingResults,
},
}
for _, tc := range cases {
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
)
if tc.ctxSetup != nil {
ctx = tc.ctxSetup(ctx)
}
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
sm := new(smmocks.StateMachine)
sm.On("Start", ctx).Return(nil)
sm.On("GetState").Return(tc.state)
sm.On("SendEvent", mock.Anything).Return()
svc := &agentService{
sm: sm,
eventSvc: events,
attestationClient: client,
runnerClient: runnerCli,
computation: testComputation(t),
}
go func() {
if err := svc.sm.Start(ctx); err != nil {
t.Errorf("Error starting state machine: %v", err)
}
}()
tc.setup(svc)
_, err := svc.Result(ctx)
t.Cleanup(func() {
_ = os.RemoveAll("datasets")
_ = os.RemoveAll("results")
})
assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err)
})
}
}
func TestAttestation(t *testing.T) {
client := new(MockAttestationClient)
cases := []struct {
name string
reportData [vtpm.SEVNonce]byte
nonce [vtpm.Nonce]byte
rawQuote []uint8
platform attestation.PlatformType
err error
}{
{
name: "Test SNP attestation successful",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: make([]uint8, 0),
platform: attestation.SNP,
err: nil,
},
{
name: "Test SNP attestation failed",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: nil,
platform: attestation.SNP,
err: ErrAttestationFailed,
},
{
name: "Test vTPM attestation successful",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: make([]uint8, 0),
platform: attestation.VTPM,
err: nil,
},
{
name: "Test vTPM attestation failed",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: nil,
platform: attestation.VTPM,
err: ErrAttestationVTpmFailed,
},
{
name: "Test SNP-vTPM attestation successful",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: make([]uint8, 0),
platform: attestation.SNPvTPM,
err: nil,
},
{
name: "Test SNP-vTPM attestation failed",
reportData: generateReportData(),
nonce: [32]byte{},
rawQuote: nil,
platform: attestation.SNPvTPM,
err: ErrAttestationVTpmFailed,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
getQuote := client.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err)
if tc.err != ErrAttestationFailed && tc.err != ErrAttestationVTpmFailed {
getQuote = client.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.nonce[:], nil)
}
defer getQuote.Unset()
runnerCli := new(runnermocks.Client)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
time.Sleep(300 * time.Millisecond)
_, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
})
}
}
func TestAzureAttestationToken(t *testing.T) {
client := new(MockAttestationClient)
cases := []struct {
name string
nonce [vtpm.Nonce]byte
token []byte
err error
}{
{
name: "Azure token fetch successful",
nonce: [32]byte{1, 2, 3}, // any test nonce
token: []byte("mockToken"),
err: nil,
},
{
name: "Azure token fetch failed",
nonce: [32]byte{4, 5, 6},
token: []byte{},
err: ErrAttestationType,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
client.On("GetAzureToken", mock.Anything, tc.nonce).Return(tc.token, tc.err)
ctx := context.Background()
runnerCli := new(runnermocks.Client)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
_, err := svc.AzureAttestationToken(ctx, tc.nonce)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
})
}
}
func generateReportData() [vtpm.SEVNonce]byte {
bytes := make([]byte, vtpm.SEVNonce)
_, err := rand.Read(bytes)
if err != nil {
log.Fatalf("Failed to generate random bytes: %v", err)
}
return [64]byte(bytes)
}
func testComputation(t *testing.T) Computation {
algo, err := os.ReadFile(algoPath)
require.NoError(t, err)
algoHash := sha3.Sum256(algo)
data, err := os.ReadFile(dataPath)
require.NoError(t, err)
dataHash := sha3.Sum256(data)
return Computation{
ID: "1",
Name: "sample computation",
Description: "sample description",
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
Algorithm: &Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
}
}
func TestStopComputation(t *testing.T) {
cases := []struct {
name string
setupDirs bool
setupAlgo bool
algoStopErr error
expectedErr error
}{
{
name: "Stop computation successfully",
setupDirs: true,
setupAlgo: true,
algoStopErr: nil,
expectedErr: nil,
},
{
name: "Stop computation with algorithm stop error",
setupDirs: true,
setupAlgo: true,
algoStopErr: fmt.Errorf("algorithm stop failed"),
expectedErr: nil, // Warn only
},
// We log warnings but don't return error in StopComputation in new implementation for Stop failure.
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
events := new(mocks.Service)
events.On("SendEvent", mock.Anything, "Stopped", "Stopped", mock.Anything).Return()
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
// Mock Stop call
var stopErr error
if tc.algoStopErr != nil {
stopErr = tc.algoStopErr
}
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, stopErr)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0).(*agentService)
svc.computation = Computation{
ID: "test-computation",
Name: "test",
}
if tc.setupDirs {
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
err = os.MkdirAll(algorithm.ResultsDir, 0o755)
require.NoError(t, err)
}
// Use real dirs for test
// algorithm.DatasetsDir refers to global var?
// "github.com/ultravioletrs/cocos/agent/algorithm"
// It uses hardcoded path "datasets" and "results" in current dir.
// Tests create them in current dir.
err := svc.StopComputation(ctx)
if tc.expectedErr != nil {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedErr.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, ReceivingManifest, svc.sm.GetState())
assert.Nil(t, svc.result)
assert.Nil(t, svc.runError)
assert.False(t, svc.resultsConsumed)
events.AssertExpectations(t)
_ = os.RemoveAll(algorithm.DatasetsDir)
_ = os.RemoveAll(algorithm.ResultsDir)
})
}
}
func TestStopComputationIntegration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
algo := []byte("#!/bin/bash\necho 'test algorithm'")
algoHash := sha3.Sum256(algo)
testDir := "test_integration"
err := os.MkdirAll(testDir, 0o755)
require.NoError(t, err)
defer os.RemoveAll(testDir)
algoFile := filepath.Join(testDir, "test_algo")
err = os.WriteFile(algoFile, algo, 0o755)
require.NoError(t, err)
events := new(mocks.Service)
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "bin"),
)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
computation := Computation{
ID: "integration-test",
Name: "Integration Test",
Algorithm: &Algorithm{
Hash: algoHash,
Algorithm: algo,
},
}
err = svc.InitComputation(ctx, computation)
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = svc.Algo(ctx, Algorithm{
Hash: algoHash,
Algorithm: algo,
})
require.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = svc.StopComputation(ctx)
assert.NoError(t, err)
assert.Equal(t, "ReceivingManifest", svc.State())
}
func TestStopComputationConcurrent(t *testing.T) {
events := new(mocks.Service)
events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
client := new(MockAttestationClient)
runnerCli := new(runnermocks.Client)
runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil)
svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0)
svc.(*agentService).computation = Computation{
ID: "concurrent-test",
Name: "Concurrent Test",
}
const numGoroutines = 10
errChan := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
err := svc.StopComputation(ctx)
errChan <- err
}()
}
var errors []error
for i := 0; i < numGoroutines; i++ {
err := <-errChan
if err != nil {
errors = append(errors, err)
}
}
assert.True(t, len(errors) < numGoroutines, "All StopComputation calls failed")
}
// newTestAgentService creates a minimal agentService for direct method testing.
func newTestAgentService(sm statemachine.StateMachine, eventSvc agentevents.Service) *agentService {
return &agentService{
logger: slog.Default(),
eventSvc: eventSvc,
sm: sm,
}
}
func TestDownloadAndDecryptResource(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", mock.Anything).Return().Maybe()
svc := newTestAgentService(sm, eventsSvc)
ctx := context.Background()
t.Run("unsupported URL format no type", func(t *testing.T) {
source := &ResourceSource{URL: "abc://unsupported-format"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported source URL format")
})
t.Run("ftp URL unsupported format", func(t *testing.T) {
source := &ResourceSource{URL: "ftp://some-server/file"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported source URL format")
})
t.Run("unsupported explicit source type", func(t *testing.T) {
source := &ResourceSource{Type: "s3-bucket", URL: "s3://mybucket/algo"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported source type: s3-bucket")
})
t.Run("bare OCI image name inferred as oci-image", func(t *testing.T) {
source := &ResourceSource{URL: "ubuntu:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
// Should route to OCI and fail at OCI client (which is nil or mock)
assert.NotContains(t, err.Error(), "unsupported source URL format")
})
t.Run("bare registry image name inferred as oci-image", func(t *testing.T) {
source := &ResourceSource{URL: "gcr.io/project/image:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.NotContains(t, err.Error(), "unsupported source URL format")
})
t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) {
// This exercises the oci-image path; will fail at skopeo step
source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
// Should be a skopeo or OCI error, not an "unsupported" error
assert.NotContains(t, err.Error(), "unsupported source URL format")
})
t.Run("oci: URL inferred as oci-image routes to skopeo", func(t *testing.T) {
source := &ResourceSource{URL: "oci:some-local-dir"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.NotContains(t, err.Error(), "unsupported source URL format")
})
t.Run("explicit oci-image type routes to skopeo", func(t *testing.T) {
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/algo:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.NotContains(t, err.Error(), "unsupported source type")
})
t.Run("dataset resource type with oci-image routes to skopeo", func(t *testing.T) {
source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "dataset")
require.Error(t, err)
assert.NotContains(t, err.Error(), "unsupported source type")
})
t.Run("https inferred routes to registry", func(t *testing.T) {
// Mock registry to fail predictably
source := &ResourceSource{URL: "https://example.com/file.bin"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
// It should complain about registry missing, because the test service does not initialize the registry
assert.Contains(t, err.Error(), "resource registry not initialized")
})
t.Run("s3 inferred routes to registry", func(t *testing.T) {
source := &ResourceSource{URL: "s3://bucket/key"}
_, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm")
require.Error(t, err)
assert.Contains(t, err.Error(), "resource registry not initialized")
})
}
func TestDownloadAlgorithmIfRemote(t *testing.T) {
t.Run("no source configured - no-op, waits for direct upload", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
// No SendEvent expected — just the no-op path
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{} // Algorithm.Source == nil
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Nil(t, svc.runError)
sm.AssertExpectations(t)
})
t.Run("source set but KBS disabled - no-op", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Algorithm: &Algorithm{
Source: &ResourceSource{URL: "docker://registry/algo:latest"},
KBS: &KBSConfig{Enabled: false},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Nil(t, svc.runError)
sm.AssertExpectations(t)
})
t.Run("source + KBS enabled - download fails, sends RunFailed", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Algorithm: &Algorithm{
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://invalid.example.com/algo:latest",
},
KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.NotNil(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "failed to download and decrypt algorithm")
sm.AssertExpectations(t)
})
t.Run("unsupported URL format - download fails, sends RunFailed", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Algorithm: &Algorithm{
Source: &ResourceSource{
URL: "http://unsupported-format/algo",
},
KBS: &KBSConfig{Enabled: true},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.NotNil(t, svc.runError)
sm.AssertExpectations(t)
})
}
func TestDownloadDatasetsIfRemote(t *testing.T) {
t.Run("no datasets with remote sources - no-op", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
svc := newTestAgentService(sm, eventsSvc)
// Dataset with no Source
dataHash := sha3.Sum256([]byte("testdata"))
svc.computation = Computation{
Datasets: []Dataset{
{Hash: dataHash, Filename: "data.csv"},
},
}
svc.downloadDatasetsIfRemote(ReceivingData)
// No RunFailed event, no DataReceived event
sm.AssertExpectations(t)
})
t.Run("no datasets at all - no-op", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Datasets: []Dataset{},
}
svc.downloadDatasetsIfRemote(ReceivingData)
sm.AssertExpectations(t)
})
t.Run("KBS disabled even with source - no-op", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Source: &ResourceSource{URL: "docker://registry/data:latest"},
KBS: &KBSConfig{Enabled: false},
},
},
}
svc.downloadDatasetsIfRemote(ReceivingData)
sm.AssertExpectations(t)
})
t.Run("remote dataset + KBS enabled - download fails, sends RunFailed", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://invalid.example.com/data:latest",
},
KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
},
},
}
svc.downloadDatasetsIfRemote(ReceivingData)
sm.AssertExpectations(t)
})
t.Run("unsupported URL fails - sends RunFailed", func(t *testing.T) {
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Source: &ResourceSource{
URL: "http://unsupported-format/data",
},
KBS: &KBSConfig{Enabled: true},
},
},
}
svc.downloadDatasetsIfRemote(ReceivingData)
sm.AssertExpectations(t)
})
}
func TestRunComputation(t *testing.T) {
// Helper to set up a temp working directory and restore CWD afterwards.
withTempDir := func(t *testing.T) (tmpDir string, restore func()) {
t.Helper()
origDir, err := os.Getwd()
require.NoError(t, err)
tmpDir = t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
return tmpDir, func() { _ = os.Chdir(origDir) }
}
t.Run("algo file not found sends RunFailed", func(t *testing.T) {
_, restore := withTempDir(t)
defer restore()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
// No algo file exists runComputation should hit the ReadFile error path.
svc.runComputation(Running)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "failed to read algo file")
sm.AssertExpectations(t)
})
t.Run("runner client returns error sends RunFailed", func(t *testing.T) {
_, restore := withTempDir(t)
defer restore()
// Write a dummy algo file so ReadFile succeeds.
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return((*runnerpb.RunResponse)(nil), fmt.Errorf("runner unavailable"))
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
svc.runnerClient = runnerCli
svc.runComputation(Running)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "runner unavailable")
sm.AssertExpectations(t)
})
t.Run("runner returns non-empty error field sends RunFailed", func(t *testing.T) {
_, restore := withTempDir(t)
defer restore()
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{Error: "computation crashed"}, nil)
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
svc := newTestAgentService(sm, eventsSvc)
svc.runnerClient = runnerCli
svc.runComputation(Running)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "computation crashed")
sm.AssertExpectations(t)
})
}
func TestIMAMeasurements(t *testing.T) {
t.Run("error when IMA measurements file does not exist in non-SGX environment", func(t *testing.T) {
// In a regular test environment (non-SGX), the IMA measurements file
// at /sys/kernel/security/integrity/ima/ascii_runtime_measurements won't exist.
// Verify our error handling works correctly.
origPath := ImaMeasurementsFilePath
ImaMeasurementsFilePath = "/non/existent/path"
defer func() { ImaMeasurementsFilePath = origPath }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
svc := newTestAgentService(sm, eventsSvc)
data, pcr10, err := svc.IMAMeasurements(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "error reading Linux IMA measurements file")
assert.Nil(t, data)
assert.Nil(t, pcr10)
})
t.Run("successful reading of IMA measurements", func(t *testing.T) {
tempFile := filepath.Join(t.TempDir(), "ima_measurements")
content := []byte("10 sha1:0000000000000000000000000000000000000000 ima-ng sha256:0000000000000000000000000000000000000000000000000000000000000000 /usr/bin/python3\n")
err := os.WriteFile(tempFile, content, 0o644)
require.NoError(t, err)
vtpm.ExternalTPM = &vtpm.DummyRWC{}
origPath := ImaMeasurementsFilePath
ImaMeasurementsFilePath = tempFile
defer func() { ImaMeasurementsFilePath = origPath }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
svc := newTestAgentService(sm, eventsSvc)
data, pcr10, err := svc.IMAMeasurements(context.Background())
assert.NoError(t, err)
assert.Equal(t, content, data)
assert.NotEmpty(t, pcr10)
})
}
func TestDownloadAlgorithmIfRemote_Success(t *testing.T) {
// Skip this test in short mode as it might involve more setup if we were using real OCI
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", AlgorithmReceived).Return().Once()
mockOCI := new(MockOCIClient)
algoContent := []byte("print('hello')")
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", algoContent)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
algoContent = []byte("print('hello')")
algoHash := sha3.Sum256(algoContent)
svc.computation = Computation{
Algorithm: &Algorithm{
Hash: algoHash,
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-success",
},
KBS: &KBSConfig{Enabled: true},
},
}
// We need to bypass oci.ExtractAlgorithm by manually creating what it would create
// OR use a real-enough looking OCI layout.
// Since we can't easily mock oci.ExtractAlgorithm, we'll try to provide a minimal OCI layout
// so that oci.ExtractAlgorithm doesn't fail.
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Nil(t, svc.runError)
assert.True(t, svc.algoReceived)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", AlgorithmReceived).Return().Once()
mockOCI := new(MockOCIClient)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Return(nil)
dummyContent := []byte("dummy docker tar")
dummyHash := sha3.Sum256(dummyContent)
mockOCI.On("ToDockerArchive", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destFile := args.String(2)
err := os.WriteFile(destFile, dummyContent, 0o644)
require.NoError(t, err)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: &Algorithm{
AlgoType: "docker",
Hash: dummyHash,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-docker-success",
},
KBS: &KBSConfig{Enabled: true},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Nil(t, svc.runError)
assert.True(t, svc.algoReceived)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func setupMinimalOCI(t *testing.T, ociDir, filename string, content []byte) {
t.Helper()
blobsDir := filepath.Join(ociDir, "blobs", "sha256")
require.NoError(t, os.MkdirAll(blobsDir, 0o755))
layerPath := filepath.Join(blobsDir, "layer123")
layerFile, err := os.Create(layerPath)
require.NoError(t, err)
gw := gzip.NewWriter(layerFile)
tw := tar.NewWriter(gw)
hdr := &tar.Header{
Name: filename,
Mode: 0o755,
Size: int64(len(content)),
}
require.NoError(t, tw.WriteHeader(hdr))
_, err = tw.Write(content)
require.NoError(t, err)
require.NoError(t, tw.Close())
require.NoError(t, gw.Close())
require.NoError(t, layerFile.Close())
manifest := struct {
Layers []struct {
Digest string `json:"digest"`
} `json:"layers"`
}{
Layers: []struct {
Digest string `json:"digest"`
}{{Digest: "sha256:layer123"}},
}
manifestData, err := json.Marshal(manifest)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644))
index := oci.OCIIndex{
SchemaVersion: 2,
Manifests: []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
Size int `json:"size"`
}{{Digest: "sha256:manifest123", Size: len(manifestData)}},
}
indexData, err := json.Marshal(index)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644))
}
func TestDownloadDatasetsIfRemote_Success(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", DataReceived).Return().Once()
mockOCI := new(MockOCIClient)
dataContent := []byte("a,b,c\n1,2,3")
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", dataContent)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
dataContent = []byte("a,b,c\n1,2,3")
dataHash := sha3.Sum256(dataContent)
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: dataHash,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-success",
},
KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"},
},
},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
assert.Nil(t, svc.runError)
assert.Len(t, svc.computation.Datasets, 0)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", DataReceived).Return().Maybe()
sm.On("SendEvent", RunFailed).Return().Maybe()
mockOCI := new(MockOCIClient)
// Create a zip file in memory
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
f, err := zw.Create("test.txt")
require.NoError(t, err)
_, err = f.Write([]byte("hello zip"))
require.NoError(t, err)
require.NoError(t, zw.Close())
zipData := buf.Bytes()
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.zip", zipData)
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
dataHash := sha3.Sum256(zipData)
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.zip",
Hash: dataHash,
Decompress: true,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-decompress",
},
KBS: &KBSConfig{Enabled: true},
},
},
}
err = os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
assert.Nil(t, svc.runError)
assert.Len(t, svc.computation.Datasets, 0)
// Check if file was decompressed
decompressedFile := filepath.Join(algorithm.DatasetsDir, "test.txt")
_, err = os.Stat(decompressedFile)
assert.NoError(t, err)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
t.Run("hash mismatch", func(t *testing.T) {
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", []byte("wrong content"))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: &Algorithm{
Hash: sha3.Sum256([]byte("expected content")),
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-hash-mismatch",
},
KBS: &KBSConfig{Enabled: true},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "algorithm hash mismatch")
sm.AssertExpectations(t)
})
t.Run("create algo file failure", func(t *testing.T) {
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
// Create a directory named "algo" to make file creation fail
require.NoError(t, os.Mkdir("algo", 0o755))
defer os.RemoveAll("algo")
mockOCI := new(MockOCIClient)
algoContent := "print(1)"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", []byte(algoContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: &Algorithm{
Hash: sha3.Sum256([]byte(algoContent)),
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-create-fail",
},
KBS: &KBSConfig{Enabled: true},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "error creating algorithm file")
sm.AssertExpectations(t)
})
t.Run("extraction failure", func(t *testing.T) {
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
// Setup OCI with NO main.py or any algorithm file
require.NoError(t, os.MkdirAll(filepath.Join(destDir, "blobs"), 0o755))
// Create a legit-looking but empty index.json
require.NoError(t, os.WriteFile(filepath.Join(destDir, "index.json"), []byte(`{"schemaVersion":2,"manifests":[]}`), 0o644))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Algorithm: &Algorithm{
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/image",
},
KBS: &KBSConfig{Enabled: true},
},
}
svc.downloadAlgorithmIfRemote(ReceivingAlgorithm)
assert.Error(t, svc.runError)
assert.Contains(t, svc.runError.Error(), "no manifests found")
sm.AssertExpectations(t)
})
}
func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
// Use a fresh mock in each subtest to avoid state pollution
t.Run("dataset create file failure", func(t *testing.T) {
eventsSvc := mocks.NewService(t)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
// Create a directory named "data.csv" in datasets dir to make file creation fail
require.NoError(t, os.MkdirAll(filepath.Join(algorithm.DatasetsDir, "data.csv"), 0o755))
defer os.RemoveAll(algorithm.DatasetsDir)
mockOCI := new(MockOCIClient)
dataContent := "a,b,c"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: sha3.Sum256([]byte(dataContent)),
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-create-fail",
},
KBS: &KBSConfig{Enabled: true},
},
},
}
svc.downloadDatasetsIfRemote(ReceivingData)
sm.AssertExpectations(t)
})
t.Run("dataset hash mismatch", func(t *testing.T) {
eventsSvc := mocks.NewService(t)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { _ = os.Chdir(origDir) }()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
dataContent := "wrong content"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: sha3.Sum256([]byte("expected content")),
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-mismatch",
},
KBS: &KBSConfig{Enabled: true},
},
},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
if svc.runError == nil {
t.Fatalf("runError should not be nil in hash mismatch test")
}
assert.Contains(t, svc.runError.Error(), "dataset data.csv hash mismatch")
sm.AssertExpectations(t)
})
t.Run("dataset unzip failure", func(t *testing.T) {
eventsSvc := mocks.NewService(t)
eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe()
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { _ = os.Chdir(origDir) }()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunFailed).Return().Once()
mockOCI := new(MockOCIClient)
// Provide invalid zip content
dataContent := "not a zip file"
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.zip", []byte(dataContent))
}).Return(nil)
svc := newTestAgentService(sm, eventsSvc)
svc.ociClient = mockOCI
svc.computation = Computation{
Datasets: []Dataset{
{
Filename: "data.zip",
Hash: sha3.Sum256([]byte(dataContent)),
Decompress: true,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-unzip-fail",
},
KBS: &KBSConfig{Enabled: true},
},
},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
svc.downloadDatasetsIfRemote(ReceivingData)
if svc.runError == nil {
t.Fatalf("runError should not be nil in unzip failure test")
}
assert.Contains(t, svc.runError.Error(), "failed to unzip dataset")
sm.AssertExpectations(t)
})
}
func TestAlgo_RemoteSource(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("GetState").Return(ReceivingAlgorithm)
sm.On("SendEvent", AlgorithmReceived).Return().Once()
mockOCI := new(MockOCIClient)
algoContent := []byte("print('remote algo')")
algoHash := sha3.Sum256(algoContent)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "main.py", algoContent)
}).Return(nil)
svc := &agentService{
logger: slog.Default(),
eventSvc: eventsSvc,
sm: sm,
ociClient: mockOCI,
computation: Computation{
Algorithm: &Algorithm{
Hash: algoHash,
AlgoType: "python",
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/algo-remote",
},
KBS: &KBSConfig{Enabled: true},
},
},
}
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python"))
err := svc.Algo(ctx, Algorithm{})
assert.NoError(t, err)
assert.True(t, svc.algoReceived)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestData_RemoteSource(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("GetState").Return(ReceivingData)
sm.On("SendEvent", DataReceived).Return().Once()
mockOCI := new(MockOCIClient)
dataContent := []byte("remote data")
dataHash := sha3.Sum256(dataContent)
mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
destDir := args.String(2)
setupMinimalOCI(t, destDir, "data.csv", dataContent)
}).Return(nil)
svc := &agentService{
logger: slog.Default(),
eventSvc: eventsSvc,
sm: sm,
ociClient: mockOCI,
computation: Computation{
Datasets: []Dataset{
{
Filename: "data.csv",
Hash: dataHash,
Source: &ResourceSource{
Type: "oci-image",
URL: "docker://test/data-remote",
},
KBS: &KBSConfig{Enabled: true},
},
},
},
}
err := os.MkdirAll(algorithm.DatasetsDir, 0o755)
require.NoError(t, err)
ctx := context.Background()
err = svc.Data(ctx, Dataset{})
assert.NoError(t, err)
assert.Len(t, svc.computation.Datasets, 0)
sm.AssertExpectations(t)
mockOCI.AssertExpectations(t)
}
func TestRunComputation_Success(t *testing.T) {
origDir, _ := os.Getwd()
tmpDir := t.TempDir()
require.NoError(t, os.Chdir(tmpDir))
defer func() { require.NoError(t, os.Chdir(origDir)) }()
// Write a dummy algo file
require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755))
runnerCli := new(runnermocks.Client)
runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil)
eventsSvc := new(mocks.Service)
eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe()
sm := &smmocks.StateMachine{}
sm.On("SendEvent", RunComplete).Return().Once()
svc := &agentService{
logger: slog.Default(),
eventSvc: eventsSvc,
sm: sm,
runnerClient: runnerCli,
computation: Computation{ID: "test-run"},
}
svc.runComputation(Running)
assert.Nil(t, svc.runError)
sm.AssertExpectations(t)
runnerCli.AssertExpectations(t)
}
func TestInferSourceType(t *testing.T) {
testCases := []struct {
url string
expected string
}{
{"docker://test/repo", resource.SourceTypeOCIImage},
{"oci:test/repo", resource.SourceTypeOCIImage},
{"s3://bucket/key", resource.SourceTypeS3},
{"gs://bucket/key", resource.SourceTypeGCS},
{"https://example.com/file", resource.SourceTypeHTTPS},
{"http://example.com/file", resource.SourceTypeHTTP},
{"abc://example.com/file", ""},
{"ftp://example.com/file", ""},
{"unknown://example.com/file", ""},
{"malformed-url", ""},
{"", ""},
}
for _, tc := range testCases {
t.Run(tc.url, func(t *testing.T) {
result := inferSourceType(tc.url)
assert.Equal(t, tc.expected, result)
})
}
}