From 46b94204df963eb94c911705e21d9a8e4ea00d45 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Thu, 7 Nov 2024 12:47:53 +0300 Subject: [PATCH] NOISSUE - Improve file streaming (#295) * improve file streaming Signed-off-by: Sammy Oina * error check Signed-off-by: Sammy Oina * empty line Signed-off-by: Sammy Oina * fix tests Signed-off-by: Sammy Oina * send buffer test Signed-off-by: Sammy Oina * fix test cases Signed-off-by: Sammy Oina * stream data and attestation Signed-off-by: Sammy Oina * fumpt Signed-off-by: Sammy Oina * fix test Signed-off-by: Sammy Oina * mocks Signed-off-by: Sammy Oina * value check Signed-off-by: Sammy Oina * more value checks Signed-off-by: Sammy Oina * add test cases Signed-off-by: Sammy Oina * fumpt Signed-off-by: Sammy Oina * fix tests Signed-off-by: Sammy Oina * all files Signed-off-by: Sammy Oina * fix lint Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/events/mocks/events.go | 5 -- cli/algorithm_test.go | 10 +-- cli/algorithms.go | 17 ++-- cli/attestation.go | 33 +++++-- cli/attestation_test.go | 17 +++- cli/datasets.go | 16 ++-- cli/datasets_test.go | 10 +-- cli/result.go | 50 ++++------- cli/result_test.go | 69 ++++----------- codecov.yml | 2 +- internal/zip.go | 54 ++++++++++++ internal/zip_test.go | 146 +++++++++++++++++++++++++++++++ pkg/progressbar/progress_test.go | 75 ++++++++++++++-- pkg/progressbar/progressbar.go | 102 ++++++++++++--------- pkg/sdk/agent.go | 43 +++++---- pkg/sdk/agent_test.go | 70 +++++++++++++-- pkg/sdk/mocks/sdk.go | 79 ++++++----------- 17 files changed, 536 insertions(+), 262 deletions(-) diff --git a/agent/events/mocks/events.go b/agent/events/mocks/events.go index b10080db..1ac7da61 100644 --- a/agent/events/mocks/events.go +++ b/agent/events/mocks/events.go @@ -16,11 +16,6 @@ type Service struct { mock.Mock } -// Close provides a mock function with given fields: -func (_m *Service) Close() { - _m.Called() -} - // SendEvent provides a mock function with given fields: event, status, details func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error { ret := _m.Called(event, status, details) diff --git a/cli/algorithm_test.go b/cli/algorithm_test.go index c8863740..d488711a 100644 --- a/cli/algorithm_test.go +++ b/cli/algorithm_test.go @@ -57,7 +57,7 @@ func TestAlgorithmCmd(t *testing.T) { { name: "successful upload", setupMock: func(m *mocks.SDK) { - m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() error { if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil { @@ -75,7 +75,7 @@ func TestAlgorithmCmd(t *testing.T) { { name: "missing algorithm file", setupMock: func(m *mocks.SDK) { - m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, args: []string{"non_existent_algo_file.py", privateKeyFile}, expectedOutput: "Error reading algorithm file", @@ -83,7 +83,7 @@ func TestAlgorithmCmd(t *testing.T) { { name: "missing private key file", setupMock: func(m *mocks.SDK) { - m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() error { return os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644) @@ -97,7 +97,7 @@ func TestAlgorithmCmd(t *testing.T) { { name: "upload failure", setupMock: func(m *mocks.SDK) { - m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) + m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) }, setupFiles: func() error { if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil { @@ -115,7 +115,7 @@ func TestAlgorithmCmd(t *testing.T) { { name: "invalid private key", setupMock: func(m *mocks.SDK) { - m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() error { if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil { diff --git a/cli/algorithms.go b/cli/algorithms.go index a36844e4..0f1c5ca9 100644 --- a/cli/algorithms.go +++ b/cli/algorithms.go @@ -9,7 +9,6 @@ import ( "github.com/fatih/color" "github.com/spf13/cobra" - "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/algorithm/python" "google.golang.org/grpc/metadata" @@ -38,24 +37,22 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { cmd.Println("Uploading algorithm file:", algorithmFile) - algorithm, err := os.ReadFile(algorithmFile) + algorithm, err := os.Open(algorithmFile) if err != nil { printError(cmd, "Error reading algorithm file: %v ❌ ", err) return } - var req []byte + defer algorithm.Close() + + var req *os.File if requirementsFile != "" { - req, err = os.ReadFile(requirementsFile) + req, err = os.Open(requirementsFile) if err != nil { printError(cmd, "Error reading requirments file: %v ❌ ", err) return } - } - - algoReq := agent.Algorithm{ - Algorithm: algorithm, - Requirements: req, + defer req.Close() } privKeyFile, err := os.ReadFile(args[1]) @@ -74,7 +71,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) - if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil { + if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil { printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err) return } diff --git a/cli/attestation.go b/cli/attestation.go index 950683db..d57d89de 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -176,25 +176,44 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { return } - result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData)) + filename := attestationFilePath + if getJsonAttestation { + filename = attestationJson + } + + attestationFile, err := os.Create(filename) if err != nil { + printError(cmd, "Error creating attestation file: %v ❌ ", err) + return + } + + if err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData), attestationFile); err != nil { printError(cmd, "Failed to get attestation due to error: %v ❌ ", err) return } - filename := attestationFilePath + if err := attestationFile.Close(); err != nil { + printError(cmd, "Error closing attestation file: %v ❌ ", err) + return + } + if getJsonAttestation { + result, err := os.ReadFile(filename) + if err != nil { + printError(cmd, "Error reading attestation file: %v ❌ ", err) + return + } + result, err = attesationToJSON(result) if err != nil { printError(cmd, "Error converting attestation to json: %v ❌ ", err) return } - filename = attestationJson - } - if err = os.WriteFile(filename, result, 0o644); err != nil { - printError(cmd, "Error saving attestation result: %v ❌ ", err) - return + if err := os.WriteFile(filename, result, 0o644); err != nil { + printError(cmd, "Error writing attestation file: %v ❌ ", err) + return + } } cmd.Println("Attestation result retrieved and saved successfully!") diff --git a/cli/attestation_test.go b/cli/attestation_test.go index bd97e87f..b1a628e1 100644 --- a/cli/attestation_test.go +++ b/cli/attestation_test.go @@ -22,7 +22,8 @@ import ( ) func TestNewAttestationCmd(t *testing.T) { - cli := &CLI{} + mockSDK := new(mocks.SDK) + cli := &CLI{agentSDK: mockSDK} cmd := cli.NewAttestationCmd() assert.Equal(t, "attestation [command]", cmd.Use) @@ -30,12 +31,19 @@ func TestNewAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOut(&buf) + + cmd.SetOutput(&buf) + + reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize) + mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData), mock.Anything).Return(nil) + + cmd.SetArgs([]string{hex.EncodeToString(reportData)}) err := cmd.Execute() assert.NoError(t, err) assert.Contains(t, buf.String(), "Get and validate attestations") } -func TestNewGetAttestationCmdN(t *testing.T) { +func TestNewGetAttestationCmd(t *testing.T) { validattestation, err := os.ReadFile("../attestation.bin") require.NoError(t, err) testCases := []struct { @@ -119,7 +127,10 @@ func TestNewGetAttestationCmdN(t *testing.T) { var buf bytes.Buffer cmd.SetOutput(&buf) - mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))).Return(tc.mockResponse, tc.mockError) + mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { + _, err := args.Get(2).(*os.File).Write(tc.mockResponse) + require.NoError(t, err) + }) cmd.SetArgs(tc.args) err := cmd.Execute() diff --git a/cli/datasets.go b/cli/datasets.go index e0009d4e..edd6b51e 100644 --- a/cli/datasets.go +++ b/cli/datasets.go @@ -41,25 +41,23 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { return } - var dataset []byte + var dataset *os.File if f.IsDir() { - dataset, err = internal.ZipDirectoryToMemory(datasetPath) + dataset, err = internal.ZipDirectoryToTempFile(datasetPath) if err != nil { printError(cmd, "Error zipping dataset directory: %v ❌ ", err) return } + defer dataset.Close() + defer os.Remove(dataset.Name()) } else { - dataset, err = os.ReadFile(datasetPath) + dataset, err = os.Open(datasetPath) if err != nil { printError(cmd, "Error reading dataset file: %v ❌ ", err) return } - } - - dataReq := agent.Dataset{ - Dataset: dataset, - Filename: path.Base(datasetPath), + defer dataset.Close() } privKeyFile, err := os.ReadFile(args[1]) @@ -77,7 +75,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { } ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) - if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil { + if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil { printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err) return } diff --git a/cli/datasets_test.go b/cli/datasets_test.go index 0d45c881..7f2cc9a0 100644 --- a/cli/datasets_test.go +++ b/cli/datasets_test.go @@ -39,7 +39,7 @@ func TestDatasetsCmd(t *testing.T) { { name: "successful upload", setupMock: func(m *mocks.SDK) { - m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() (string, error) { datasetFile, err := createTempDatasetFile("test dataset content") @@ -58,7 +58,7 @@ func TestDatasetsCmd(t *testing.T) { { name: "missing dataset file", setupMock: func(m *mocks.SDK) { - m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() (string, error) { return "", nil @@ -68,7 +68,7 @@ func TestDatasetsCmd(t *testing.T) { { name: "missing private key file", setupMock: func(m *mocks.SDK) { - m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() (string, error) { return createTempDatasetFile("test dataset content") @@ -81,7 +81,7 @@ func TestDatasetsCmd(t *testing.T) { { name: "upload failure", setupMock: func(m *mocks.SDK) { - m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) + m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) }, setupFiles: func() (string, error) { datasetFile, err := createTempDatasetFile("test dataset content") @@ -100,7 +100,7 @@ func TestDatasetsCmd(t *testing.T) { { name: "invalid private key", setupMock: func(m *mocks.SDK) { - m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) }, setupFiles: func() (string, error) { datasetFile, err := createTempDatasetFile("test dataset content") diff --git a/cli/result.go b/cli/result.go index be1b1fc1..866cc37b 100644 --- a/cli/result.go +++ b/cli/result.go @@ -4,7 +4,6 @@ package cli import ( "encoding/pem" - "fmt" "os" "github.com/fatih/color" @@ -14,13 +13,14 @@ import ( const ( resultFilePrefix = "results" resultFileExt = ".zip" + resultfilename = "results.zip" ) func (cli *CLI) NewResultsCmd() *cobra.Command { return &cobra.Command{ Use: "result", Short: "Retrieve computation result file", - Example: "result ", + Example: "result ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { @@ -36,50 +36,32 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { return } - pemBlock, _ := pem.Decode(privKeyFile) + filename := resultfilename + if len(args) > 1 { + filename = args[1] + } - var result []byte + pemBlock, _ := pem.Decode(privKeyFile) privKey, err := decodeKey(pemBlock) if err != nil { printError(cmd, "Error decoding private key: %v ❌ ", err) return } - result, err = cli.agentSDK.Result(cmd.Context(), privKey) + + resultFile, err := os.Create(filename) if err != nil { + printError(cmd, "Error creating result file: %v ❌ ", err) + return + } + defer resultFile.Close() + + if err = cli.agentSDK.Result(cmd.Context(), privKey, resultFile); err != nil { printError(cmd, "Error retrieving computation result: %v ❌ ", err) return } - resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt) - if err != nil { - printError(cmd, "Error generating unique file path: %v ❌ ", err) - return - } - - if err := os.WriteFile(resultFilePath, result, 0o644); err != nil { - printError(cmd, "Error saving computation result file: %v ❌ ", err) - return - } - - cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath)) + cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", filename)) }, } } - -func getUniqueFilePath(prefix, ext string) (string, error) { - for i := 0; ; i++ { - var filename string - if i == 0 { - filename = prefix + ext - } else { - filename = fmt.Sprintf("%s_%d%s", prefix, i, ext) - } - - if _, err := os.Stat(filename); os.IsNotExist(err) { - return filename, nil - } else if err != nil { - return "", err - } - } -} diff --git a/cli/result_test.go b/cli/result_test.go index 963000da..848c5d9a 100644 --- a/cli/result_test.go +++ b/cli/result_test.go @@ -6,7 +6,6 @@ package cli import ( "bytes" "errors" - "fmt" "os" "path/filepath" "testing" @@ -20,7 +19,10 @@ const compResult = "Test computation result" func TestResultsCmd_MultipleExecutions(t *testing.T) { mockSDK := new(mocks.SDK) - mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + _, err := args.Get(2).(*os.File).WriteString(compResult) + require.NoError(t, err) + }) testCLI := CLI{agentSDK: mockSDK} err := generateRSAPrivateKeyFile(privateKeyFile) @@ -40,7 +42,6 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) { files, err := filepath.Glob("results*.zip") require.NoError(t, err) - require.Len(t, files, 3) t.Cleanup(func() { for _, file := range files { @@ -52,7 +53,10 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) { func TestResultsCmd_InvalidPrivateKey(t *testing.T) { mockSDK := new(mocks.SDK) - mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + _, err := args.Get(2).(*os.File).WriteString(compResult) + require.NoError(t, err) + }) testCLI := CLI{agentSDK: mockSDK} invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem") @@ -73,30 +77,7 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) { require.NoError(t, err) require.Contains(t, buf.String(), "Error decoding private key") - mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything) -} - -func TestGetUniqueFilePath(t *testing.T) { - prefix := "test" - ext := ".txt" - - path, err := getUniqueFilePath(prefix, ext) - require.NoError(t, err) - require.Equal(t, "test.txt", path) - - _, err = os.Create("test.txt") - require.NoError(t, err) - defer os.Remove("test.txt") - for i := 1; i < 3; i++ { - fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext) - _, err := os.Create(fileName) - require.NoError(t, err) - defer os.Remove(fileName) - } - - path, err = getUniqueFilePath(prefix, ext) - require.NoError(t, err) - require.Equal(t, "test_3.txt", path) + mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything, mock.Anything) } func TestResultsCmd(t *testing.T) { @@ -111,7 +92,10 @@ func TestResultsCmd(t *testing.T) { { name: "successful result retrieval", setupMock: func(m *mocks.SDK) { - m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + _, err := args.Get(2).(*os.File).WriteString(compResult) + require.NoError(t, err) + }) }, setupFiles: func() (string, error) { return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile) @@ -128,7 +112,10 @@ func TestResultsCmd(t *testing.T) { { name: "missing private key file", setupMock: func(m *mocks.SDK) { - m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { + _, err := args.Get(2).(*os.File).WriteString(compResult) + require.NoError(t, err) + }) }, setupFiles: func() (string, error) { return "non_existent_private_key.pem", nil @@ -138,7 +125,7 @@ func TestResultsCmd(t *testing.T) { { name: "result retrieval failure", setupMock: func(m *mocks.SDK) { - m.On("Result", mock.Anything, mock.Anything).Return(nil, errors.New("error retrieving computation result")) + m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("error retrieving computation result")) }, setupFiles: func() (string, error) { return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile) @@ -148,26 +135,6 @@ func TestResultsCmd(t *testing.T) { os.Remove(privateKeyFile) }, }, - { - name: "save failure", - setupMock: func(m *mocks.SDK) { - m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) - }, - setupFiles: func() (string, error) { - err := generateRSAPrivateKeyFile(privateKeyFile) - if err != nil { - return "", err - } - // Simulate failure in saving the result file by making all files read-only - return privateKeyFile, os.Chmod(".", 0o555) - }, - expectedOutput: "Error saving computation result file", - cleanup: func() { - err := os.Chmod(".", 0o755) - require.NoError(t, err) - os.Remove(privateKeyFile) - }, - }, { name: "connection error", setupMock: func(m *mocks.SDK) { diff --git a/codecov.yml b/codecov.yml index 433c373e..ea7b5fc3 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,7 +4,7 @@ coverage: ignore: - "test/*" - - "cmd/*" + - "cmd/**" - "**/mocks/**" - "mocks/**" - "**/*.pb.go" diff --git a/internal/zip.go b/internal/zip.go index 25bce054..c3bba928 100644 --- a/internal/zip.go +++ b/internal/zip.go @@ -60,6 +60,60 @@ func ZipDirectoryToMemory(sourceDir string) ([]byte, error) { return buf.Bytes(), nil } +func ZipDirectoryToTempFile(sourceDir string) (*os.File, error) { + tmpFile, err := os.CreateTemp("", "dataset*.zip") + if err != nil { + return nil, err + } + + zipWriter := zip.NewWriter(tmpFile) + + err = filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + relPath, err := filepath.Rel(sourceDir, path) + if err != nil { + return err + } + + zipHeader, err := zip.FileInfoHeader(info) + if err != nil { + return err + } + zipHeader.Name = relPath + + zipWriterEntry, err := zipWriter.CreateHeader(zipHeader) + if err != nil { + return err + } + + fileToZip, err := os.Open(path) + if err != nil { + return err + } + defer fileToZip.Close() + + _, err = io.Copy(zipWriterEntry, fileToZip) + return err + }) + if err != nil { + zipWriter.Close() + return nil, err + } + + if err := zipWriter.Close(); err != nil { + return nil, err + } + + return tmpFile, nil +} + func UnzipFromMemory(zipData []byte, targetDir string) error { reader := bytes.NewReader(zipData) zipReader, err := zip.NewReader(reader, int64(len(zipData))) diff --git a/internal/zip_test.go b/internal/zip_test.go index 0bfa9506..d3224bf5 100644 --- a/internal/zip_test.go +++ b/internal/zip_test.go @@ -3,6 +3,8 @@ package internal import ( + "archive/zip" + "io" "os" "path/filepath" "testing" @@ -104,3 +106,147 @@ func TestZipDirectoryToMemory_NonExistentDirectory(t *testing.T) { t.Error("ZipDirectoryToMemory should fail with non-existent directory") } } + +func TestZipDirectoryToTempFile(t *testing.T) { + tests := []struct { + name string + setupFiles map[string]string // map of relative path to content + expectError bool + }{ + { + name: "single file", + setupFiles: map[string]string{ + "test.txt": "hello world", + }, + expectError: false, + }, + { + name: "multiple files in root", + setupFiles: map[string]string{ + "test1.txt": "content1", + "test2.txt": "content2", + "test3.txt": "content3", + }, + expectError: false, + }, + { + name: "nested directory structure", + setupFiles: map[string]string{ + "file1.txt": "root file", + "dir1/file2.txt": "nested file", + "dir1/dir2/file3.txt": "deeply nested file", + "dir1/dir2/dir3/file4.txt": "very deeply nested file", + }, + expectError: false, + }, + { + name: "empty directory", + setupFiles: map[string]string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sourceDir, err := os.MkdirTemp("", "source") + if err != nil { + t.Fatalf("Failed to create temp source directory: %v", err) + } + defer os.RemoveAll(sourceDir) + + for relPath, content := range tt.setupFiles { + fullPath := filepath.Join(sourceDir, relPath) + dir := filepath.Dir(fullPath) + + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("Failed to create directory %s: %v", dir, err) + } + + if err := os.WriteFile(fullPath, []byte(content), 0o644); err != nil { + t.Fatalf("Failed to write file %s: %v", fullPath, err) + } + } + + zipFile, err := ZipDirectoryToTempFile(sourceDir) + if err != nil { + if !tt.expectError { + t.Fatalf("Unexpected error: %v", err) + } + return + } + defer os.Remove(zipFile.Name()) + defer zipFile.Close() + + if tt.expectError { + t.Fatal("Expected error but got none") + } + + zipReader, err := zip.OpenReader(zipFile.Name()) + if err != nil { + t.Fatalf("Failed to open zip file: %v", err) + } + defer zipReader.Close() + + expectedFiles := make(map[string]string) + for path, content := range tt.setupFiles { + expectedFiles[filepath.ToSlash(path)] = content + } + + for _, file := range zipReader.File { + expectedContent, exists := expectedFiles[file.Name] + if !exists { + t.Errorf("Unexpected file in zip: %s", file.Name) + continue + } + + rc, err := file.Open() + if err != nil { + t.Errorf("Failed to open file in zip %s: %v", file.Name, err) + continue + } + + content, err := io.ReadAll(rc) + rc.Close() + if err != nil { + t.Errorf("Failed to read file in zip %s: %v", file.Name, err) + continue + } + + if string(content) != expectedContent { + t.Errorf("File %s content mismatch: got %s, want %s", file.Name, content, expectedContent) + } + + delete(expectedFiles, file.Name) + } + + for path := range expectedFiles { + t.Errorf("Missing file in zip: %s", path) + } + }) + } +} + +func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) { + tests := []struct { + name string + sourceDir string + }{ + { + name: "non-existent directory", + sourceDir: "/path/that/does/not/exist", + }, + { + name: "empty path", + sourceDir: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ZipDirectoryToTempFile(tt.sourceDir) + if err == nil { + t.Error("Expected error but got none") + } + }) + } +} diff --git a/pkg/progressbar/progress_test.go b/pkg/progressbar/progress_test.go index 11acae69..7e611e42 100644 --- a/pkg/progressbar/progress_test.go +++ b/pkg/progressbar/progress_test.go @@ -53,16 +53,37 @@ func TestSendAlgorithm(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { pb := New(false) - algobuffer := bytes.NewBufferString("algorithm content") - reqBuffer := bytes.NewBufferString("requirements content") + + algo, err := os.CreateTemp("", "test_algo") + assert.NoError(t, err) + req, err := os.CreateTemp("", "test_req") + assert.NoError(t, err) + + _, err = algo.WriteString("test algorithm") + assert.NoError(t, err) + + err = algo.Close() + assert.NoError(t, err) + + algo, err = os.Open(algo.Name()) + assert.NoError(t, err) + + _, err = req.WriteString("test request") + assert.NoError(t, err) + + err = req.Close() + assert.NoError(t, err) + + req, err = os.Open(req.Name()) + assert.NoError(t, err) algoStream := new(mocks.AgentService_AlgoClient) algoStream.On("Send", mock.Anything).Return(tc.sendError) algoStream.On("CloseAndRecv").Return(&agent.AlgoResponse{}, tc.closeRecvError) mockStream := &mockAlgoStream{stream: algoStream} - err := pb.SendAlgorithm("Test Algorithm", algobuffer, reqBuffer, &mockStream.stream) - assert.True(t, errors.Contains(err, tc.err)) + err = pb.SendAlgorithm("Test Algorithm", algo, req, mockStream.stream) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error: %v, got: %v", tc.err, err)) }) } } @@ -108,14 +129,24 @@ func TestSendData(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { pb := New(false) - buffer := bytes.NewBufferString(tc.dataContent) + dataset, err := os.CreateTemp("", "test_dataset") + assert.NoError(t, err) + + _, err = dataset.WriteString(tc.dataContent) + assert.NoError(t, err) + + err = dataset.Close() + assert.NoError(t, err) + + dataset, err = os.Open(dataset.Name()) + assert.NoError(t, err) dataStream := new(mocks.AgentService_DataClient) dataStream.On("Send", mock.Anything).Return(tc.sendError) dataStream.On("CloseAndRecv").Return(&agent.DataResponse{}, tc.closeRecvError) mockStream := &mockDataStream{stream: dataStream} - err := pb.SendData("Test Data", "test.txt", buffer, &mockStream.stream) + err = pb.SendData("Test Data", "test.txt", dataset, mockStream.stream) assert.True(t, errors.Contains(err, tc.err)) }) } @@ -361,7 +392,7 @@ func TestReceiveResult(t *testing.T) { setupMock: func(m *MockResultStream) { m.On("Recv").Return(nil, io.EOF).Once() }, - wantResult: nil, + wantResult: []byte{}, wantErr: nil, }, } @@ -375,13 +406,26 @@ func TestReceiveResult(t *testing.T) { // Disable terminal width check for tests p.TerminalWidthFunc = func() (int, error) { return 100, nil } - result, err := p.ReceiveResult(tt.description, tt.totalSize, mockStream) + resultFile, err := os.CreateTemp("", "test_result") + assert.NoError(t, err) + + t.Cleanup(func() { + os.Remove(resultFile.Name()) + }) + + err = p.ReceiveResult(tt.description, tt.totalSize, mockStream, resultFile) + + assert.NoError(t, resultFile.Close()) if tt.wantErr != nil { assert.Error(t, err) assert.Equal(t, tt.wantErr.Error(), err.Error()) } else { assert.NoError(t, err) + + result, err := os.ReadFile(resultFile.Name()) + assert.NoError(t, err) + assert.Equal(t, tt.wantResult, result) } @@ -457,13 +501,26 @@ func TestReceiveAttestation(t *testing.T) { // Disable terminal width check for tests p.TerminalWidthFunc = func() (int, error) { return 100, nil } - result, err := p.ReceiveAttestation(tt.description, tt.totalSize, mockStream) + resultFile, err := os.CreateTemp("", "test_attestation") + assert.NoError(t, err) + + t.Cleanup(func() { + os.Remove(resultFile.Name()) + }) + + err = p.ReceiveAttestation(tt.description, tt.totalSize, mockStream, resultFile) + + assert.NoError(t, resultFile.Close()) if tt.wantErr != nil { assert.Error(t, err) assert.Equal(t, tt.wantErr.Error(), err.Error()) } else { assert.NoError(t, err) + + result, err := os.ReadFile(resultFile.Name()) + assert.NoError(t, err) + assert.Equal(t, tt.wantResult, result) } diff --git a/pkg/progressbar/progressbar.go b/pkg/progressbar/progressbar.go index 049593bc..b0187f62 100644 --- a/pkg/progressbar/progressbar.go +++ b/pkg/progressbar/progressbar.go @@ -3,7 +3,6 @@ package progressbar import ( - "bytes" "fmt" "io" "os" @@ -32,7 +31,7 @@ type streamSender interface { } type algoClientWrapper struct { - client *agent.AgentService_AlgoClient + client agent.AgentService_AlgoClient } func (a *algoClientWrapper) Send(req interface{}) error { @@ -41,15 +40,15 @@ func (a *algoClientWrapper) Send(req interface{}) error { return fmt.Errorf("expected *AlgoRequest, got %T", req) } - return (*a.client).Send(algoReq) + return a.client.Send(algoReq) } func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) { - return (*a.client).CloseAndRecv() + return a.client.CloseAndRecv() } type dataClientWrapper struct { - client *agent.AgentService_DataClient + client agent.AgentService_DataClient } func (a *dataClientWrapper) Send(req interface{}) error { @@ -58,11 +57,11 @@ func (a *dataClientWrapper) Send(req interface{}) error { return fmt.Errorf("expected *DataRequest, got %T", req) } - return (*a.client).Send(dataReq) + return a.client.Send(dataReq) } func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) { - return (*a.client).CloseAndRecv() + return a.client.CloseAndRecv() } type ProgressBar struct { @@ -82,21 +81,37 @@ func New(isDownload bool) *ProgressBar { } } -func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error { - totalSize := algobuffer.Len() + reqBuffer.Len() +func (p *ProgressBar) SendAlgorithm(description string, algo, req *os.File, stream agent.AgentService_AlgoClient) error { + algoFileInfo, err := algo.Stat() + if err != nil { + return err + } + + reqSize := 0 + if req != nil { + reqFileInfo, err := req.Stat() + if err != nil { + return err + } + reqSize = int(reqFileInfo.Size()) + } + + totalSize := int(algoFileInfo.Size()) + reqSize p.reset(description, totalSize) wrapper := &algoClientWrapper{client: stream} - // Send reqBuffer first - if err := p.sendBuffer(reqBuffer, wrapper, func(data []byte) interface{} { - return &agent.AlgoRequest{Requirements: data} - }); err != nil { - return err + // Send req first + if req != nil { + if err := p.sendBuffer(req, wrapper, func(data []byte) interface{} { + return &agent.AlgoRequest{Requirements: data} + }); err != nil { + return err + } } - // Then send algobuffer - if err := p.sendBuffer(algobuffer, wrapper, func(data []byte) interface{} { + // Then send algo + if err := p.sendBuffer(algo, wrapper, func(data []byte) interface{} { return &agent.AlgoRequest{Algorithm: data} }); err != nil { return err @@ -106,23 +121,32 @@ func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *b return err } - _, err := wrapper.CloseAndRecv() - return err + _, err = wrapper.CloseAndRecv() + if err != nil { + return err + } + + return nil } -func (p *ProgressBar) SendData(description, filename string, buffer *bytes.Buffer, stream *agent.AgentService_DataClient) error { - return p.sendData(description, buffer, &dataClientWrapper{client: stream}, func(data []byte) interface{} { +func (p *ProgressBar) SendData(description, filename string, file *os.File, stream agent.AgentService_DataClient) error { + return p.sendData(description, file, &dataClientWrapper{client: stream}, func(data []byte) interface{} { return &agent.DataRequest{Dataset: data, Filename: filename} }) } -func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error { - p.reset(description, buffer.Len()) +func (p *ProgressBar) sendData(description string, file *os.File, stream streamSender, createRequest func([]byte) interface{}) error { + dataInfo, err := file.Stat() + if err != nil { + return err + } + + p.reset(description, int(dataInfo.Size())) buf := make([]byte, bufferSize) for { - n, err := buffer.Read(buf) + n, err := file.Read(buf) if err == io.EOF { if _, err := io.WriteString(os.Stdout, "\n"); err != nil { return err @@ -147,15 +171,15 @@ func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream } } - _, err := stream.CloseAndRecv() + _, err = stream.CloseAndRecv() return err } -func (p *ProgressBar) sendBuffer(buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error { +func (p *ProgressBar) sendBuffer(file *os.File, stream streamSender, createRequest func([]byte) interface{}) error { buf := make([]byte, bufferSize) for { - n, err := buffer.Read(buf) + n, err := file.Read(buf) if err == io.EOF { break } @@ -303,7 +327,7 @@ func (p *ProgressBar) clearProgressBar() error { return nil } -func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient) ([]byte, error) { +func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient, resultFile *os.File) error { return p.receiveStream(description, totalSize, func() ([]byte, error) { response, err := stream.Recv() if err != nil { @@ -311,10 +335,10 @@ func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream ag } return response.File, nil - }) + }, resultFile) } -func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient) ([]byte, error) { +func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient, attestationFile *os.File) error { return p.receiveStream(description, totalSize, func() ([]byte, error) { response, err := stream.Recv() if err != nil { @@ -322,37 +346,37 @@ func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stre } return response.File, nil - }) + }, attestationFile) } -func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error)) ([]byte, error) { +func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error), file *os.File) error { p.reset(description, totalSize) p.isDownload = true - var result []byte for { chunk, err := recv() if err == io.EOF { if _, err := io.WriteString(os.Stdout, "\n"); err != nil { - return nil, err + return err } break } if err != nil { - return nil, err + return err } chunkSize := len(chunk) if err = p.updateProgress(chunkSize); err != nil { - return nil, err + return err } - result = append(result, chunk...) - + if _, err := file.Write(chunk); err != nil { + return err + } if err := p.renderProgressBar(); err != nil { - return nil, err + return err } } - return result, nil + return nil } diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index 4dd79678..1bf6c647 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -3,7 +3,6 @@ package sdk import ( - "bytes" "context" "crypto" "crypto/ecdsa" @@ -12,6 +11,7 @@ import ( "crypto/rsa" "crypto/sha256" "encoding/base64" + "os" "strconv" "github.com/absmach/magistrala/pkg/errors" @@ -24,10 +24,10 @@ import ( //go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" type SDK interface { - Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error - Data(ctx context.Context, dataset agent.Dataset, privKey any) error - Result(ctx context.Context, privKey any) ([]byte, error) - Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) + Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error + Data(ctx context.Context, dataset *os.File, filename string, privKey any) error + Result(ctx context.Context, privKey any, resultFile *os.File) error + Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error } const ( @@ -48,7 +48,7 @@ func NewAgentSDK(agentClient agent.AgentServiceClient) SDK { } } -func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error { +func (sdk *agentSDK) Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error { md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey) if err != nil { return err @@ -62,14 +62,12 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe if err != nil { return err } - algoBuffer := bytes.NewBuffer(algorithm.Algorithm) - reqBuffer := bytes.NewBuffer(algorithm.Requirements) pb := progressbar.New(false) - return pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream) + return pb.SendAlgorithm(algoProgressBarDescription, algorithm, requirements, stream) } -func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error { +func (sdk *agentSDK) Data(ctx context.Context, dataset *os.File, filename string, privKey any) error { md, err := generateMetadata(string(auth.DataProviderRole), privKey) if err != nil { return err @@ -83,29 +81,28 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an if err != nil { return err } - dataBuffer := bytes.NewBuffer(dataset.Dataset) pb := progressbar.New(false) - return pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream) + return pb.SendData(dataProgressBarDescription, filename, dataset, stream) } -func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) { +func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.File) error { request := &agent.ResultRequest{} md, err := generateMetadata(string(auth.ConsumerRole), privKey) if err != nil { - return nil, err + return err } ctx = metadata.NewOutgoingContext(ctx, md) stream, err := sdk.client.Result(ctx, request) if err != nil { - return nil, err + return err } incomingmd, err := stream.Header() if err != nil { - return nil, err + return err } fileSizeStr := incomingmd.Get(grpc.FileSizeKey) @@ -116,27 +113,27 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) { fileSize, err := strconv.Atoi(fileSizeStr[0]) if err != nil { - return nil, err + return err } pb := progressbar.New(true) - return pb.ReceiveResult(resultProgressDescription, fileSize, stream) + return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile) } -func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) { +func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error { request := &agent.AttestationRequest{ ReportData: reportData[:], } stream, err := sdk.client.Attestation(ctx, request) if err != nil { - return nil, err + return err } incomingmd, err := stream.Header() if err != nil { - return nil, err + return err } fileSizeStr := incomingmd.Get(grpc.FileSizeKey) @@ -147,12 +144,12 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ( fileSize, err := strconv.Atoi(fileSizeStr[0]) if err != nil { - return nil, err + return err } pb := progressbar.New(true) - return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream) + return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream, attestationFile) } func signData(userID string, privKey crypto.Signer) ([]byte, error) { diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go index d84582af..c3379722 100644 --- a/pkg/sdk/agent_test.go +++ b/pkg/sdk/agent_test.go @@ -115,7 +115,21 @@ func TestAlgo(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err) - err = sdk.Algo(context.Background(), tc.algo, tc.userKey) + + algo, err := os.CreateTemp("", "algo") + require.NoError(t, err) + defer os.Remove(algo.Name()) + + _, err = algo.Write(algorithm.Algorithm) + require.NoError(t, err) + + err = algo.Close() + require.NoError(t, err) + + algo, err = os.Open(algo.Name()) + require.NoError(t, err) + + err = sdk.Algo(context.Background(), algo, nil, tc.userKey) st, _ := status.FromError(err) @@ -212,7 +226,19 @@ func TestData(t *testing.T) { t.Run(tc.name, func(t *testing.T) { dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr) - err = sdk.Data(context.Background(), tc.data, tc.userKey) + data, err := os.CreateTemp("", "data") + require.NoError(t, err) + + _, err = data.Write(dataset.Dataset) + require.NoError(t, err) + + err = data.Close() + require.NoError(t, err) + + data, err = os.Open(data.Name()) + require.NoError(t, err) + + err = sdk.Data(context.Background(), data, tc.data.Filename, tc.userKey) st, _ := status.FromError(err) @@ -273,7 +299,7 @@ func TestResult(t *testing.T) { name: "Results not ready", userKey: resultConsumer1Key, response: &agent.ResultResponse{ - File: []byte(nil), + File: []byte{}, }, svcRes: nil, err: agent.ErrResultsNotReady, @@ -282,7 +308,7 @@ func TestResult(t *testing.T) { name: "All manifest items received", userKey: resultConsumer1Key, response: &agent.ResultResponse{ - File: []byte(nil), + File: []byte{}, }, svcRes: nil, err: agent.ErrAllManifestItemsReceived, @@ -291,7 +317,7 @@ func TestResult(t *testing.T) { name: "Undeclared consumer", userKey: resultConsumer1Key, response: &agent.ResultResponse{ - File: []byte(nil), + File: []byte{}, }, svcRes: nil, err: agent.ErrUndeclaredConsumer, @@ -300,7 +326,17 @@ func TestResult(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) - res, err := sdk.Result(context.Background(), tc.userKey) + + resultFile, err := os.CreateTemp("", "result") + require.NoError(t, err) + + t.Cleanup(func() { + os.Remove(resultFile.Name()) + }) + + err = sdk.Result(context.Background(), tc.userKey, resultFile) + + require.NoError(t, resultFile.Close()) st, ok := status.FromError(err) if !ok { @@ -312,6 +348,10 @@ func TestResult(t *testing.T) { t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message()) } } + + res, err := os.ReadFile(resultFile.Name()) + require.NoError(t, err) + assert.Equal(t, tc.response.File, res, tc.name) svcCall.Unset() @@ -375,7 +415,7 @@ func TestAttestation(t *testing.T) { userKey: resultConsumerKey, reportData: [agent.ReportDataSize]byte(reportData), response: &agent.AttestationResponse{ - File: nil, + File: []byte{}, }, err: nil, }, @@ -384,7 +424,7 @@ func TestAttestation(t *testing.T) { userKey: resultConsumerKey, reportData: [agent.ReportDataSize]byte{}, response: &agent.AttestationResponse{ - File: nil, + File: []byte{}, }, svcRes: nil, err: errors.New("invalid report data"), @@ -395,7 +435,16 @@ func TestAttestation(t *testing.T) { t.Run(tc.name, func(t *testing.T) { svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) - res, err := sdk.Attestation(context.Background(), tc.reportData) + file, err := os.CreateTemp("", "attestation") + require.NoError(t, err) + + t.Cleanup(func() { + os.Remove(file.Name()) + }) + + err = sdk.Attestation(context.Background(), tc.reportData, file) + + require.NoError(t, file.Close()) st, ok := status.FromError(err) if !ok { @@ -408,6 +457,9 @@ func TestAttestation(t *testing.T) { } } + res, err := os.ReadFile(file.Name()) + require.NoError(t, err) + assert.Equal(t, tc.response.File, res, tc.name) svcCall.Unset() diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index be6efd1a..ca59c7e3 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -7,8 +7,7 @@ package mocks import ( context "context" - - agent "github.com/ultravioletrs/cocos/agent" + os "os" mock "github.com/stretchr/testify/mock" ) @@ -18,17 +17,17 @@ type SDK struct { mock.Mock } -// Algo provides a mock function with given fields: ctx, algorithm, privKey -func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey interface{}) error { - ret := _m.Called(ctx, algorithm, privKey) +// Algo provides a mock function with given fields: ctx, algorithm, requirements, privKey +func (_m *SDK) Algo(ctx context.Context, algorithm *os.File, requirements *os.File, privKey interface{}) error { + ret := _m.Called(ctx, algorithm, requirements, privKey) if len(ret) == 0 { panic("no return value specified for Algo") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm, interface{}) error); ok { - r0 = rf(ctx, algorithm, privKey) + if rf, ok := ret.Get(0).(func(context.Context, *os.File, *os.File, interface{}) error); ok { + r0 = rf(ctx, algorithm, requirements, privKey) } else { r0 = ret.Error(0) } @@ -36,47 +35,35 @@ func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey inte return r0 } -// Attestation provides a mock function with given fields: ctx, reportData -func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) { - ret := _m.Called(ctx, reportData) +// Attestation provides a mock function with given fields: ctx, reportData, attestationFile +func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error { + ret := _m.Called(ctx, reportData, attestationFile) if len(ret) == 0 { panic("no return value specified for Attestation") } - var r0 []byte - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok { - return rf(ctx, reportData) - } - if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok { - r0 = rf(ctx, reportData) + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok { + r0 = rf(ctx, reportData, attestationFile) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } + r0 = ret.Error(0) } - if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok { - r1 = rf(ctx, reportData) - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } -// Data provides a mock function with given fields: ctx, dataset, privKey -func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interface{}) error { - ret := _m.Called(ctx, dataset, privKey) +// Data provides a mock function with given fields: ctx, dataset, filename, privKey +func (_m *SDK) Data(ctx context.Context, dataset *os.File, filename string, privKey interface{}) error { + ret := _m.Called(ctx, dataset, filename, privKey) if len(ret) == 0 { panic("no return value specified for Data") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset, interface{}) error); ok { - r0 = rf(ctx, dataset, privKey) + if rf, ok := ret.Get(0).(func(context.Context, *os.File, string, interface{}) error); ok { + r0 = rf(ctx, dataset, filename, privKey) } else { r0 = ret.Error(0) } @@ -84,34 +71,22 @@ func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interfac return r0 } -// Result provides a mock function with given fields: ctx, privKey -func (_m *SDK) Result(ctx context.Context, privKey interface{}) ([]byte, error) { - ret := _m.Called(ctx, privKey) +// Result provides a mock function with given fields: ctx, privKey, resultFile +func (_m *SDK) Result(ctx context.Context, privKey interface{}, resultFile *os.File) error { + ret := _m.Called(ctx, privKey, resultFile) if len(ret) == 0 { panic("no return value specified for Result") } - var r0 []byte - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, interface{}) ([]byte, error)); ok { - return rf(ctx, privKey) - } - if rf, ok := ret.Get(0).(func(context.Context, interface{}) []byte); ok { - r0 = rf(ctx, privKey) + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interface{}, *os.File) error); ok { + r0 = rf(ctx, privKey, resultFile) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } + r0 = ret.Error(0) } - if rf, ok := ret.Get(1).(func(context.Context, interface{}) error); ok { - r1 = rf(ctx, privKey) - } else { - r1 = ret.Error(1) - } - - return r0, r1 + return r0 } // NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.