mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Improve file streaming (#295)
* improve file streaming Signed-off-by: Sammy Oina <sammyoina@gmail.com> * error check Signed-off-by: Sammy Oina <sammyoina@gmail.com> * empty line Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * send buffer test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * stream data and attestation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fumpt Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * mocks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * value check Signed-off-by: Sammy Oina <sammyoina@gmail.com> * more value checks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fumpt Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * all files Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
01a619fd2a
commit
46b94204df
@@ -16,11 +16,6 @@ type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *Service) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: event, status, details
|
||||
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
|
||||
ret := _m.Called(event, status, details)
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
||||
{
|
||||
name: "successful upload",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() error {
|
||||
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
||||
@@ -75,7 +75,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
||||
{
|
||||
name: "missing algorithm file",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
args: []string{"non_existent_algo_file.py", privateKeyFile},
|
||||
expectedOutput: "Error reading algorithm file",
|
||||
@@ -83,7 +83,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
||||
{
|
||||
name: "missing private key file",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() error {
|
||||
return os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644)
|
||||
@@ -97,7 +97,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
||||
{
|
||||
name: "upload failure",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
||||
},
|
||||
setupFiles: func() error {
|
||||
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
||||
@@ -115,7 +115,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
||||
{
|
||||
name: "invalid private key",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() error {
|
||||
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
||||
|
||||
+7
-10
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -38,24 +37,22 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
cmd.Println("Uploading algorithm file:", algorithmFile)
|
||||
|
||||
algorithm, err := os.ReadFile(algorithmFile)
|
||||
algorithm, err := os.Open(algorithmFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req []byte
|
||||
defer algorithm.Close()
|
||||
|
||||
var req *os.File
|
||||
if requirementsFile != "" {
|
||||
req, err = os.ReadFile(requirementsFile)
|
||||
req, err = os.Open(requirementsFile)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
algoReq := agent.Algorithm{
|
||||
Algorithm: algorithm,
|
||||
Requirements: req,
|
||||
defer req.Close()
|
||||
}
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[1])
|
||||
@@ -74,7 +71,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
||||
|
||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||
|
||||
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil {
|
||||
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil {
|
||||
printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
+25
-6
@@ -176,26 +176,45 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData))
|
||||
filename := attestationFilePath
|
||||
if getJsonAttestation {
|
||||
filename = attestationJson
|
||||
}
|
||||
|
||||
attestationFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData), attestationFile); err != nil {
|
||||
printError(cmd, "Failed to get attestation due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
filename := attestationFilePath
|
||||
if err := attestationFile.Close(); err != nil {
|
||||
printError(cmd, "Error closing attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if getJsonAttestation {
|
||||
result, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
result, err = attesationToJSON(result)
|
||||
if err != nil {
|
||||
printError(cmd, "Error converting attestation to json: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
filename = attestationJson
|
||||
}
|
||||
|
||||
if err = os.WriteFile(filename, result, 0o644); err != nil {
|
||||
printError(cmd, "Error saving attestation result: %v ❌ ", err)
|
||||
if err := os.WriteFile(filename, result, 0o644); err != nil {
|
||||
printError(cmd, "Error writing attestation file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Attestation result retrieved and saved successfully!")
|
||||
},
|
||||
|
||||
+14
-3
@@ -22,7 +22,8 @@ import (
|
||||
)
|
||||
|
||||
func TestNewAttestationCmd(t *testing.T) {
|
||||
cli := &CLI{}
|
||||
mockSDK := new(mocks.SDK)
|
||||
cli := &CLI{agentSDK: mockSDK}
|
||||
cmd := cli.NewAttestationCmd()
|
||||
|
||||
assert.Equal(t, "attestation [command]", cmd.Use)
|
||||
@@ -30,12 +31,19 @@ func TestNewAttestationCmd(t *testing.T) {
|
||||
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOut(&buf)
|
||||
|
||||
cmd.SetOutput(&buf)
|
||||
|
||||
reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize)
|
||||
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData), mock.Anything).Return(nil)
|
||||
|
||||
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
|
||||
err := cmd.Execute()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, buf.String(), "Get and validate attestations")
|
||||
}
|
||||
|
||||
func TestNewGetAttestationCmdN(t *testing.T) {
|
||||
func TestNewGetAttestationCmd(t *testing.T) {
|
||||
validattestation, err := os.ReadFile("../attestation.bin")
|
||||
require.NoError(t, err)
|
||||
testCases := []struct {
|
||||
@@ -119,7 +127,10 @@ func TestNewGetAttestationCmdN(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
cmd.SetOutput(&buf)
|
||||
|
||||
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))).Return(tc.mockResponse, tc.mockError)
|
||||
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(2).(*os.File).Write(tc.mockResponse)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
cmd.SetArgs(tc.args)
|
||||
err := cmd.Execute()
|
||||
|
||||
+7
-9
@@ -41,25 +41,23 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
var dataset []byte
|
||||
var dataset *os.File
|
||||
|
||||
if f.IsDir() {
|
||||
dataset, err = internal.ZipDirectoryToMemory(datasetPath)
|
||||
dataset, err = internal.ZipDirectoryToTempFile(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer dataset.Close()
|
||||
defer os.Remove(dataset.Name())
|
||||
} else {
|
||||
dataset, err = os.ReadFile(datasetPath)
|
||||
dataset, err = os.Open(datasetPath)
|
||||
if err != nil {
|
||||
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
dataReq := agent.Dataset{
|
||||
Dataset: dataset,
|
||||
Filename: path.Base(datasetPath),
|
||||
defer dataset.Close()
|
||||
}
|
||||
|
||||
privKeyFile, err := os.ReadFile(args[1])
|
||||
@@ -77,7 +75,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil {
|
||||
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil {
|
||||
printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func TestDatasetsCmd(t *testing.T) {
|
||||
{
|
||||
name: "successful upload",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
datasetFile, err := createTempDatasetFile("test dataset content")
|
||||
@@ -58,7 +58,7 @@ func TestDatasetsCmd(t *testing.T) {
|
||||
{
|
||||
name: "missing dataset file",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
return "", nil
|
||||
@@ -68,7 +68,7 @@ func TestDatasetsCmd(t *testing.T) {
|
||||
{
|
||||
name: "missing private key file",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
return createTempDatasetFile("test dataset content")
|
||||
@@ -81,7 +81,7 @@ func TestDatasetsCmd(t *testing.T) {
|
||||
{
|
||||
name: "upload failure",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
datasetFile, err := createTempDatasetFile("test dataset content")
|
||||
@@ -100,7 +100,7 @@ func TestDatasetsCmd(t *testing.T) {
|
||||
{
|
||||
name: "invalid private key",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
datasetFile, err := createTempDatasetFile("test dataset content")
|
||||
|
||||
+16
-34
@@ -4,7 +4,6 @@ package cli
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/fatih/color"
|
||||
@@ -14,13 +13,14 @@ import (
|
||||
const (
|
||||
resultFilePrefix = "results"
|
||||
resultFileExt = ".zip"
|
||||
resultfilename = "results.zip"
|
||||
)
|
||||
|
||||
func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "result",
|
||||
Short: "Retrieve computation result file",
|
||||
Example: "result <private_key_file_path>",
|
||||
Example: "result <private_key_file_path> <optional_file_name.zip>",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if cli.connectErr != nil {
|
||||
@@ -36,50 +36,32 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
pemBlock, _ := pem.Decode(privKeyFile)
|
||||
filename := resultfilename
|
||||
if len(args) > 1 {
|
||||
filename = args[1]
|
||||
}
|
||||
|
||||
var result []byte
|
||||
pemBlock, _ := pem.Decode(privKeyFile)
|
||||
|
||||
privKey, err := decodeKey(pemBlock)
|
||||
if err != nil {
|
||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
result, err = cli.agentSDK.Result(cmd.Context(), privKey)
|
||||
|
||||
resultFile, err := os.Create(filename)
|
||||
if err != nil {
|
||||
printError(cmd, "Error creating result file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
defer resultFile.Close()
|
||||
|
||||
if err = cli.agentSDK.Result(cmd.Context(), privKey, resultFile); err != nil {
|
||||
printError(cmd, "Error retrieving computation result: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating unique file path: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resultFilePath, result, 0o644); err != nil {
|
||||
printError(cmd, "Error saving computation result file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath))
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", filename))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getUniqueFilePath(prefix, ext string) (string, error) {
|
||||
for i := 0; ; i++ {
|
||||
var filename string
|
||||
if i == 0 {
|
||||
filename = prefix + ext
|
||||
} else {
|
||||
filename = fmt.Sprintf("%s_%d%s", prefix, i, ext)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
||||
return filename, nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+18
-51
@@ -6,7 +6,6 @@ package cli
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -20,7 +19,10 @@ const compResult = "Test computation result"
|
||||
|
||||
func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
||||
mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
testCLI := CLI{agentSDK: mockSDK}
|
||||
|
||||
err := generateRSAPrivateKeyFile(privateKeyFile)
|
||||
@@ -40,7 +42,6 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
||||
|
||||
files, err := filepath.Glob("results*.zip")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 3)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, file := range files {
|
||||
@@ -52,7 +53,10 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
||||
|
||||
func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
||||
mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
testCLI := CLI{agentSDK: mockSDK}
|
||||
|
||||
invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem")
|
||||
@@ -73,30 +77,7 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, buf.String(), "Error decoding private key")
|
||||
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestGetUniqueFilePath(t *testing.T) {
|
||||
prefix := "test"
|
||||
ext := ".txt"
|
||||
|
||||
path, err := getUniqueFilePath(prefix, ext)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test.txt", path)
|
||||
|
||||
_, err = os.Create("test.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("test.txt")
|
||||
for i := 1; i < 3; i++ {
|
||||
fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext)
|
||||
_, err := os.Create(fileName)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(fileName)
|
||||
}
|
||||
|
||||
path, err = getUniqueFilePath(prefix, ext)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test_3.txt", path)
|
||||
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestResultsCmd(t *testing.T) {
|
||||
@@ -111,7 +92,10 @@ func TestResultsCmd(t *testing.T) {
|
||||
{
|
||||
name: "successful result retrieval",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
||||
m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
|
||||
@@ -128,7 +112,10 @@ func TestResultsCmd(t *testing.T) {
|
||||
{
|
||||
name: "missing private key file",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
||||
m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
return "non_existent_private_key.pem", nil
|
||||
@@ -138,7 +125,7 @@ func TestResultsCmd(t *testing.T) {
|
||||
{
|
||||
name: "result retrieval failure",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Result", mock.Anything, mock.Anything).Return(nil, errors.New("error retrieving computation result"))
|
||||
m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("error retrieving computation result"))
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
|
||||
@@ -148,26 +135,6 @@ func TestResultsCmd(t *testing.T) {
|
||||
os.Remove(privateKeyFile)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "save failure",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
||||
},
|
||||
setupFiles: func() (string, error) {
|
||||
err := generateRSAPrivateKeyFile(privateKeyFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Simulate failure in saving the result file by making all files read-only
|
||||
return privateKeyFile, os.Chmod(".", 0o555)
|
||||
},
|
||||
expectedOutput: "Error saving computation result file",
|
||||
cleanup: func() {
|
||||
err := os.Chmod(".", 0o755)
|
||||
require.NoError(t, err)
|
||||
os.Remove(privateKeyFile)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "connection error",
|
||||
setupMock: func(m *mocks.SDK) {
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@
|
||||
coverage:
|
||||
ignore:
|
||||
- "test/*"
|
||||
- "cmd/*"
|
||||
- "cmd/**"
|
||||
- "**/mocks/**"
|
||||
- "mocks/**"
|
||||
- "**/*.pb.go"
|
||||
|
||||
@@ -60,6 +60,60 @@ func ZipDirectoryToMemory(sourceDir string) ([]byte, error) {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func ZipDirectoryToTempFile(sourceDir string) (*os.File, error) {
|
||||
tmpFile, err := os.CreateTemp("", "dataset*.zip")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zipWriter := zip.NewWriter(tmpFile)
|
||||
|
||||
err = filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(sourceDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zipHeader, err := zip.FileInfoHeader(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zipHeader.Name = relPath
|
||||
|
||||
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileToZip, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fileToZip.Close()
|
||||
|
||||
_, err = io.Copy(zipWriterEntry, fileToZip)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
zipWriter.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := zipWriter.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tmpFile, nil
|
||||
}
|
||||
|
||||
func UnzipFromMemory(zipData []byte, targetDir string) error {
|
||||
reader := bytes.NewReader(zipData)
|
||||
zipReader, err := zip.NewReader(reader, int64(len(zipData)))
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -104,3 +106,147 @@ func TestZipDirectoryToMemory_NonExistentDirectory(t *testing.T) {
|
||||
t.Error("ZipDirectoryToMemory should fail with non-existent directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipDirectoryToTempFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFiles map[string]string // map of relative path to content
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "single file",
|
||||
setupFiles: map[string]string{
|
||||
"test.txt": "hello world",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple files in root",
|
||||
setupFiles: map[string]string{
|
||||
"test1.txt": "content1",
|
||||
"test2.txt": "content2",
|
||||
"test3.txt": "content3",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nested directory structure",
|
||||
setupFiles: map[string]string{
|
||||
"file1.txt": "root file",
|
||||
"dir1/file2.txt": "nested file",
|
||||
"dir1/dir2/file3.txt": "deeply nested file",
|
||||
"dir1/dir2/dir3/file4.txt": "very deeply nested file",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty directory",
|
||||
setupFiles: map[string]string{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sourceDir, err := os.MkdirTemp("", "source")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp source directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(sourceDir)
|
||||
|
||||
for relPath, content := range tt.setupFiles {
|
||||
fullPath := filepath.Join(sourceDir, relPath)
|
||||
dir := filepath.Dir(fullPath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatalf("Failed to create directory %s: %v", dir, err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(fullPath, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write file %s: %v", fullPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
zipFile, err := ZipDirectoryToTempFile(sourceDir)
|
||||
if err != nil {
|
||||
if !tt.expectError {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer os.Remove(zipFile.Name())
|
||||
defer zipFile.Close()
|
||||
|
||||
if tt.expectError {
|
||||
t.Fatal("Expected error but got none")
|
||||
}
|
||||
|
||||
zipReader, err := zip.OpenReader(zipFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open zip file: %v", err)
|
||||
}
|
||||
defer zipReader.Close()
|
||||
|
||||
expectedFiles := make(map[string]string)
|
||||
for path, content := range tt.setupFiles {
|
||||
expectedFiles[filepath.ToSlash(path)] = content
|
||||
}
|
||||
|
||||
for _, file := range zipReader.File {
|
||||
expectedContent, exists := expectedFiles[file.Name]
|
||||
if !exists {
|
||||
t.Errorf("Unexpected file in zip: %s", file.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to open file in zip %s: %v", file.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(rc)
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read file in zip %s: %v", file.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if string(content) != expectedContent {
|
||||
t.Errorf("File %s content mismatch: got %s, want %s", file.Name, content, expectedContent)
|
||||
}
|
||||
|
||||
delete(expectedFiles, file.Name)
|
||||
}
|
||||
|
||||
for path := range expectedFiles {
|
||||
t.Errorf("Missing file in zip: %s", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceDir string
|
||||
}{
|
||||
{
|
||||
name: "non-existent directory",
|
||||
sourceDir: "/path/that/does/not/exist",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
sourceDir: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ZipDirectoryToTempFile(tt.sourceDir)
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,16 +53,37 @@ func TestSendAlgorithm(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pb := New(false)
|
||||
algobuffer := bytes.NewBufferString("algorithm content")
|
||||
reqBuffer := bytes.NewBufferString("requirements content")
|
||||
|
||||
algo, err := os.CreateTemp("", "test_algo")
|
||||
assert.NoError(t, err)
|
||||
req, err := os.CreateTemp("", "test_req")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = algo.WriteString("test algorithm")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = algo.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
algo, err = os.Open(algo.Name())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = req.WriteString("test request")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = req.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
req, err = os.Open(req.Name())
|
||||
assert.NoError(t, err)
|
||||
|
||||
algoStream := new(mocks.AgentService_AlgoClient)
|
||||
algoStream.On("Send", mock.Anything).Return(tc.sendError)
|
||||
algoStream.On("CloseAndRecv").Return(&agent.AlgoResponse{}, tc.closeRecvError)
|
||||
mockStream := &mockAlgoStream{stream: algoStream}
|
||||
|
||||
err := pb.SendAlgorithm("Test Algorithm", algobuffer, reqBuffer, &mockStream.stream)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
err = pb.SendAlgorithm("Test Algorithm", algo, req, mockStream.stream)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error: %v, got: %v", tc.err, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -108,14 +129,24 @@ func TestSendData(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
pb := New(false)
|
||||
buffer := bytes.NewBufferString(tc.dataContent)
|
||||
dataset, err := os.CreateTemp("", "test_dataset")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = dataset.WriteString(tc.dataContent)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = dataset.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
dataset, err = os.Open(dataset.Name())
|
||||
assert.NoError(t, err)
|
||||
|
||||
dataStream := new(mocks.AgentService_DataClient)
|
||||
dataStream.On("Send", mock.Anything).Return(tc.sendError)
|
||||
dataStream.On("CloseAndRecv").Return(&agent.DataResponse{}, tc.closeRecvError)
|
||||
mockStream := &mockDataStream{stream: dataStream}
|
||||
|
||||
err := pb.SendData("Test Data", "test.txt", buffer, &mockStream.stream)
|
||||
err = pb.SendData("Test Data", "test.txt", dataset, mockStream.stream)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
})
|
||||
}
|
||||
@@ -361,7 +392,7 @@ func TestReceiveResult(t *testing.T) {
|
||||
setupMock: func(m *MockResultStream) {
|
||||
m.On("Recv").Return(nil, io.EOF).Once()
|
||||
},
|
||||
wantResult: nil,
|
||||
wantResult: []byte{},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
@@ -375,13 +406,26 @@ func TestReceiveResult(t *testing.T) {
|
||||
// Disable terminal width check for tests
|
||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||
|
||||
result, err := p.ReceiveResult(tt.description, tt.totalSize, mockStream)
|
||||
resultFile, err := os.CreateTemp("", "test_result")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(resultFile.Name())
|
||||
})
|
||||
|
||||
err = p.ReceiveResult(tt.description, tt.totalSize, mockStream, resultFile)
|
||||
|
||||
assert.NoError(t, resultFile.Close())
|
||||
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := os.ReadFile(resultFile.Name())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
}
|
||||
|
||||
@@ -457,13 +501,26 @@ func TestReceiveAttestation(t *testing.T) {
|
||||
// Disable terminal width check for tests
|
||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||
|
||||
result, err := p.ReceiveAttestation(tt.description, tt.totalSize, mockStream)
|
||||
resultFile, err := os.CreateTemp("", "test_attestation")
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(resultFile.Name())
|
||||
})
|
||||
|
||||
err = p.ReceiveAttestation(tt.description, tt.totalSize, mockStream, resultFile)
|
||||
|
||||
assert.NoError(t, resultFile.Close())
|
||||
|
||||
if tt.wantErr != nil {
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := os.ReadFile(resultFile.Name())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.wantResult, result)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package progressbar
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@@ -32,7 +31,7 @@ type streamSender interface {
|
||||
}
|
||||
|
||||
type algoClientWrapper struct {
|
||||
client *agent.AgentService_AlgoClient
|
||||
client agent.AgentService_AlgoClient
|
||||
}
|
||||
|
||||
func (a *algoClientWrapper) Send(req interface{}) error {
|
||||
@@ -41,15 +40,15 @@ func (a *algoClientWrapper) Send(req interface{}) error {
|
||||
return fmt.Errorf("expected *AlgoRequest, got %T", req)
|
||||
}
|
||||
|
||||
return (*a.client).Send(algoReq)
|
||||
return a.client.Send(algoReq)
|
||||
}
|
||||
|
||||
func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) {
|
||||
return (*a.client).CloseAndRecv()
|
||||
return a.client.CloseAndRecv()
|
||||
}
|
||||
|
||||
type dataClientWrapper struct {
|
||||
client *agent.AgentService_DataClient
|
||||
client agent.AgentService_DataClient
|
||||
}
|
||||
|
||||
func (a *dataClientWrapper) Send(req interface{}) error {
|
||||
@@ -58,11 +57,11 @@ func (a *dataClientWrapper) Send(req interface{}) error {
|
||||
return fmt.Errorf("expected *DataRequest, got %T", req)
|
||||
}
|
||||
|
||||
return (*a.client).Send(dataReq)
|
||||
return a.client.Send(dataReq)
|
||||
}
|
||||
|
||||
func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) {
|
||||
return (*a.client).CloseAndRecv()
|
||||
return a.client.CloseAndRecv()
|
||||
}
|
||||
|
||||
type ProgressBar struct {
|
||||
@@ -82,21 +81,37 @@ func New(isDownload bool) *ProgressBar {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error {
|
||||
totalSize := algobuffer.Len() + reqBuffer.Len()
|
||||
func (p *ProgressBar) SendAlgorithm(description string, algo, req *os.File, stream agent.AgentService_AlgoClient) error {
|
||||
algoFileInfo, err := algo.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqSize := 0
|
||||
if req != nil {
|
||||
reqFileInfo, err := req.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reqSize = int(reqFileInfo.Size())
|
||||
}
|
||||
|
||||
totalSize := int(algoFileInfo.Size()) + reqSize
|
||||
p.reset(description, totalSize)
|
||||
|
||||
wrapper := &algoClientWrapper{client: stream}
|
||||
|
||||
// Send reqBuffer first
|
||||
if err := p.sendBuffer(reqBuffer, wrapper, func(data []byte) interface{} {
|
||||
// Send req first
|
||||
if req != nil {
|
||||
if err := p.sendBuffer(req, wrapper, func(data []byte) interface{} {
|
||||
return &agent.AlgoRequest{Requirements: data}
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Then send algobuffer
|
||||
if err := p.sendBuffer(algobuffer, wrapper, func(data []byte) interface{} {
|
||||
// Then send algo
|
||||
if err := p.sendBuffer(algo, wrapper, func(data []byte) interface{} {
|
||||
return &agent.AlgoRequest{Algorithm: data}
|
||||
}); err != nil {
|
||||
return err
|
||||
@@ -106,23 +121,32 @@ func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *b
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := wrapper.CloseAndRecv()
|
||||
_, err = wrapper.CloseAndRecv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProgressBar) SendData(description, filename string, buffer *bytes.Buffer, stream *agent.AgentService_DataClient) error {
|
||||
return p.sendData(description, buffer, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
|
||||
func (p *ProgressBar) SendData(description, filename string, file *os.File, stream agent.AgentService_DataClient) error {
|
||||
return p.sendData(description, file, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
|
||||
return &agent.DataRequest{Dataset: data, Filename: filename}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error {
|
||||
p.reset(description, buffer.Len())
|
||||
func (p *ProgressBar) sendData(description string, file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
|
||||
dataInfo, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.reset(description, int(dataInfo.Size()))
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := buffer.Read(buf)
|
||||
n, err := file.Read(buf)
|
||||
if err == io.EOF {
|
||||
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
||||
return err
|
||||
@@ -147,15 +171,15 @@ func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream
|
||||
}
|
||||
}
|
||||
|
||||
_, err := stream.CloseAndRecv()
|
||||
_, err = stream.CloseAndRecv()
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProgressBar) sendBuffer(buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error {
|
||||
func (p *ProgressBar) sendBuffer(file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := buffer.Read(buf)
|
||||
n, err := file.Read(buf)
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
@@ -303,7 +327,7 @@ func (p *ProgressBar) clearProgressBar() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient) ([]byte, error) {
|
||||
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient, resultFile *os.File) error {
|
||||
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
@@ -311,10 +335,10 @@ func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream ag
|
||||
}
|
||||
|
||||
return response.File, nil
|
||||
})
|
||||
}, resultFile)
|
||||
}
|
||||
|
||||
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient) ([]byte, error) {
|
||||
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient, attestationFile *os.File) error {
|
||||
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
@@ -322,37 +346,37 @@ func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stre
|
||||
}
|
||||
|
||||
return response.File, nil
|
||||
})
|
||||
}, attestationFile)
|
||||
}
|
||||
|
||||
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error)) ([]byte, error) {
|
||||
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error), file *os.File) error {
|
||||
p.reset(description, totalSize)
|
||||
p.isDownload = true
|
||||
|
||||
var result []byte
|
||||
for {
|
||||
chunk, err := recv()
|
||||
if err == io.EOF {
|
||||
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
chunkSize := len(chunk)
|
||||
if err = p.updateProgress(chunkSize); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
result = append(result, chunk...)
|
||||
|
||||
if _, err := file.Write(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.renderProgressBar(); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
+20
-23
@@ -3,7 +3,6 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
@@ -12,6 +11,7 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
@@ -24,10 +24,10 @@ import (
|
||||
|
||||
//go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type SDK interface {
|
||||
Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error
|
||||
Data(ctx context.Context, dataset agent.Dataset, privKey any) error
|
||||
Result(ctx context.Context, privKey any) ([]byte, error)
|
||||
Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error)
|
||||
Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error
|
||||
Data(ctx context.Context, dataset *os.File, filename string, privKey any) error
|
||||
Result(ctx context.Context, privKey any, resultFile *os.File) error
|
||||
Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -48,7 +48,7 @@ func NewAgentSDK(agentClient agent.AgentServiceClient) SDK {
|
||||
}
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error {
|
||||
func (sdk *agentSDK) Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error {
|
||||
md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -62,14 +62,12 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
|
||||
reqBuffer := bytes.NewBuffer(algorithm.Requirements)
|
||||
|
||||
pb := progressbar.New(false)
|
||||
return pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream)
|
||||
return pb.SendAlgorithm(algoProgressBarDescription, algorithm, requirements, stream)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error {
|
||||
func (sdk *agentSDK) Data(ctx context.Context, dataset *os.File, filename string, privKey any) error {
|
||||
md, err := generateMetadata(string(auth.DataProviderRole), privKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -83,29 +81,28 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dataBuffer := bytes.NewBuffer(dataset.Dataset)
|
||||
|
||||
pb := progressbar.New(false)
|
||||
return pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream)
|
||||
return pb.SendData(dataProgressBarDescription, filename, dataset, stream)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
|
||||
func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.File) error {
|
||||
request := &agent.ResultRequest{}
|
||||
|
||||
md, err := generateMetadata(string(auth.ConsumerRole), privKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||
stream, err := sdk.client.Result(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
@@ -116,27 +113,27 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveResult(resultProgressDescription, fileSize, stream)
|
||||
return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile)
|
||||
}
|
||||
|
||||
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) {
|
||||
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error {
|
||||
request := &agent.AttestationRequest{
|
||||
ReportData: reportData[:],
|
||||
}
|
||||
|
||||
stream, err := sdk.client.Attestation(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
incomingmd, err := stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||
@@ -147,12 +144,12 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) (
|
||||
|
||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
pb := progressbar.New(true)
|
||||
|
||||
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream)
|
||||
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream, attestationFile)
|
||||
}
|
||||
|
||||
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
|
||||
|
||||
+61
-9
@@ -115,7 +115,21 @@ func TestAlgo(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err)
|
||||
err = sdk.Algo(context.Background(), tc.algo, tc.userKey)
|
||||
|
||||
algo, err := os.CreateTemp("", "algo")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(algo.Name())
|
||||
|
||||
_, err = algo.Write(algorithm.Algorithm)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = algo.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
algo, err = os.Open(algo.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sdk.Algo(context.Background(), algo, nil, tc.userKey)
|
||||
|
||||
st, _ := status.FromError(err)
|
||||
|
||||
@@ -212,7 +226,19 @@ func TestData(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr)
|
||||
|
||||
err = sdk.Data(context.Background(), tc.data, tc.userKey)
|
||||
data, err := os.CreateTemp("", "data")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = data.Write(dataset.Dataset)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = data.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err = os.Open(data.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
err = sdk.Data(context.Background(), data, tc.data.Filename, tc.userKey)
|
||||
|
||||
st, _ := status.FromError(err)
|
||||
|
||||
@@ -273,7 +299,7 @@ func TestResult(t *testing.T) {
|
||||
name: "Results not ready",
|
||||
userKey: resultConsumer1Key,
|
||||
response: &agent.ResultResponse{
|
||||
File: []byte(nil),
|
||||
File: []byte{},
|
||||
},
|
||||
svcRes: nil,
|
||||
err: agent.ErrResultsNotReady,
|
||||
@@ -282,7 +308,7 @@ func TestResult(t *testing.T) {
|
||||
name: "All manifest items received",
|
||||
userKey: resultConsumer1Key,
|
||||
response: &agent.ResultResponse{
|
||||
File: []byte(nil),
|
||||
File: []byte{},
|
||||
},
|
||||
svcRes: nil,
|
||||
err: agent.ErrAllManifestItemsReceived,
|
||||
@@ -291,7 +317,7 @@ func TestResult(t *testing.T) {
|
||||
name: "Undeclared consumer",
|
||||
userKey: resultConsumer1Key,
|
||||
response: &agent.ResultResponse{
|
||||
File: []byte(nil),
|
||||
File: []byte{},
|
||||
},
|
||||
svcRes: nil,
|
||||
err: agent.ErrUndeclaredConsumer,
|
||||
@@ -300,7 +326,17 @@ func TestResult(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||
res, err := sdk.Result(context.Background(), tc.userKey)
|
||||
|
||||
resultFile, err := os.CreateTemp("", "result")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(resultFile.Name())
|
||||
})
|
||||
|
||||
err = sdk.Result(context.Background(), tc.userKey, resultFile)
|
||||
|
||||
require.NoError(t, resultFile.Close())
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
@@ -312,6 +348,10 @@ func TestResult(t *testing.T) {
|
||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||
}
|
||||
}
|
||||
|
||||
res, err := os.ReadFile(resultFile.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.response.File, res, tc.name)
|
||||
|
||||
svcCall.Unset()
|
||||
@@ -375,7 +415,7 @@ func TestAttestation(t *testing.T) {
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte(reportData),
|
||||
response: &agent.AttestationResponse{
|
||||
File: nil,
|
||||
File: []byte{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
@@ -384,7 +424,7 @@ func TestAttestation(t *testing.T) {
|
||||
userKey: resultConsumerKey,
|
||||
reportData: [agent.ReportDataSize]byte{},
|
||||
response: &agent.AttestationResponse{
|
||||
File: nil,
|
||||
File: []byte{},
|
||||
},
|
||||
svcRes: nil,
|
||||
err: errors.New("invalid report data"),
|
||||
@@ -395,7 +435,16 @@ func TestAttestation(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||
|
||||
res, err := sdk.Attestation(context.Background(), tc.reportData)
|
||||
file, err := os.CreateTemp("", "attestation")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove(file.Name())
|
||||
})
|
||||
|
||||
err = sdk.Attestation(context.Background(), tc.reportData, file)
|
||||
|
||||
require.NoError(t, file.Close())
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
@@ -408,6 +457,9 @@ func TestAttestation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := os.ReadFile(file.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.response.File, res, tc.name)
|
||||
|
||||
svcCall.Unset()
|
||||
|
||||
+27
-52
@@ -7,8 +7,7 @@ package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
os "os"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
@@ -18,17 +17,17 @@ type SDK struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Algo provides a mock function with given fields: ctx, algorithm, privKey
|
||||
func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey interface{}) error {
|
||||
ret := _m.Called(ctx, algorithm, privKey)
|
||||
// Algo provides a mock function with given fields: ctx, algorithm, requirements, privKey
|
||||
func (_m *SDK) Algo(ctx context.Context, algorithm *os.File, requirements *os.File, privKey interface{}) error {
|
||||
ret := _m.Called(ctx, algorithm, requirements, privKey)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Algo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm, interface{}) error); ok {
|
||||
r0 = rf(ctx, algorithm, privKey)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *os.File, *os.File, interface{}) error); ok {
|
||||
r0 = rf(ctx, algorithm, requirements, privKey)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@@ -36,47 +35,35 @@ func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey inte
|
||||
return r0
|
||||
}
|
||||
|
||||
// Attestation provides a mock function with given fields: ctx, reportData
|
||||
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) {
|
||||
ret := _m.Called(ctx, reportData)
|
||||
// Attestation provides a mock function with given fields: ctx, reportData, attestationFile
|
||||
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error {
|
||||
ret := _m.Called(ctx, reportData, attestationFile)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Attestation")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok {
|
||||
return rf(ctx, reportData)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok {
|
||||
r0 = rf(ctx, reportData)
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok {
|
||||
r0 = rf(ctx, reportData, attestationFile)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok {
|
||||
r1 = rf(ctx, reportData)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
return r0
|
||||
}
|
||||
|
||||
// Data provides a mock function with given fields: ctx, dataset, privKey
|
||||
func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interface{}) error {
|
||||
ret := _m.Called(ctx, dataset, privKey)
|
||||
// Data provides a mock function with given fields: ctx, dataset, filename, privKey
|
||||
func (_m *SDK) Data(ctx context.Context, dataset *os.File, filename string, privKey interface{}) error {
|
||||
ret := _m.Called(ctx, dataset, filename, privKey)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Data")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset, interface{}) error); ok {
|
||||
r0 = rf(ctx, dataset, privKey)
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *os.File, string, interface{}) error); ok {
|
||||
r0 = rf(ctx, dataset, filename, privKey)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
@@ -84,34 +71,22 @@ func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interfac
|
||||
return r0
|
||||
}
|
||||
|
||||
// Result provides a mock function with given fields: ctx, privKey
|
||||
func (_m *SDK) Result(ctx context.Context, privKey interface{}) ([]byte, error) {
|
||||
ret := _m.Called(ctx, privKey)
|
||||
// Result provides a mock function with given fields: ctx, privKey, resultFile
|
||||
func (_m *SDK) Result(ctx context.Context, privKey interface{}, resultFile *os.File) error {
|
||||
ret := _m.Called(ctx, privKey, resultFile)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Result")
|
||||
}
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, interface{}) ([]byte, error)); ok {
|
||||
return rf(ctx, privKey)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, interface{}) []byte); ok {
|
||||
r0 = rf(ctx, privKey)
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, interface{}, *os.File) error); ok {
|
||||
r0 = rf(ctx, privKey, resultFile)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, interface{}) error); ok {
|
||||
r1 = rf(ctx, privKey)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
|
||||
Reference in New Issue
Block a user