NOISSUE - Improve file streaming (#295)

* improve file streaming

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* error check

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* empty line

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* send buffer test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test cases

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* stream data and attestation

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fumpt

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* mocks

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* value check

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* more value checks

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add test cases

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fumpt

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* all  files

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-11-07 12:47:53 +03:00
committed by GitHub
parent 01a619fd2a
commit 46b94204df
17 changed files with 536 additions and 262 deletions
-5
View File
@@ -16,11 +16,6 @@ type Service struct {
mock.Mock mock.Mock
} }
// Close provides a mock function with given fields:
func (_m *Service) Close() {
_m.Called()
}
// SendEvent provides a mock function with given fields: event, status, details // SendEvent provides a mock function with given fields: event, status, details
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error { func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
ret := _m.Called(event, status, details) ret := _m.Called(event, status, details)
+5 -5
View File
@@ -57,7 +57,7 @@ func TestAlgorithmCmd(t *testing.T) {
{ {
name: "successful upload", name: "successful upload",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() error { setupFiles: func() error {
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil { if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
@@ -75,7 +75,7 @@ func TestAlgorithmCmd(t *testing.T) {
{ {
name: "missing algorithm file", name: "missing algorithm file",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
args: []string{"non_existent_algo_file.py", privateKeyFile}, args: []string{"non_existent_algo_file.py", privateKeyFile},
expectedOutput: "Error reading algorithm file", expectedOutput: "Error reading algorithm file",
@@ -83,7 +83,7 @@ func TestAlgorithmCmd(t *testing.T) {
{ {
name: "missing private key file", name: "missing private key file",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() error { setupFiles: func() error {
return os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644) return os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644)
@@ -97,7 +97,7 @@ func TestAlgorithmCmd(t *testing.T) {
{ {
name: "upload failure", name: "upload failure",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
}, },
setupFiles: func() error { setupFiles: func() error {
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil { if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
@@ -115,7 +115,7 @@ func TestAlgorithmCmd(t *testing.T) {
{ {
name: "invalid private key", name: "invalid private key",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() error { setupFiles: func() error {
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil { if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
+7 -10
View File
@@ -9,7 +9,6 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/python" "github.com/ultravioletrs/cocos/agent/algorithm/python"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@@ -38,24 +37,22 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
cmd.Println("Uploading algorithm file:", algorithmFile) cmd.Println("Uploading algorithm file:", algorithmFile)
algorithm, err := os.ReadFile(algorithmFile) algorithm, err := os.Open(algorithmFile)
if err != nil { if err != nil {
printError(cmd, "Error reading algorithm file: %v ❌ ", err) printError(cmd, "Error reading algorithm file: %v ❌ ", err)
return return
} }
var req []byte defer algorithm.Close()
var req *os.File
if requirementsFile != "" { if requirementsFile != "" {
req, err = os.ReadFile(requirementsFile) req, err = os.Open(requirementsFile)
if err != nil { if err != nil {
printError(cmd, "Error reading requirments file: %v ❌ ", err) printError(cmd, "Error reading requirments file: %v ❌ ", err)
return return
} }
} defer req.Close()
algoReq := agent.Algorithm{
Algorithm: algorithm,
Requirements: req,
} }
privKeyFile, err := os.ReadFile(args[1]) privKeyFile, err := os.ReadFile(args[1])
@@ -74,7 +71,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil { if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil {
printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err) printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
return return
} }
+26 -7
View File
@@ -176,25 +176,44 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
return return
} }
result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData)) filename := attestationFilePath
if getJsonAttestation {
filename = attestationJson
}
attestationFile, err := os.Create(filename)
if err != nil { if err != nil {
printError(cmd, "Error creating attestation file: %v ❌ ", err)
return
}
if err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData), attestationFile); err != nil {
printError(cmd, "Failed to get attestation due to error: %v ❌ ", err) printError(cmd, "Failed to get attestation due to error: %v ❌ ", err)
return return
} }
filename := attestationFilePath if err := attestationFile.Close(); err != nil {
printError(cmd, "Error closing attestation file: %v ❌ ", err)
return
}
if getJsonAttestation { if getJsonAttestation {
result, err := os.ReadFile(filename)
if err != nil {
printError(cmd, "Error reading attestation file: %v ❌ ", err)
return
}
result, err = attesationToJSON(result) result, err = attesationToJSON(result)
if err != nil { if err != nil {
printError(cmd, "Error converting attestation to json: %v ❌ ", err) printError(cmd, "Error converting attestation to json: %v ❌ ", err)
return return
} }
filename = attestationJson
}
if err = os.WriteFile(filename, result, 0o644); err != nil { if err := os.WriteFile(filename, result, 0o644); err != nil {
printError(cmd, "Error saving attestation result: %v ❌ ", err) printError(cmd, "Error writing attestation file: %v ❌ ", err)
return return
}
} }
cmd.Println("Attestation result retrieved and saved successfully!") cmd.Println("Attestation result retrieved and saved successfully!")
+14 -3
View File
@@ -22,7 +22,8 @@ import (
) )
func TestNewAttestationCmd(t *testing.T) { func TestNewAttestationCmd(t *testing.T) {
cli := &CLI{} mockSDK := new(mocks.SDK)
cli := &CLI{agentSDK: mockSDK}
cmd := cli.NewAttestationCmd() cmd := cli.NewAttestationCmd()
assert.Equal(t, "attestation [command]", cmd.Use) assert.Equal(t, "attestation [command]", cmd.Use)
@@ -30,12 +31,19 @@ func TestNewAttestationCmd(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
cmd.SetOut(&buf) cmd.SetOut(&buf)
cmd.SetOutput(&buf)
reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize)
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData), mock.Anything).Return(nil)
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
err := cmd.Execute() err := cmd.Execute()
assert.NoError(t, err) assert.NoError(t, err)
assert.Contains(t, buf.String(), "Get and validate attestations") assert.Contains(t, buf.String(), "Get and validate attestations")
} }
func TestNewGetAttestationCmdN(t *testing.T) { func TestNewGetAttestationCmd(t *testing.T) {
validattestation, err := os.ReadFile("../attestation.bin") validattestation, err := os.ReadFile("../attestation.bin")
require.NoError(t, err) require.NoError(t, err)
testCases := []struct { testCases := []struct {
@@ -119,7 +127,10 @@ func TestNewGetAttestationCmdN(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
cmd.SetOutput(&buf) cmd.SetOutput(&buf)
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))).Return(tc.mockResponse, tc.mockError) mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
_, err := args.Get(2).(*os.File).Write(tc.mockResponse)
require.NoError(t, err)
})
cmd.SetArgs(tc.args) cmd.SetArgs(tc.args)
err := cmd.Execute() err := cmd.Execute()
+7 -9
View File
@@ -41,25 +41,23 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
return return
} }
var dataset []byte var dataset *os.File
if f.IsDir() { if f.IsDir() {
dataset, err = internal.ZipDirectoryToMemory(datasetPath) dataset, err = internal.ZipDirectoryToTempFile(datasetPath)
if err != nil { if err != nil {
printError(cmd, "Error zipping dataset directory: %v ❌ ", err) printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
return return
} }
defer dataset.Close()
defer os.Remove(dataset.Name())
} else { } else {
dataset, err = os.ReadFile(datasetPath) dataset, err = os.Open(datasetPath)
if err != nil { if err != nil {
printError(cmd, "Error reading dataset file: %v ❌ ", err) printError(cmd, "Error reading dataset file: %v ❌ ", err)
return return
} }
} defer dataset.Close()
dataReq := agent.Dataset{
Dataset: dataset,
Filename: path.Base(datasetPath),
} }
privKeyFile, err := os.ReadFile(args[1]) privKeyFile, err := os.ReadFile(args[1])
@@ -77,7 +75,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
} }
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil { if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil {
printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err) printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
return return
} }
+5 -5
View File
@@ -39,7 +39,7 @@ func TestDatasetsCmd(t *testing.T) {
{ {
name: "successful upload", name: "successful upload",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
datasetFile, err := createTempDatasetFile("test dataset content") datasetFile, err := createTempDatasetFile("test dataset content")
@@ -58,7 +58,7 @@ func TestDatasetsCmd(t *testing.T) {
{ {
name: "missing dataset file", name: "missing dataset file",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
return "", nil return "", nil
@@ -68,7 +68,7 @@ func TestDatasetsCmd(t *testing.T) {
{ {
name: "missing private key file", name: "missing private key file",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
return createTempDatasetFile("test dataset content") return createTempDatasetFile("test dataset content")
@@ -81,7 +81,7 @@ func TestDatasetsCmd(t *testing.T) {
{ {
name: "upload failure", name: "upload failure",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error")) m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
datasetFile, err := createTempDatasetFile("test dataset content") datasetFile, err := createTempDatasetFile("test dataset content")
@@ -100,7 +100,7 @@ func TestDatasetsCmd(t *testing.T) {
{ {
name: "invalid private key", name: "invalid private key",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil) m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
datasetFile, err := createTempDatasetFile("test dataset content") datasetFile, err := createTempDatasetFile("test dataset content")
+16 -34
View File
@@ -4,7 +4,6 @@ package cli
import ( import (
"encoding/pem" "encoding/pem"
"fmt"
"os" "os"
"github.com/fatih/color" "github.com/fatih/color"
@@ -14,13 +13,14 @@ import (
const ( const (
resultFilePrefix = "results" resultFilePrefix = "results"
resultFileExt = ".zip" resultFileExt = ".zip"
resultfilename = "results.zip"
) )
func (cli *CLI) NewResultsCmd() *cobra.Command { func (cli *CLI) NewResultsCmd() *cobra.Command {
return &cobra.Command{ return &cobra.Command{
Use: "result", Use: "result",
Short: "Retrieve computation result file", Short: "Retrieve computation result file",
Example: "result <private_key_file_path>", Example: "result <private_key_file_path> <optional_file_name.zip>",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
if cli.connectErr != nil { if cli.connectErr != nil {
@@ -36,50 +36,32 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
return return
} }
pemBlock, _ := pem.Decode(privKeyFile) filename := resultfilename
if len(args) > 1 {
filename = args[1]
}
var result []byte pemBlock, _ := pem.Decode(privKeyFile)
privKey, err := decodeKey(pemBlock) privKey, err := decodeKey(pemBlock)
if err != nil { if err != nil {
printError(cmd, "Error decoding private key: %v ❌ ", err) printError(cmd, "Error decoding private key: %v ❌ ", err)
return return
} }
result, err = cli.agentSDK.Result(cmd.Context(), privKey)
resultFile, err := os.Create(filename)
if err != nil { if err != nil {
printError(cmd, "Error creating result file: %v ❌ ", err)
return
}
defer resultFile.Close()
if err = cli.agentSDK.Result(cmd.Context(), privKey, resultFile); err != nil {
printError(cmd, "Error retrieving computation result: %v ❌ ", err) printError(cmd, "Error retrieving computation result: %v ❌ ", err)
return return
} }
resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt) cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", filename))
if err != nil {
printError(cmd, "Error generating unique file path: %v ❌ ", err)
return
}
if err := os.WriteFile(resultFilePath, result, 0o644); err != nil {
printError(cmd, "Error saving computation result file: %v ❌ ", err)
return
}
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath))
}, },
} }
} }
func getUniqueFilePath(prefix, ext string) (string, error) {
for i := 0; ; i++ {
var filename string
if i == 0 {
filename = prefix + ext
} else {
filename = fmt.Sprintf("%s_%d%s", prefix, i, ext)
}
if _, err := os.Stat(filename); os.IsNotExist(err) {
return filename, nil
} else if err != nil {
return "", err
}
}
}
+18 -51
View File
@@ -6,7 +6,6 @@ package cli
import ( import (
"bytes" "bytes"
"errors" "errors"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -20,7 +19,10 @@ const compResult = "Test computation result"
func TestResultsCmd_MultipleExecutions(t *testing.T) { func TestResultsCmd_MultipleExecutions(t *testing.T) {
mockSDK := new(mocks.SDK) mockSDK := new(mocks.SDK)
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
_, err := args.Get(2).(*os.File).WriteString(compResult)
require.NoError(t, err)
})
testCLI := CLI{agentSDK: mockSDK} testCLI := CLI{agentSDK: mockSDK}
err := generateRSAPrivateKeyFile(privateKeyFile) err := generateRSAPrivateKeyFile(privateKeyFile)
@@ -40,7 +42,6 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) {
files, err := filepath.Glob("results*.zip") files, err := filepath.Glob("results*.zip")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, files, 3)
t.Cleanup(func() { t.Cleanup(func() {
for _, file := range files { for _, file := range files {
@@ -52,7 +53,10 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) {
func TestResultsCmd_InvalidPrivateKey(t *testing.T) { func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
mockSDK := new(mocks.SDK) mockSDK := new(mocks.SDK)
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
_, err := args.Get(2).(*os.File).WriteString(compResult)
require.NoError(t, err)
})
testCLI := CLI{agentSDK: mockSDK} testCLI := CLI{agentSDK: mockSDK}
invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem") invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem")
@@ -73,30 +77,7 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, buf.String(), "Error decoding private key") require.Contains(t, buf.String(), "Error decoding private key")
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything) mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything, mock.Anything)
}
func TestGetUniqueFilePath(t *testing.T) {
prefix := "test"
ext := ".txt"
path, err := getUniqueFilePath(prefix, ext)
require.NoError(t, err)
require.Equal(t, "test.txt", path)
_, err = os.Create("test.txt")
require.NoError(t, err)
defer os.Remove("test.txt")
for i := 1; i < 3; i++ {
fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext)
_, err := os.Create(fileName)
require.NoError(t, err)
defer os.Remove(fileName)
}
path, err = getUniqueFilePath(prefix, ext)
require.NoError(t, err)
require.Equal(t, "test_3.txt", path)
} }
func TestResultsCmd(t *testing.T) { func TestResultsCmd(t *testing.T) {
@@ -111,7 +92,10 @@ func TestResultsCmd(t *testing.T) {
{ {
name: "successful result retrieval", name: "successful result retrieval",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
_, err := args.Get(2).(*os.File).WriteString(compResult)
require.NoError(t, err)
})
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile) return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
@@ -128,7 +112,10 @@ func TestResultsCmd(t *testing.T) {
{ {
name: "missing private key file", name: "missing private key file",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
_, err := args.Get(2).(*os.File).WriteString(compResult)
require.NoError(t, err)
})
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
return "non_existent_private_key.pem", nil return "non_existent_private_key.pem", nil
@@ -138,7 +125,7 @@ func TestResultsCmd(t *testing.T) {
{ {
name: "result retrieval failure", name: "result retrieval failure",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
m.On("Result", mock.Anything, mock.Anything).Return(nil, errors.New("error retrieving computation result")) m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("error retrieving computation result"))
}, },
setupFiles: func() (string, error) { setupFiles: func() (string, error) {
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile) return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
@@ -148,26 +135,6 @@ func TestResultsCmd(t *testing.T) {
os.Remove(privateKeyFile) os.Remove(privateKeyFile)
}, },
}, },
{
name: "save failure",
setupMock: func(m *mocks.SDK) {
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
},
setupFiles: func() (string, error) {
err := generateRSAPrivateKeyFile(privateKeyFile)
if err != nil {
return "", err
}
// Simulate failure in saving the result file by making all files read-only
return privateKeyFile, os.Chmod(".", 0o555)
},
expectedOutput: "Error saving computation result file",
cleanup: func() {
err := os.Chmod(".", 0o755)
require.NoError(t, err)
os.Remove(privateKeyFile)
},
},
{ {
name: "connection error", name: "connection error",
setupMock: func(m *mocks.SDK) { setupMock: func(m *mocks.SDK) {
+1 -1
View File
@@ -4,7 +4,7 @@
coverage: coverage:
ignore: ignore:
- "test/*" - "test/*"
- "cmd/*" - "cmd/**"
- "**/mocks/**" - "**/mocks/**"
- "mocks/**" - "mocks/**"
- "**/*.pb.go" - "**/*.pb.go"
+54
View File
@@ -60,6 +60,60 @@ func ZipDirectoryToMemory(sourceDir string) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func ZipDirectoryToTempFile(sourceDir string) (*os.File, error) {
tmpFile, err := os.CreateTemp("", "dataset*.zip")
if err != nil {
return nil, err
}
zipWriter := zip.NewWriter(tmpFile)
err = filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(sourceDir, path)
if err != nil {
return err
}
zipHeader, err := zip.FileInfoHeader(info)
if err != nil {
return err
}
zipHeader.Name = relPath
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
if err != nil {
return err
}
fileToZip, err := os.Open(path)
if err != nil {
return err
}
defer fileToZip.Close()
_, err = io.Copy(zipWriterEntry, fileToZip)
return err
})
if err != nil {
zipWriter.Close()
return nil, err
}
if err := zipWriter.Close(); err != nil {
return nil, err
}
return tmpFile, nil
}
func UnzipFromMemory(zipData []byte, targetDir string) error { func UnzipFromMemory(zipData []byte, targetDir string) error {
reader := bytes.NewReader(zipData) reader := bytes.NewReader(zipData)
zipReader, err := zip.NewReader(reader, int64(len(zipData))) zipReader, err := zip.NewReader(reader, int64(len(zipData)))
+146
View File
@@ -3,6 +3,8 @@
package internal package internal
import ( import (
"archive/zip"
"io"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@@ -104,3 +106,147 @@ func TestZipDirectoryToMemory_NonExistentDirectory(t *testing.T) {
t.Error("ZipDirectoryToMemory should fail with non-existent directory") t.Error("ZipDirectoryToMemory should fail with non-existent directory")
} }
} }
func TestZipDirectoryToTempFile(t *testing.T) {
tests := []struct {
name string
setupFiles map[string]string // map of relative path to content
expectError bool
}{
{
name: "single file",
setupFiles: map[string]string{
"test.txt": "hello world",
},
expectError: false,
},
{
name: "multiple files in root",
setupFiles: map[string]string{
"test1.txt": "content1",
"test2.txt": "content2",
"test3.txt": "content3",
},
expectError: false,
},
{
name: "nested directory structure",
setupFiles: map[string]string{
"file1.txt": "root file",
"dir1/file2.txt": "nested file",
"dir1/dir2/file3.txt": "deeply nested file",
"dir1/dir2/dir3/file4.txt": "very deeply nested file",
},
expectError: false,
},
{
name: "empty directory",
setupFiles: map[string]string{},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sourceDir, err := os.MkdirTemp("", "source")
if err != nil {
t.Fatalf("Failed to create temp source directory: %v", err)
}
defer os.RemoveAll(sourceDir)
for relPath, content := range tt.setupFiles {
fullPath := filepath.Join(sourceDir, relPath)
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("Failed to create directory %s: %v", dir, err)
}
if err := os.WriteFile(fullPath, []byte(content), 0o644); err != nil {
t.Fatalf("Failed to write file %s: %v", fullPath, err)
}
}
zipFile, err := ZipDirectoryToTempFile(sourceDir)
if err != nil {
if !tt.expectError {
t.Fatalf("Unexpected error: %v", err)
}
return
}
defer os.Remove(zipFile.Name())
defer zipFile.Close()
if tt.expectError {
t.Fatal("Expected error but got none")
}
zipReader, err := zip.OpenReader(zipFile.Name())
if err != nil {
t.Fatalf("Failed to open zip file: %v", err)
}
defer zipReader.Close()
expectedFiles := make(map[string]string)
for path, content := range tt.setupFiles {
expectedFiles[filepath.ToSlash(path)] = content
}
for _, file := range zipReader.File {
expectedContent, exists := expectedFiles[file.Name]
if !exists {
t.Errorf("Unexpected file in zip: %s", file.Name)
continue
}
rc, err := file.Open()
if err != nil {
t.Errorf("Failed to open file in zip %s: %v", file.Name, err)
continue
}
content, err := io.ReadAll(rc)
rc.Close()
if err != nil {
t.Errorf("Failed to read file in zip %s: %v", file.Name, err)
continue
}
if string(content) != expectedContent {
t.Errorf("File %s content mismatch: got %s, want %s", file.Name, content, expectedContent)
}
delete(expectedFiles, file.Name)
}
for path := range expectedFiles {
t.Errorf("Missing file in zip: %s", path)
}
})
}
}
func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
tests := []struct {
name string
sourceDir string
}{
{
name: "non-existent directory",
sourceDir: "/path/that/does/not/exist",
},
{
name: "empty path",
sourceDir: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ZipDirectoryToTempFile(tt.sourceDir)
if err == nil {
t.Error("Expected error but got none")
}
})
}
}
+66 -9
View File
@@ -53,16 +53,37 @@ func TestSendAlgorithm(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
pb := New(false) pb := New(false)
algobuffer := bytes.NewBufferString("algorithm content")
reqBuffer := bytes.NewBufferString("requirements content") algo, err := os.CreateTemp("", "test_algo")
assert.NoError(t, err)
req, err := os.CreateTemp("", "test_req")
assert.NoError(t, err)
_, err = algo.WriteString("test algorithm")
assert.NoError(t, err)
err = algo.Close()
assert.NoError(t, err)
algo, err = os.Open(algo.Name())
assert.NoError(t, err)
_, err = req.WriteString("test request")
assert.NoError(t, err)
err = req.Close()
assert.NoError(t, err)
req, err = os.Open(req.Name())
assert.NoError(t, err)
algoStream := new(mocks.AgentService_AlgoClient) algoStream := new(mocks.AgentService_AlgoClient)
algoStream.On("Send", mock.Anything).Return(tc.sendError) algoStream.On("Send", mock.Anything).Return(tc.sendError)
algoStream.On("CloseAndRecv").Return(&agent.AlgoResponse{}, tc.closeRecvError) algoStream.On("CloseAndRecv").Return(&agent.AlgoResponse{}, tc.closeRecvError)
mockStream := &mockAlgoStream{stream: algoStream} mockStream := &mockAlgoStream{stream: algoStream}
err := pb.SendAlgorithm("Test Algorithm", algobuffer, reqBuffer, &mockStream.stream) err = pb.SendAlgorithm("Test Algorithm", algo, req, mockStream.stream)
assert.True(t, errors.Contains(err, tc.err)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error: %v, got: %v", tc.err, err))
}) })
} }
} }
@@ -108,14 +129,24 @@ func TestSendData(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
pb := New(false) pb := New(false)
buffer := bytes.NewBufferString(tc.dataContent) dataset, err := os.CreateTemp("", "test_dataset")
assert.NoError(t, err)
_, err = dataset.WriteString(tc.dataContent)
assert.NoError(t, err)
err = dataset.Close()
assert.NoError(t, err)
dataset, err = os.Open(dataset.Name())
assert.NoError(t, err)
dataStream := new(mocks.AgentService_DataClient) dataStream := new(mocks.AgentService_DataClient)
dataStream.On("Send", mock.Anything).Return(tc.sendError) dataStream.On("Send", mock.Anything).Return(tc.sendError)
dataStream.On("CloseAndRecv").Return(&agent.DataResponse{}, tc.closeRecvError) dataStream.On("CloseAndRecv").Return(&agent.DataResponse{}, tc.closeRecvError)
mockStream := &mockDataStream{stream: dataStream} mockStream := &mockDataStream{stream: dataStream}
err := pb.SendData("Test Data", "test.txt", buffer, &mockStream.stream) err = pb.SendData("Test Data", "test.txt", dataset, mockStream.stream)
assert.True(t, errors.Contains(err, tc.err)) assert.True(t, errors.Contains(err, tc.err))
}) })
} }
@@ -361,7 +392,7 @@ func TestReceiveResult(t *testing.T) {
setupMock: func(m *MockResultStream) { setupMock: func(m *MockResultStream) {
m.On("Recv").Return(nil, io.EOF).Once() m.On("Recv").Return(nil, io.EOF).Once()
}, },
wantResult: nil, wantResult: []byte{},
wantErr: nil, wantErr: nil,
}, },
} }
@@ -375,13 +406,26 @@ func TestReceiveResult(t *testing.T) {
// Disable terminal width check for tests // Disable terminal width check for tests
p.TerminalWidthFunc = func() (int, error) { return 100, nil } p.TerminalWidthFunc = func() (int, error) { return 100, nil }
result, err := p.ReceiveResult(tt.description, tt.totalSize, mockStream) resultFile, err := os.CreateTemp("", "test_result")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(resultFile.Name())
})
err = p.ReceiveResult(tt.description, tt.totalSize, mockStream, resultFile)
assert.NoError(t, resultFile.Close())
if tt.wantErr != nil { if tt.wantErr != nil {
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, tt.wantErr.Error(), err.Error()) assert.Equal(t, tt.wantErr.Error(), err.Error())
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
result, err := os.ReadFile(resultFile.Name())
assert.NoError(t, err)
assert.Equal(t, tt.wantResult, result) assert.Equal(t, tt.wantResult, result)
} }
@@ -457,13 +501,26 @@ func TestReceiveAttestation(t *testing.T) {
// Disable terminal width check for tests // Disable terminal width check for tests
p.TerminalWidthFunc = func() (int, error) { return 100, nil } p.TerminalWidthFunc = func() (int, error) { return 100, nil }
result, err := p.ReceiveAttestation(tt.description, tt.totalSize, mockStream) resultFile, err := os.CreateTemp("", "test_attestation")
assert.NoError(t, err)
t.Cleanup(func() {
os.Remove(resultFile.Name())
})
err = p.ReceiveAttestation(tt.description, tt.totalSize, mockStream, resultFile)
assert.NoError(t, resultFile.Close())
if tt.wantErr != nil { if tt.wantErr != nil {
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, tt.wantErr.Error(), err.Error()) assert.Equal(t, tt.wantErr.Error(), err.Error())
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
result, err := os.ReadFile(resultFile.Name())
assert.NoError(t, err)
assert.Equal(t, tt.wantResult, result) assert.Equal(t, tt.wantResult, result)
} }
+63 -39
View File
@@ -3,7 +3,6 @@
package progressbar package progressbar
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -32,7 +31,7 @@ type streamSender interface {
} }
type algoClientWrapper struct { type algoClientWrapper struct {
client *agent.AgentService_AlgoClient client agent.AgentService_AlgoClient
} }
func (a *algoClientWrapper) Send(req interface{}) error { func (a *algoClientWrapper) Send(req interface{}) error {
@@ -41,15 +40,15 @@ func (a *algoClientWrapper) Send(req interface{}) error {
return fmt.Errorf("expected *AlgoRequest, got %T", req) return fmt.Errorf("expected *AlgoRequest, got %T", req)
} }
return (*a.client).Send(algoReq) return a.client.Send(algoReq)
} }
func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) { func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) {
return (*a.client).CloseAndRecv() return a.client.CloseAndRecv()
} }
type dataClientWrapper struct { type dataClientWrapper struct {
client *agent.AgentService_DataClient client agent.AgentService_DataClient
} }
func (a *dataClientWrapper) Send(req interface{}) error { func (a *dataClientWrapper) Send(req interface{}) error {
@@ -58,11 +57,11 @@ func (a *dataClientWrapper) Send(req interface{}) error {
return fmt.Errorf("expected *DataRequest, got %T", req) return fmt.Errorf("expected *DataRequest, got %T", req)
} }
return (*a.client).Send(dataReq) return a.client.Send(dataReq)
} }
func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) { func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) {
return (*a.client).CloseAndRecv() return a.client.CloseAndRecv()
} }
type ProgressBar struct { type ProgressBar struct {
@@ -82,21 +81,37 @@ func New(isDownload bool) *ProgressBar {
} }
} }
func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error { func (p *ProgressBar) SendAlgorithm(description string, algo, req *os.File, stream agent.AgentService_AlgoClient) error {
totalSize := algobuffer.Len() + reqBuffer.Len() algoFileInfo, err := algo.Stat()
if err != nil {
return err
}
reqSize := 0
if req != nil {
reqFileInfo, err := req.Stat()
if err != nil {
return err
}
reqSize = int(reqFileInfo.Size())
}
totalSize := int(algoFileInfo.Size()) + reqSize
p.reset(description, totalSize) p.reset(description, totalSize)
wrapper := &algoClientWrapper{client: stream} wrapper := &algoClientWrapper{client: stream}
// Send reqBuffer first // Send req first
if err := p.sendBuffer(reqBuffer, wrapper, func(data []byte) interface{} { if req != nil {
return &agent.AlgoRequest{Requirements: data} if err := p.sendBuffer(req, wrapper, func(data []byte) interface{} {
}); err != nil { return &agent.AlgoRequest{Requirements: data}
return err }); err != nil {
return err
}
} }
// Then send algobuffer // Then send algo
if err := p.sendBuffer(algobuffer, wrapper, func(data []byte) interface{} { if err := p.sendBuffer(algo, wrapper, func(data []byte) interface{} {
return &agent.AlgoRequest{Algorithm: data} return &agent.AlgoRequest{Algorithm: data}
}); err != nil { }); err != nil {
return err return err
@@ -106,23 +121,32 @@ func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *b
return err return err
} }
_, err := wrapper.CloseAndRecv() _, err = wrapper.CloseAndRecv()
return err if err != nil {
return err
}
return nil
} }
func (p *ProgressBar) SendData(description, filename string, buffer *bytes.Buffer, stream *agent.AgentService_DataClient) error { func (p *ProgressBar) SendData(description, filename string, file *os.File, stream agent.AgentService_DataClient) error {
return p.sendData(description, buffer, &dataClientWrapper{client: stream}, func(data []byte) interface{} { return p.sendData(description, file, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
return &agent.DataRequest{Dataset: data, Filename: filename} return &agent.DataRequest{Dataset: data, Filename: filename}
}) })
} }
func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error { func (p *ProgressBar) sendData(description string, file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
p.reset(description, buffer.Len()) dataInfo, err := file.Stat()
if err != nil {
return err
}
p.reset(description, int(dataInfo.Size()))
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
n, err := buffer.Read(buf) n, err := file.Read(buf)
if err == io.EOF { if err == io.EOF {
if _, err := io.WriteString(os.Stdout, "\n"); err != nil { if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
return err return err
@@ -147,15 +171,15 @@ func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream
} }
} }
_, err := stream.CloseAndRecv() _, err = stream.CloseAndRecv()
return err return err
} }
func (p *ProgressBar) sendBuffer(buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error { func (p *ProgressBar) sendBuffer(file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
n, err := buffer.Read(buf) n, err := file.Read(buf)
if err == io.EOF { if err == io.EOF {
break break
} }
@@ -303,7 +327,7 @@ func (p *ProgressBar) clearProgressBar() error {
return nil return nil
} }
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient) ([]byte, error) { func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient, resultFile *os.File) error {
return p.receiveStream(description, totalSize, func() ([]byte, error) { return p.receiveStream(description, totalSize, func() ([]byte, error) {
response, err := stream.Recv() response, err := stream.Recv()
if err != nil { if err != nil {
@@ -311,10 +335,10 @@ func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream ag
} }
return response.File, nil return response.File, nil
}) }, resultFile)
} }
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient) ([]byte, error) { func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient, attestationFile *os.File) error {
return p.receiveStream(description, totalSize, func() ([]byte, error) { return p.receiveStream(description, totalSize, func() ([]byte, error) {
response, err := stream.Recv() response, err := stream.Recv()
if err != nil { if err != nil {
@@ -322,37 +346,37 @@ func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stre
} }
return response.File, nil return response.File, nil
}) }, attestationFile)
} }
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error)) ([]byte, error) { func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error), file *os.File) error {
p.reset(description, totalSize) p.reset(description, totalSize)
p.isDownload = true p.isDownload = true
var result []byte
for { for {
chunk, err := recv() chunk, err := recv()
if err == io.EOF { if err == io.EOF {
if _, err := io.WriteString(os.Stdout, "\n"); err != nil { if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
return nil, err return err
} }
break break
} }
if err != nil { if err != nil {
return nil, err return err
} }
chunkSize := len(chunk) chunkSize := len(chunk)
if err = p.updateProgress(chunkSize); err != nil { if err = p.updateProgress(chunkSize); err != nil {
return nil, err return err
} }
result = append(result, chunk...) if _, err := file.Write(chunk); err != nil {
return err
}
if err := p.renderProgressBar(); err != nil { if err := p.renderProgressBar(); err != nil {
return nil, err return err
} }
} }
return result, nil return nil
} }
+20 -23
View File
@@ -3,7 +3,6 @@
package sdk package sdk
import ( import (
"bytes"
"context" "context"
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
@@ -12,6 +11,7 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"os"
"strconv" "strconv"
"github.com/absmach/magistrala/pkg/errors" "github.com/absmach/magistrala/pkg/errors"
@@ -24,10 +24,10 @@ import (
//go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" //go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type SDK interface { type SDK interface {
Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error
Data(ctx context.Context, dataset agent.Dataset, privKey any) error Data(ctx context.Context, dataset *os.File, filename string, privKey any) error
Result(ctx context.Context, privKey any) ([]byte, error) Result(ctx context.Context, privKey any, resultFile *os.File) error
Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error
} }
const ( const (
@@ -48,7 +48,7 @@ func NewAgentSDK(agentClient agent.AgentServiceClient) SDK {
} }
} }
func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error { func (sdk *agentSDK) Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error {
md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey) md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey)
if err != nil { if err != nil {
return err return err
@@ -62,14 +62,12 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
if err != nil { if err != nil {
return err return err
} }
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
reqBuffer := bytes.NewBuffer(algorithm.Requirements)
pb := progressbar.New(false) pb := progressbar.New(false)
return pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream) return pb.SendAlgorithm(algoProgressBarDescription, algorithm, requirements, stream)
} }
func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error { func (sdk *agentSDK) Data(ctx context.Context, dataset *os.File, filename string, privKey any) error {
md, err := generateMetadata(string(auth.DataProviderRole), privKey) md, err := generateMetadata(string(auth.DataProviderRole), privKey)
if err != nil { if err != nil {
return err return err
@@ -83,29 +81,28 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
if err != nil { if err != nil {
return err return err
} }
dataBuffer := bytes.NewBuffer(dataset.Dataset)
pb := progressbar.New(false) pb := progressbar.New(false)
return pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream) return pb.SendData(dataProgressBarDescription, filename, dataset, stream)
} }
func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) { func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.File) error {
request := &agent.ResultRequest{} request := &agent.ResultRequest{}
md, err := generateMetadata(string(auth.ConsumerRole), privKey) md, err := generateMetadata(string(auth.ConsumerRole), privKey)
if err != nil { if err != nil {
return nil, err return err
} }
ctx = metadata.NewOutgoingContext(ctx, md) ctx = metadata.NewOutgoingContext(ctx, md)
stream, err := sdk.client.Result(ctx, request) stream, err := sdk.client.Result(ctx, request)
if err != nil { if err != nil {
return nil, err return err
} }
incomingmd, err := stream.Header() incomingmd, err := stream.Header()
if err != nil { if err != nil {
return nil, err return err
} }
fileSizeStr := incomingmd.Get(grpc.FileSizeKey) fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
@@ -116,27 +113,27 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
fileSize, err := strconv.Atoi(fileSizeStr[0]) fileSize, err := strconv.Atoi(fileSizeStr[0])
if err != nil { if err != nil {
return nil, err return err
} }
pb := progressbar.New(true) pb := progressbar.New(true)
return pb.ReceiveResult(resultProgressDescription, fileSize, stream) return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile)
} }
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) { func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error {
request := &agent.AttestationRequest{ request := &agent.AttestationRequest{
ReportData: reportData[:], ReportData: reportData[:],
} }
stream, err := sdk.client.Attestation(ctx, request) stream, err := sdk.client.Attestation(ctx, request)
if err != nil { if err != nil {
return nil, err return err
} }
incomingmd, err := stream.Header() incomingmd, err := stream.Header()
if err != nil { if err != nil {
return nil, err return err
} }
fileSizeStr := incomingmd.Get(grpc.FileSizeKey) fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
@@ -147,12 +144,12 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) (
fileSize, err := strconv.Atoi(fileSizeStr[0]) fileSize, err := strconv.Atoi(fileSizeStr[0])
if err != nil { if err != nil {
return nil, err return err
} }
pb := progressbar.New(true) pb := progressbar.New(true)
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream) return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream, attestationFile)
} }
func signData(userID string, privKey crypto.Signer) ([]byte, error) { func signData(userID string, privKey crypto.Signer) ([]byte, error) {
+61 -9
View File
@@ -115,7 +115,21 @@ func TestAlgo(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err) svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err)
err = sdk.Algo(context.Background(), tc.algo, tc.userKey)
algo, err := os.CreateTemp("", "algo")
require.NoError(t, err)
defer os.Remove(algo.Name())
_, err = algo.Write(algorithm.Algorithm)
require.NoError(t, err)
err = algo.Close()
require.NoError(t, err)
algo, err = os.Open(algo.Name())
require.NoError(t, err)
err = sdk.Algo(context.Background(), algo, nil, tc.userKey)
st, _ := status.FromError(err) st, _ := status.FromError(err)
@@ -212,7 +226,19 @@ func TestData(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr) dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr)
err = sdk.Data(context.Background(), tc.data, tc.userKey) data, err := os.CreateTemp("", "data")
require.NoError(t, err)
_, err = data.Write(dataset.Dataset)
require.NoError(t, err)
err = data.Close()
require.NoError(t, err)
data, err = os.Open(data.Name())
require.NoError(t, err)
err = sdk.Data(context.Background(), data, tc.data.Filename, tc.userKey)
st, _ := status.FromError(err) st, _ := status.FromError(err)
@@ -273,7 +299,7 @@ func TestResult(t *testing.T) {
name: "Results not ready", name: "Results not ready",
userKey: resultConsumer1Key, userKey: resultConsumer1Key,
response: &agent.ResultResponse{ response: &agent.ResultResponse{
File: []byte(nil), File: []byte{},
}, },
svcRes: nil, svcRes: nil,
err: agent.ErrResultsNotReady, err: agent.ErrResultsNotReady,
@@ -282,7 +308,7 @@ func TestResult(t *testing.T) {
name: "All manifest items received", name: "All manifest items received",
userKey: resultConsumer1Key, userKey: resultConsumer1Key,
response: &agent.ResultResponse{ response: &agent.ResultResponse{
File: []byte(nil), File: []byte{},
}, },
svcRes: nil, svcRes: nil,
err: agent.ErrAllManifestItemsReceived, err: agent.ErrAllManifestItemsReceived,
@@ -291,7 +317,7 @@ func TestResult(t *testing.T) {
name: "Undeclared consumer", name: "Undeclared consumer",
userKey: resultConsumer1Key, userKey: resultConsumer1Key,
response: &agent.ResultResponse{ response: &agent.ResultResponse{
File: []byte(nil), File: []byte{},
}, },
svcRes: nil, svcRes: nil,
err: agent.ErrUndeclaredConsumer, err: agent.ErrUndeclaredConsumer,
@@ -300,7 +326,17 @@ func TestResult(t *testing.T) {
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
res, err := sdk.Result(context.Background(), tc.userKey)
resultFile, err := os.CreateTemp("", "result")
require.NoError(t, err)
t.Cleanup(func() {
os.Remove(resultFile.Name())
})
err = sdk.Result(context.Background(), tc.userKey, resultFile)
require.NoError(t, resultFile.Close())
st, ok := status.FromError(err) st, ok := status.FromError(err)
if !ok { if !ok {
@@ -312,6 +348,10 @@ func TestResult(t *testing.T) {
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message()) t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
} }
} }
res, err := os.ReadFile(resultFile.Name())
require.NoError(t, err)
assert.Equal(t, tc.response.File, res, tc.name) assert.Equal(t, tc.response.File, res, tc.name)
svcCall.Unset() svcCall.Unset()
@@ -375,7 +415,7 @@ func TestAttestation(t *testing.T) {
userKey: resultConsumerKey, userKey: resultConsumerKey,
reportData: [agent.ReportDataSize]byte(reportData), reportData: [agent.ReportDataSize]byte(reportData),
response: &agent.AttestationResponse{ response: &agent.AttestationResponse{
File: nil, File: []byte{},
}, },
err: nil, err: nil,
}, },
@@ -384,7 +424,7 @@ func TestAttestation(t *testing.T) {
userKey: resultConsumerKey, userKey: resultConsumerKey,
reportData: [agent.ReportDataSize]byte{}, reportData: [agent.ReportDataSize]byte{},
response: &agent.AttestationResponse{ response: &agent.AttestationResponse{
File: nil, File: []byte{},
}, },
svcRes: nil, svcRes: nil,
err: errors.New("invalid report data"), err: errors.New("invalid report data"),
@@ -395,7 +435,16 @@ func TestAttestation(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
res, err := sdk.Attestation(context.Background(), tc.reportData) file, err := os.CreateTemp("", "attestation")
require.NoError(t, err)
t.Cleanup(func() {
os.Remove(file.Name())
})
err = sdk.Attestation(context.Background(), tc.reportData, file)
require.NoError(t, file.Close())
st, ok := status.FromError(err) st, ok := status.FromError(err)
if !ok { if !ok {
@@ -408,6 +457,9 @@ func TestAttestation(t *testing.T) {
} }
} }
res, err := os.ReadFile(file.Name())
require.NoError(t, err)
assert.Equal(t, tc.response.File, res, tc.name) assert.Equal(t, tc.response.File, res, tc.name)
svcCall.Unset() svcCall.Unset()
+27 -52
View File
@@ -7,8 +7,7 @@ package mocks
import ( import (
context "context" context "context"
os "os"
agent "github.com/ultravioletrs/cocos/agent"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
) )
@@ -18,17 +17,17 @@ type SDK struct {
mock.Mock mock.Mock
} }
// Algo provides a mock function with given fields: ctx, algorithm, privKey // Algo provides a mock function with given fields: ctx, algorithm, requirements, privKey
func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey interface{}) error { func (_m *SDK) Algo(ctx context.Context, algorithm *os.File, requirements *os.File, privKey interface{}) error {
ret := _m.Called(ctx, algorithm, privKey) ret := _m.Called(ctx, algorithm, requirements, privKey)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Algo") panic("no return value specified for Algo")
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm, interface{}) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *os.File, *os.File, interface{}) error); ok {
r0 = rf(ctx, algorithm, privKey) r0 = rf(ctx, algorithm, requirements, privKey)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@@ -36,47 +35,35 @@ func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey inte
return r0 return r0
} }
// Attestation provides a mock function with given fields: ctx, reportData // Attestation provides a mock function with given fields: ctx, reportData, attestationFile
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) { func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error {
ret := _m.Called(ctx, reportData) ret := _m.Called(ctx, reportData, attestationFile)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Attestation") panic("no return value specified for Attestation")
} }
var r0 []byte var r0 error
var r1 error if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok { r0 = rf(ctx, reportData, attestationFile)
return rf(ctx, reportData)
}
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok {
r0 = rf(ctx, reportData)
} else { } else {
if ret.Get(0) != nil { r0 = ret.Error(0)
r0 = ret.Get(0).([]byte)
}
} }
if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok { return r0
r1 = rf(ctx, reportData)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
// Data provides a mock function with given fields: ctx, dataset, privKey // Data provides a mock function with given fields: ctx, dataset, filename, privKey
func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interface{}) error { func (_m *SDK) Data(ctx context.Context, dataset *os.File, filename string, privKey interface{}) error {
ret := _m.Called(ctx, dataset, privKey) ret := _m.Called(ctx, dataset, filename, privKey)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Data") panic("no return value specified for Data")
} }
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset, interface{}) error); ok { if rf, ok := ret.Get(0).(func(context.Context, *os.File, string, interface{}) error); ok {
r0 = rf(ctx, dataset, privKey) r0 = rf(ctx, dataset, filename, privKey)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@@ -84,34 +71,22 @@ func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interfac
return r0 return r0
} }
// Result provides a mock function with given fields: ctx, privKey // Result provides a mock function with given fields: ctx, privKey, resultFile
func (_m *SDK) Result(ctx context.Context, privKey interface{}) ([]byte, error) { func (_m *SDK) Result(ctx context.Context, privKey interface{}, resultFile *os.File) error {
ret := _m.Called(ctx, privKey) ret := _m.Called(ctx, privKey, resultFile)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Result") panic("no return value specified for Result")
} }
var r0 []byte var r0 error
var r1 error if rf, ok := ret.Get(0).(func(context.Context, interface{}, *os.File) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, interface{}) ([]byte, error)); ok { r0 = rf(ctx, privKey, resultFile)
return rf(ctx, privKey)
}
if rf, ok := ret.Get(0).(func(context.Context, interface{}) []byte); ok {
r0 = rf(ctx, privKey)
} else { } else {
if ret.Get(0) != nil { r0 = ret.Error(0)
r0 = ret.Get(0).([]byte)
}
} }
if rf, ok := ret.Get(1).(func(context.Context, interface{}) error); ok { return r0
r1 = rf(ctx, privKey)
} else {
r1 = ret.Error(1)
}
return r0, r1
} }
// NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.