// Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 package agent import ( "archive/tar" "archive/zip" "bytes" "compress/gzip" "context" "crypto/rand" "encoding/json" "fmt" "log" "log/slog" "os" "path/filepath" "testing" "time" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/algorithm/python" agentevents "github.com/ultravioletrs/cocos/agent/events" "github.com/ultravioletrs/cocos/agent/events/mocks" runnerpb "github.com/ultravioletrs/cocos/agent/runner" "github.com/ultravioletrs/cocos/agent/statemachine" smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "github.com/ultravioletrs/cocos/pkg/attestation" "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" runnermocks "github.com/ultravioletrs/cocos/pkg/clients/grpc/runner/mocks" "github.com/ultravioletrs/cocos/pkg/oci" "github.com/ultravioletrs/cocos/pkg/resource" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/emptypb" ) type MockOCIClient struct { mock.Mock } func (m *MockOCIClient) PullAndDecrypt(ctx context.Context, source oci.ResourceSource, destDir string) error { args := m.Called(ctx, source, destDir) return args.Error(0) } func (m *MockOCIClient) ToDockerArchive(ctx context.Context, ociDir, destFile string) error { args := m.Called(ctx, ociDir, destFile) return args.Error(0) } var ( algoPath = "../test/manual/algo/lin_reg.py" reqPath = "../test/manual/algo/requirements.txt" dataPath = "../test/manual/data/iris.csv" ) const datasetFile = "iris.csv" func TestAlgo(t *testing.T) { algo, err := os.ReadFile(algoPath) require.NoError(t, err) algoHash := sha3.Sum256(algo) vtpm.ExternalTPM = &vtpm.DummyRWC{} reqFile, err := os.ReadFile(reqPath) require.NoError(t, err) testCases := []struct { name string err error algo Algorithm algoType string }{ { name: "Test Algo successfully", algo: Algorithm{ Algorithm: algo, Hash: algoHash, }, algoType: "python", err: nil, }, { name: "Test Algo successfully with requirements file", algo: Algorithm{ Algorithm: algo, Hash: algoHash, Requirements: reqFile, }, algoType: "python", err: nil, }, { name: "Test Algo type binary successfully", algo: Algorithm{ Algorithm: algo, Hash: algoHash, }, algoType: "bin", err: nil, }, { name: "Test Algo type wasm successfully", algo: Algorithm{ Algorithm: algo, Hash: algoHash, }, algoType: "wasm", err: nil, }, { name: "Test Algo type docker successfully", algo: Algorithm{ Algorithm: algo, Hash: algoHash, }, algoType: "docker", err: nil, }, { name: "Test algo hash mismatch", algo: Algorithm{}, algoType: "python", err: ErrHashMismatch, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err = os.RemoveAll("datasets") require.NoError(t, err) ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime), ) events := new(mocks.Service) events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() ctx, cancel := context.WithCancel(ctx) defer cancel() client := new(MockAttestationClient) runnerCli := new(runnermocks.Client) runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) err := svc.InitComputation(ctx, testComputation(t)) require.NoError(t, err) time.Sleep(300 * time.Millisecond) err = svc.Algo(ctx, tc.algo) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) t.Cleanup(func() { err = os.RemoveAll("venv") err = os.RemoveAll("algo") err = os.RemoveAll("datasets") }) }) } } func TestData(t *testing.T) { algo, err := os.ReadFile(algoPath) require.NoError(t, err) algoHash := sha3.Sum256(algo) vtpm.ExternalTPM = &vtpm.DummyRWC{} alg := Algorithm{ Hash: algoHash, Algorithm: algo, } data, err := os.ReadFile(dataPath) require.NoError(t, err) dataHash := sha3.Sum256(data) cases := []struct { name string data Dataset err error }{ { name: "Test data successfully", data: Dataset{ Hash: dataHash, Dataset: data, Filename: datasetFile, }, }, { name: "Test State not ready", data: Dataset{ Dataset: data, Hash: dataHash, Filename: datasetFile, }, err: ErrStateNotReady, }, { name: "Test File name does not match manifest", data: Dataset{ Dataset: data, Hash: dataHash, Filename: "invalid", }, err: ErrFileNameMismatch, }, { name: "Test dataset not declared in manifest", data: Dataset{ Filename: datasetFile, }, err: ErrUndeclaredDataset, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs( algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime), ) events := new(mocks.Service) events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() if tc.err != ErrUndeclaredDataset { ctx = IndexToContext(ctx, 0) } ctx, cancel := context.WithCancel(ctx) defer cancel() client := new(MockAttestationClient) runnerCli := new(runnermocks.Client) runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) err := svc.InitComputation(ctx, testComputation(t)) require.NoError(t, err) time.Sleep(300 * time.Millisecond) if tc.err != ErrStateNotReady { err = svc.Algo(ctx, alg) require.NoError(t, err) time.Sleep(300 * time.Millisecond) } err = svc.Data(ctx, tc.data) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) t.Cleanup(func() { _ = os.RemoveAll("datasets") _ = os.RemoveAll("results") err = os.RemoveAll("venv") err = os.RemoveAll("algo") }) }) } } func TestResult(t *testing.T) { cases := []struct { name string err error setup func(svc *agentService) ctxSetup func(ctx context.Context) context.Context state statemachine.State }{ { name: "Test results not ready", err: ErrResultsNotReady, setup: func(svc *agentService) { }, state: Running, }, { name: "Test undeclared consumer", err: ErrUndeclaredConsumer, setup: func(svc *agentService) { svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}} }, ctxSetup: func(ctx context.Context) context.Context { return ctx }, state: ConsumingResults, }, { name: "Test results consumed and event sent", err: nil, setup: func(svc *agentService) { svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}} }, ctxSetup: func(ctx context.Context) context.Context { return IndexToContext(ctx, 0) }, state: ConsumingResults, }, } for _, tc := range cases { events := new(mocks.Service) events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() t.Run(tc.name, func(t *testing.T) { ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime), ) if tc.ctxSetup != nil { ctx = tc.ctxSetup(ctx) } client := new(MockAttestationClient) runnerCli := new(runnermocks.Client) sm := new(smmocks.StateMachine) sm.On("Start", ctx).Return(nil) sm.On("GetState").Return(tc.state) sm.On("SendEvent", mock.Anything).Return() svc := &agentService{ sm: sm, eventSvc: events, attestationClient: client, runnerClient: runnerCli, computation: testComputation(t), } go func() { if err := svc.sm.Start(ctx); err != nil { t.Errorf("Error starting state machine: %v", err) } }() tc.setup(svc) _, err := svc.Result(ctx) t.Cleanup(func() { _ = os.RemoveAll("datasets") _ = os.RemoveAll("results") }) assert.ErrorIs(t, err, tc.err, "expected %v, got %v", tc.err, err) }) } } func TestAttestation(t *testing.T) { client := new(MockAttestationClient) cases := []struct { name string reportData [vtpm.SEVNonce]byte nonce [vtpm.Nonce]byte rawQuote []uint8 platform attestation.PlatformType err error }{ { name: "Test SNP attestation successful", reportData: generateReportData(), nonce: [32]byte{}, rawQuote: make([]uint8, 0), platform: attestation.SNP, err: nil, }, { name: "Test SNP attestation failed", reportData: generateReportData(), nonce: [32]byte{}, rawQuote: nil, platform: attestation.SNP, err: ErrAttestationFailed, }, { name: "Test vTPM attestation successful", reportData: generateReportData(), nonce: [32]byte{}, rawQuote: make([]uint8, 0), platform: attestation.VTPM, err: nil, }, { name: "Test vTPM attestation failed", reportData: generateReportData(), nonce: [32]byte{}, rawQuote: nil, platform: attestation.VTPM, err: ErrAttestationVTpmFailed, }, { name: "Test SNP-vTPM attestation successful", reportData: generateReportData(), nonce: [32]byte{}, rawQuote: make([]uint8, 0), platform: attestation.SNPvTPM, err: nil, }, { name: "Test SNP-vTPM attestation failed", reportData: generateReportData(), nonce: [32]byte{}, rawQuote: nil, platform: attestation.SNPvTPM, err: ErrAttestationVTpmFailed, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { events := new(mocks.Service) events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime), ) ctx, cancel := context.WithCancel(ctx) defer cancel() getQuote := client.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err) if tc.err != ErrAttestationFailed && tc.err != ErrAttestationVTpmFailed { getQuote = client.On("GetAttestation", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.nonce[:], nil) } defer getQuote.Unset() runnerCli := new(runnermocks.Client) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) time.Sleep(300 * time.Millisecond) _, err := svc.Attestation(ctx, tc.reportData, tc.nonce, tc.platform) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) }) } } func TestAzureAttestationToken(t *testing.T) { client := new(MockAttestationClient) cases := []struct { name string nonce [vtpm.Nonce]byte token []byte err error }{ { name: "Azure token fetch successful", nonce: [32]byte{1, 2, 3}, // any test nonce token: []byte("mockToken"), err: nil, }, { name: "Azure token fetch failed", nonce: [32]byte{4, 5, 6}, token: []byte{}, err: ErrAttestationType, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { events := new(mocks.Service) events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() client.On("GetAzureToken", mock.Anything, tc.nonce).Return(tc.token, tc.err) ctx := context.Background() runnerCli := new(runnermocks.Client) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) _, err := svc.AzureAttestationToken(ctx, tc.nonce) assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) }) } } func generateReportData() [vtpm.SEVNonce]byte { bytes := make([]byte, vtpm.SEVNonce) _, err := rand.Read(bytes) if err != nil { log.Fatalf("Failed to generate random bytes: %v", err) } return [64]byte(bytes) } func testComputation(t *testing.T) Computation { algo, err := os.ReadFile(algoPath) require.NoError(t, err) algoHash := sha3.Sum256(algo) data, err := os.ReadFile(dataPath) require.NoError(t, err) dataHash := sha3.Sum256(data) return Computation{ ID: "1", Name: "sample computation", Description: "sample description", Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}}, Algorithm: &Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo}, ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}}, } } func TestStopComputation(t *testing.T) { cases := []struct { name string setupDirs bool setupAlgo bool algoStopErr error expectedErr error }{ { name: "Stop computation successfully", setupDirs: true, setupAlgo: true, algoStopErr: nil, expectedErr: nil, }, { name: "Stop computation with algorithm stop error", setupDirs: true, setupAlgo: true, algoStopErr: fmt.Errorf("algorithm stop failed"), expectedErr: nil, // Warn only }, // We log warnings but don't return error in StopComputation in new implementation for Stop failure. } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { events := new(mocks.Service) events.On("SendEvent", mock.Anything, "Stopped", "Stopped", mock.Anything).Return() ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() client := new(MockAttestationClient) runnerCli := new(runnermocks.Client) // Mock Stop call var stopErr error if tc.algoStopErr != nil { stopErr = tc.algoStopErr } runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, stopErr) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0).(*agentService) svc.computation = Computation{ ID: "test-computation", Name: "test", } if tc.setupDirs { err := os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) err = os.MkdirAll(algorithm.ResultsDir, 0o755) require.NoError(t, err) } // Use real dirs for test // algorithm.DatasetsDir refers to global var? // "github.com/ultravioletrs/cocos/agent/algorithm" // It uses hardcoded path "datasets" and "results" in current dir. // Tests create them in current dir. err := svc.StopComputation(ctx) if tc.expectedErr != nil { assert.Error(t, err) assert.Contains(t, err.Error(), tc.expectedErr.Error()) } else { assert.NoError(t, err) } assert.Equal(t, ReceivingManifest, svc.sm.GetState()) assert.Nil(t, svc.result) assert.Nil(t, svc.runError) assert.False(t, svc.resultsConsumed) events.AssertExpectations(t) _ = os.RemoveAll(algorithm.DatasetsDir) _ = os.RemoveAll(algorithm.ResultsDir) }) } } func TestStopComputationIntegration(t *testing.T) { if testing.Short() { t.Skip("Skipping integration test in short mode") } algo := []byte("#!/bin/bash\necho 'test algorithm'") algoHash := sha3.Sum256(algo) testDir := "test_integration" err := os.MkdirAll(testDir, 0o755) require.NoError(t, err) defer os.RemoveAll(testDir) algoFile := filepath.Join(testDir, "test_algo") err = os.WriteFile(algoFile, algo, 0o755) require.NoError(t, err) events := new(mocks.Service) events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "bin"), ) ctx, cancel := context.WithCancel(ctx) defer cancel() client := new(MockAttestationClient) runnerCli := new(runnermocks.Client) runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) computation := Computation{ ID: "integration-test", Name: "Integration Test", Algorithm: &Algorithm{ Hash: algoHash, Algorithm: algo, }, } err = svc.InitComputation(ctx, computation) require.NoError(t, err) time.Sleep(100 * time.Millisecond) err = svc.Algo(ctx, Algorithm{ Hash: algoHash, Algorithm: algo, }) require.NoError(t, err) time.Sleep(100 * time.Millisecond) err = svc.StopComputation(ctx) assert.NoError(t, err) assert.Equal(t, "ReceivingManifest", svc.State()) } func TestStopComputationConcurrent(t *testing.T) { events := new(mocks.Service) events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() ctx := context.Background() ctx, cancel := context.WithCancel(ctx) defer cancel() client := new(MockAttestationClient) runnerCli := new(runnermocks.Client) runnerCli.On("Stop", mock.Anything, mock.Anything).Return(&emptypb.Empty{}, nil) svc := New(ctx, mglog.NewMock(), events, client, runnerCli, 0) svc.(*agentService).computation = Computation{ ID: "concurrent-test", Name: "Concurrent Test", } const numGoroutines = 10 errChan := make(chan error, numGoroutines) for i := 0; i < numGoroutines; i++ { go func() { err := svc.StopComputation(ctx) errChan <- err }() } var errors []error for i := 0; i < numGoroutines; i++ { err := <-errChan if err != nil { errors = append(errors, err) } } assert.True(t, len(errors) < numGoroutines, "All StopComputation calls failed") } // newTestAgentService creates a minimal agentService for direct method testing. func newTestAgentService(sm statemachine.StateMachine, eventSvc agentevents.Service) *agentService { return &agentService{ logger: slog.Default(), eventSvc: eventSvc, sm: sm, } } func TestDownloadAndDecryptResource(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", mock.Anything).Return().Maybe() svc := newTestAgentService(sm, eventsSvc) ctx := context.Background() t.Run("unsupported URL format no type", func(t *testing.T) { source := &ResourceSource{URL: "abc://unsupported-format"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source URL format") }) t.Run("ftp URL unsupported format", func(t *testing.T) { source := &ResourceSource{URL: "ftp://some-server/file"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source URL format") }) t.Run("unsupported explicit source type", func(t *testing.T) { source := &ResourceSource{Type: "s3-bucket", URL: "s3://mybucket/algo"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "unsupported source type: s3-bucket") }) t.Run("bare OCI image name inferred as oci-image", func(t *testing.T) { source := &ResourceSource{URL: "ubuntu:latest"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) // Should route to OCI and fail at OCI client (which is nil or mock) assert.NotContains(t, err.Error(), "unsupported source URL format") }) t.Run("bare registry image name inferred as oci-image", func(t *testing.T) { source := &ResourceSource{URL: "gcr.io/project/image:latest"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.NotContains(t, err.Error(), "unsupported source URL format") }) t.Run("docker:// URL inferred as oci-image routes to skopeo", func(t *testing.T) { // This exercises the oci-image path; will fail at skopeo step source := &ResourceSource{URL: "docker://invalid.example.com/algo:latest"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) // Should be a skopeo or OCI error, not an "unsupported" error assert.NotContains(t, err.Error(), "unsupported source URL format") }) t.Run("oci: URL inferred as oci-image routes to skopeo", func(t *testing.T) { source := &ResourceSource{URL: "oci:some-local-dir"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.NotContains(t, err.Error(), "unsupported source URL format") }) t.Run("explicit oci-image type routes to skopeo", func(t *testing.T) { source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/algo:latest"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.NotContains(t, err.Error(), "unsupported source type") }) t.Run("dataset resource type with oci-image routes to skopeo", func(t *testing.T) { source := &ResourceSource{Type: "oci-image", URL: "docker://invalid.example.com/data:latest"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "dataset") require.Error(t, err) assert.NotContains(t, err.Error(), "unsupported source type") }) t.Run("https inferred routes to registry", func(t *testing.T) { // Mock registry to fail predictably source := &ResourceSource{URL: "https://example.com/file.bin"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) // It should complain about registry missing, because the test service does not initialize the registry assert.Contains(t, err.Error(), "resource registry not initialized") }) t.Run("s3 inferred routes to registry", func(t *testing.T) { source := &ResourceSource{URL: "s3://bucket/key"} _, err := svc.downloadAndDecryptResource(ctx, source, "", "algorithm") require.Error(t, err) assert.Contains(t, err.Error(), "resource registry not initialized") }) } func TestDownloadAlgorithmIfRemote(t *testing.T) { t.Run("no source configured - no-op, waits for direct upload", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} // No SendEvent expected — just the no-op path svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{} // Algorithm.Source == nil svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Nil(t, svc.runError) sm.AssertExpectations(t) }) t.Run("source set but KBS disabled - no-op", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Algorithm: &Algorithm{ Source: &ResourceSource{URL: "docker://registry/algo:latest"}, KBS: &KBSConfig{Enabled: false}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Nil(t, svc.runError) sm.AssertExpectations(t) }) t.Run("source + KBS enabled - download fails, sends RunFailed", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Algorithm: &Algorithm{ Source: &ResourceSource{ Type: "oci-image", URL: "docker://invalid.example.com/algo:latest", }, KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.NotNil(t, svc.runError) assert.Contains(t, svc.runError.Error(), "failed to download and decrypt algorithm") sm.AssertExpectations(t) }) t.Run("unsupported URL format - download fails, sends RunFailed", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Algorithm: &Algorithm{ Source: &ResourceSource{ URL: "http://unsupported-format/algo", }, KBS: &KBSConfig{Enabled: true}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.NotNil(t, svc.runError) sm.AssertExpectations(t) }) } func TestDownloadDatasetsIfRemote(t *testing.T) { t.Run("no datasets with remote sources - no-op", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} svc := newTestAgentService(sm, eventsSvc) // Dataset with no Source dataHash := sha3.Sum256([]byte("testdata")) svc.computation = Computation{ Datasets: []Dataset{ {Hash: dataHash, Filename: "data.csv"}, }, } svc.downloadDatasetsIfRemote(ReceivingData) // No RunFailed event, no DataReceived event sm.AssertExpectations(t) }) t.Run("no datasets at all - no-op", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Datasets: []Dataset{}, } svc.downloadDatasetsIfRemote(ReceivingData) sm.AssertExpectations(t) }) t.Run("KBS disabled even with source - no-op", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.csv", Source: &ResourceSource{URL: "docker://registry/data:latest"}, KBS: &KBSConfig{Enabled: false}, }, }, } svc.downloadDatasetsIfRemote(ReceivingData) sm.AssertExpectations(t) }) t.Run("remote dataset + KBS enabled - download fails, sends RunFailed", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.csv", Source: &ResourceSource{ Type: "oci-image", URL: "docker://invalid.example.com/data:latest", }, KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, }, }, } svc.downloadDatasetsIfRemote(ReceivingData) sm.AssertExpectations(t) }) t.Run("unsupported URL fails - sends RunFailed", func(t *testing.T) { eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.csv", Source: &ResourceSource{ URL: "http://unsupported-format/data", }, KBS: &KBSConfig{Enabled: true}, }, }, } svc.downloadDatasetsIfRemote(ReceivingData) sm.AssertExpectations(t) }) } func TestRunComputation(t *testing.T) { // Helper to set up a temp working directory and restore CWD afterwards. withTempDir := func(t *testing.T) (tmpDir string, restore func()) { t.Helper() origDir, err := os.Getwd() require.NoError(t, err) tmpDir = t.TempDir() require.NoError(t, os.Chdir(tmpDir)) return tmpDir, func() { _ = os.Chdir(origDir) } } t.Run("algo file not found sends RunFailed", func(t *testing.T) { _, restore := withTempDir(t) defer restore() eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) // No algo file exists – runComputation should hit the ReadFile error path. svc.runComputation(Running) assert.Error(t, svc.runError) assert.Contains(t, svc.runError.Error(), "failed to read algo file") sm.AssertExpectations(t) }) t.Run("runner client returns error sends RunFailed", func(t *testing.T) { _, restore := withTempDir(t) defer restore() // Write a dummy algo file so ReadFile succeeds. require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755)) runnerCli := new(runnermocks.Client) runnerCli.On("Run", mock.Anything, mock.Anything).Return((*runnerpb.RunResponse)(nil), fmt.Errorf("runner unavailable")) eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) svc.runnerClient = runnerCli svc.runComputation(Running) assert.Error(t, svc.runError) assert.Contains(t, svc.runError.Error(), "runner unavailable") sm.AssertExpectations(t) }) t.Run("runner returns non-empty error field sends RunFailed", func(t *testing.T) { _, restore := withTempDir(t) defer restore() require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755)) runnerCli := new(runnermocks.Client) runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{Error: "computation crashed"}, nil) eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() svc := newTestAgentService(sm, eventsSvc) svc.runnerClient = runnerCli svc.runComputation(Running) assert.Error(t, svc.runError) assert.Contains(t, svc.runError.Error(), "computation crashed") sm.AssertExpectations(t) }) } func TestIMAMeasurements(t *testing.T) { t.Run("error when IMA measurements file does not exist in non-SGX environment", func(t *testing.T) { // In a regular test environment (non-SGX), the IMA measurements file // at /sys/kernel/security/integrity/ima/ascii_runtime_measurements won't exist. // Verify our error handling works correctly. origPath := ImaMeasurementsFilePath ImaMeasurementsFilePath = "/non/existent/path" defer func() { ImaMeasurementsFilePath = origPath }() eventsSvc := new(mocks.Service) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} svc := newTestAgentService(sm, eventsSvc) data, pcr10, err := svc.IMAMeasurements(context.Background()) assert.Error(t, err) assert.Contains(t, err.Error(), "error reading Linux IMA measurements file") assert.Nil(t, data) assert.Nil(t, pcr10) }) t.Run("successful reading of IMA measurements", func(t *testing.T) { tempFile := filepath.Join(t.TempDir(), "ima_measurements") content := []byte("10 sha1:0000000000000000000000000000000000000000 ima-ng sha256:0000000000000000000000000000000000000000000000000000000000000000 /usr/bin/python3\n") err := os.WriteFile(tempFile, content, 0o644) require.NoError(t, err) vtpm.ExternalTPM = &vtpm.DummyRWC{} origPath := ImaMeasurementsFilePath ImaMeasurementsFilePath = tempFile defer func() { ImaMeasurementsFilePath = origPath }() eventsSvc := new(mocks.Service) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} svc := newTestAgentService(sm, eventsSvc) data, pcr10, err := svc.IMAMeasurements(context.Background()) assert.NoError(t, err) assert.Equal(t, content, data) assert.NotEmpty(t, pcr10) }) } func TestDownloadAlgorithmIfRemote_Success(t *testing.T) { // Skip this test in short mode as it might involve more setup if we were using real OCI if testing.Short() { t.Skip("skipping in short mode") } origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", AlgorithmReceived).Return().Once() mockOCI := new(MockOCIClient) algoContent := []byte("print('hello')") mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "main.py", algoContent) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI algoContent = []byte("print('hello')") algoHash := sha3.Sum256(algoContent) svc.computation = Computation{ Algorithm: &Algorithm{ Hash: algoHash, AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-success", }, KBS: &KBSConfig{Enabled: true}, }, } // We need to bypass oci.ExtractAlgorithm by manually creating what it would create // OR use a real-enough looking OCI layout. // Since we can't easily mock oci.ExtractAlgorithm, we'll try to provide a minimal OCI layout // so that oci.ExtractAlgorithm doesn't fail. svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Nil(t, svc.runError) assert.True(t, svc.algoReceived) sm.AssertExpectations(t) mockOCI.AssertExpectations(t) } func TestDownloadAlgorithmIfRemote_Docker_Success(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", AlgorithmReceived).Return().Once() mockOCI := new(MockOCIClient) mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Return(nil) dummyContent := []byte("dummy docker tar") dummyHash := sha3.Sum256(dummyContent) mockOCI.On("ToDockerArchive", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destFile := args.String(2) err := os.WriteFile(destFile, dummyContent, 0o644) require.NoError(t, err) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Algorithm: &Algorithm{ AlgoType: "docker", Hash: dummyHash, Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-docker-success", }, KBS: &KBSConfig{Enabled: true}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Nil(t, svc.runError) assert.True(t, svc.algoReceived) sm.AssertExpectations(t) mockOCI.AssertExpectations(t) } func setupMinimalOCI(t *testing.T, ociDir, filename string, content []byte) { t.Helper() blobsDir := filepath.Join(ociDir, "blobs", "sha256") require.NoError(t, os.MkdirAll(blobsDir, 0o755)) layerPath := filepath.Join(blobsDir, "layer123") layerFile, err := os.Create(layerPath) require.NoError(t, err) gw := gzip.NewWriter(layerFile) tw := tar.NewWriter(gw) hdr := &tar.Header{ Name: filename, Mode: 0o755, Size: int64(len(content)), } require.NoError(t, tw.WriteHeader(hdr)) _, err = tw.Write(content) require.NoError(t, err) require.NoError(t, tw.Close()) require.NoError(t, gw.Close()) require.NoError(t, layerFile.Close()) manifest := struct { Layers []struct { Digest string `json:"digest"` } `json:"layers"` }{ Layers: []struct { Digest string `json:"digest"` }{{Digest: "sha256:layer123"}}, } manifestData, err := json.Marshal(manifest) require.NoError(t, err) require.NoError(t, os.WriteFile(filepath.Join(blobsDir, "manifest123"), manifestData, 0o644)) index := oci.OCIIndex{ SchemaVersion: 2, Manifests: []struct { MediaType string `json:"mediaType"` Digest string `json:"digest"` Size int `json:"size"` }{{Digest: "sha256:manifest123", Size: len(manifestData)}}, } indexData, err := json.Marshal(index) require.NoError(t, err) require.NoError(t, os.WriteFile(filepath.Join(ociDir, "index.json"), indexData, 0o644)) } func TestDownloadDatasetsIfRemote_Success(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", DataReceived).Return().Once() mockOCI := new(MockOCIClient) dataContent := []byte("a,b,c\n1,2,3") mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "data.csv", dataContent) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI dataContent = []byte("a,b,c\n1,2,3") dataHash := sha3.Sum256(dataContent) svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.csv", Hash: dataHash, Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/data-success", }, KBS: &KBSConfig{Enabled: true, URL: "https://kbs.example.com"}, }, }, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) svc.downloadDatasetsIfRemote(ReceivingData) assert.Nil(t, svc.runError) assert.Len(t, svc.computation.Datasets, 0) sm.AssertExpectations(t) mockOCI.AssertExpectations(t) } func TestDownloadDatasetsIfRemote_Decompress(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", DataReceived).Return().Maybe() sm.On("SendEvent", RunFailed).Return().Maybe() mockOCI := new(MockOCIClient) // Create a zip file in memory var buf bytes.Buffer zw := zip.NewWriter(&buf) f, err := zw.Create("test.txt") require.NoError(t, err) _, err = f.Write([]byte("hello zip")) require.NoError(t, err) require.NoError(t, zw.Close()) zipData := buf.Bytes() mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "data.zip", zipData) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI dataHash := sha3.Sum256(zipData) svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.zip", Hash: dataHash, Decompress: true, Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/data-decompress", }, KBS: &KBSConfig{Enabled: true}, }, }, } err = os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) svc.downloadDatasetsIfRemote(ReceivingData) assert.Nil(t, svc.runError) assert.Len(t, svc.computation.Datasets, 0) // Check if file was decompressed decompressedFile := filepath.Join(algorithm.DatasetsDir, "test.txt") _, err = os.Stat(decompressedFile) assert.NoError(t, err) sm.AssertExpectations(t) mockOCI.AssertExpectations(t) } func TestDownloadAlgorithmIfRemote_ErrorPathsInternal(t *testing.T) { origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() t.Run("hash mismatch", func(t *testing.T) { sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() mockOCI := new(MockOCIClient) mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "main.py", []byte("wrong content")) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Algorithm: &Algorithm{ Hash: sha3.Sum256([]byte("expected content")), AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-hash-mismatch", }, KBS: &KBSConfig{Enabled: true}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Error(t, svc.runError) assert.Contains(t, svc.runError.Error(), "algorithm hash mismatch") sm.AssertExpectations(t) }) t.Run("create algo file failure", func(t *testing.T) { sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() // Create a directory named "algo" to make file creation fail require.NoError(t, os.Mkdir("algo", 0o755)) defer os.RemoveAll("algo") mockOCI := new(MockOCIClient) algoContent := "print(1)" mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "main.py", []byte(algoContent)) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Algorithm: &Algorithm{ Hash: sha3.Sum256([]byte(algoContent)), AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-create-fail", }, KBS: &KBSConfig{Enabled: true}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Error(t, svc.runError) assert.Contains(t, svc.runError.Error(), "error creating algorithm file") sm.AssertExpectations(t) }) t.Run("extraction failure", func(t *testing.T) { sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() mockOCI := new(MockOCIClient) mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) // Setup OCI with NO main.py or any algorithm file require.NoError(t, os.MkdirAll(filepath.Join(destDir, "blobs"), 0o755)) // Create a legit-looking but empty index.json require.NoError(t, os.WriteFile(filepath.Join(destDir, "index.json"), []byte(`{"schemaVersion":2,"manifests":[]}`), 0o644)) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Algorithm: &Algorithm{ AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/image", }, KBS: &KBSConfig{Enabled: true}, }, } svc.downloadAlgorithmIfRemote(ReceivingAlgorithm) assert.Error(t, svc.runError) assert.Contains(t, svc.runError.Error(), "no manifests found") sm.AssertExpectations(t) }) } func TestDownloadDatasetsIfRemote_ErrorPathsInternal(t *testing.T) { origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() // Use a fresh mock in each subtest to avoid state pollution t.Run("dataset create file failure", func(t *testing.T) { eventsSvc := mocks.NewService(t) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() // Create a directory named "data.csv" in datasets dir to make file creation fail require.NoError(t, os.MkdirAll(filepath.Join(algorithm.DatasetsDir, "data.csv"), 0o755)) defer os.RemoveAll(algorithm.DatasetsDir) mockOCI := new(MockOCIClient) dataContent := "a,b,c" mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent)) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.csv", Hash: sha3.Sum256([]byte(dataContent)), Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/data-create-fail", }, KBS: &KBSConfig{Enabled: true}, }, }, } svc.downloadDatasetsIfRemote(ReceivingData) sm.AssertExpectations(t) }) t.Run("dataset hash mismatch", func(t *testing.T) { eventsSvc := mocks.NewService(t) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe() origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { _ = os.Chdir(origDir) }() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() mockOCI := new(MockOCIClient) dataContent := "wrong content" mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "data.csv", []byte(dataContent)) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.csv", Hash: sha3.Sum256([]byte("expected content")), Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/data-mismatch", }, KBS: &KBSConfig{Enabled: true}, }, }, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) svc.downloadDatasetsIfRemote(ReceivingData) if svc.runError == nil { t.Fatalf("runError should not be nil in hash mismatch test") } assert.Contains(t, svc.runError.Error(), "dataset data.csv hash mismatch") sm.AssertExpectations(t) }) t.Run("dataset unzip failure", func(t *testing.T) { eventsSvc := mocks.NewService(t) eventsSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything, mock.MatchedBy(func(json.RawMessage) bool { return true })).Return().Maybe() origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { _ = os.Chdir(origDir) }() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunFailed).Return().Once() mockOCI := new(MockOCIClient) // Provide invalid zip content dataContent := "not a zip file" mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "data.zip", []byte(dataContent)) }).Return(nil) svc := newTestAgentService(sm, eventsSvc) svc.ociClient = mockOCI svc.computation = Computation{ Datasets: []Dataset{ { Filename: "data.zip", Hash: sha3.Sum256([]byte(dataContent)), Decompress: true, Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/data-unzip-fail", }, KBS: &KBSConfig{Enabled: true}, }, }, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) svc.downloadDatasetsIfRemote(ReceivingData) if svc.runError == nil { t.Fatalf("runError should not be nil in unzip failure test") } assert.Contains(t, svc.runError.Error(), "failed to unzip dataset") sm.AssertExpectations(t) }) } func TestAlgo_RemoteSource(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("GetState").Return(ReceivingAlgorithm) sm.On("SendEvent", AlgorithmReceived).Return().Once() mockOCI := new(MockOCIClient) algoContent := []byte("print('remote algo')") algoHash := sha3.Sum256(algoContent) mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "main.py", algoContent) }).Return(nil) svc := &agentService{ logger: slog.Default(), eventSvc: eventsSvc, sm: sm, ociClient: mockOCI, computation: Computation{ Algorithm: &Algorithm{ Hash: algoHash, AlgoType: "python", Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/algo-remote", }, KBS: &KBSConfig{Enabled: true}, }, }, } ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(algorithm.AlgoTypeKey, "python")) err := svc.Algo(ctx, Algorithm{}) assert.NoError(t, err) assert.True(t, svc.algoReceived) sm.AssertExpectations(t) mockOCI.AssertExpectations(t) } func TestData_RemoteSource(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("GetState").Return(ReceivingData) sm.On("SendEvent", DataReceived).Return().Once() mockOCI := new(MockOCIClient) dataContent := []byte("remote data") dataHash := sha3.Sum256(dataContent) mockOCI.On("PullAndDecrypt", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { destDir := args.String(2) setupMinimalOCI(t, destDir, "data.csv", dataContent) }).Return(nil) svc := &agentService{ logger: slog.Default(), eventSvc: eventsSvc, sm: sm, ociClient: mockOCI, computation: Computation{ Datasets: []Dataset{ { Filename: "data.csv", Hash: dataHash, Source: &ResourceSource{ Type: "oci-image", URL: "docker://test/data-remote", }, KBS: &KBSConfig{Enabled: true}, }, }, }, } err := os.MkdirAll(algorithm.DatasetsDir, 0o755) require.NoError(t, err) ctx := context.Background() err = svc.Data(ctx, Dataset{}) assert.NoError(t, err) assert.Len(t, svc.computation.Datasets, 0) sm.AssertExpectations(t) mockOCI.AssertExpectations(t) } func TestRunComputation_Success(t *testing.T) { origDir, _ := os.Getwd() tmpDir := t.TempDir() require.NoError(t, os.Chdir(tmpDir)) defer func() { require.NoError(t, os.Chdir(origDir)) }() // Write a dummy algo file require.NoError(t, os.WriteFile("algo", []byte("#!/bin/sh\necho ok\n"), 0o755)) runnerCli := new(runnermocks.Client) runnerCli.On("Run", mock.Anything, mock.Anything).Return(&runnerpb.RunResponse{}, nil) eventsSvc := new(mocks.Service) eventsSvc.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return().Maybe() sm := &smmocks.StateMachine{} sm.On("SendEvent", RunComplete).Return().Once() svc := &agentService{ logger: slog.Default(), eventSvc: eventsSvc, sm: sm, runnerClient: runnerCli, computation: Computation{ID: "test-run"}, } svc.runComputation(Running) assert.Nil(t, svc.runError) sm.AssertExpectations(t) runnerCli.AssertExpectations(t) } func TestInferSourceType(t *testing.T) { testCases := []struct { url string expected string }{ {"docker://test/repo", resource.SourceTypeOCIImage}, {"oci:test/repo", resource.SourceTypeOCIImage}, {"s3://bucket/key", resource.SourceTypeS3}, {"gs://bucket/key", resource.SourceTypeGCS}, {"https://example.com/file", resource.SourceTypeHTTPS}, {"http://example.com/file", resource.SourceTypeHTTP}, {"abc://example.com/file", ""}, {"ftp://example.com/file", ""}, {"unknown://example.com/file", ""}, {"malformed-url", ""}, {"", ""}, } for _, tc := range testCases { t.Run(tc.url, func(t *testing.T) { result := inferSourceType(tc.url) assert.Equal(t, tc.expected, result) }) } }