mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Improve file streaming (#295)
* improve file streaming Signed-off-by: Sammy Oina <sammyoina@gmail.com> * error check Signed-off-by: Sammy Oina <sammyoina@gmail.com> * empty line Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * send buffer test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * stream data and attestation Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fumpt Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * mocks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * value check Signed-off-by: Sammy Oina <sammyoina@gmail.com> * more value checks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fumpt Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * all files Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
01a619fd2a
commit
46b94204df
@@ -16,11 +16,6 @@ type Service struct {
|
|||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close provides a mock function with given fields:
|
|
||||||
func (_m *Service) Close() {
|
|
||||||
_m.Called()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendEvent provides a mock function with given fields: event, status, details
|
// SendEvent provides a mock function with given fields: event, status, details
|
||||||
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
|
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
|
||||||
ret := _m.Called(event, status, details)
|
ret := _m.Called(event, status, details)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful upload",
|
name: "successful upload",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() error {
|
setupFiles: func() error {
|
||||||
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
||||||
@@ -75,7 +75,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing algorithm file",
|
name: "missing algorithm file",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
args: []string{"non_existent_algo_file.py", privateKeyFile},
|
args: []string{"non_existent_algo_file.py", privateKeyFile},
|
||||||
expectedOutput: "Error reading algorithm file",
|
expectedOutput: "Error reading algorithm file",
|
||||||
@@ -83,7 +83,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing private key file",
|
name: "missing private key file",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() error {
|
setupFiles: func() error {
|
||||||
return os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644)
|
return os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644)
|
||||||
@@ -97,7 +97,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "upload failure",
|
name: "upload failure",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
||||||
},
|
},
|
||||||
setupFiles: func() error {
|
setupFiles: func() error {
|
||||||
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
||||||
@@ -115,7 +115,7 @@ func TestAlgorithmCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid private key",
|
name: "invalid private key",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Algo", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Algo", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() error {
|
setupFiles: func() error {
|
||||||
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
if err := os.WriteFile(algorithmFile, []byte("test algorithm"), 0o644); err != nil {
|
||||||
|
|||||||
+7
-10
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"github.com/ultravioletrs/cocos/agent"
|
|
||||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
@@ -38,24 +37,22 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
|||||||
|
|
||||||
cmd.Println("Uploading algorithm file:", algorithmFile)
|
cmd.Println("Uploading algorithm file:", algorithmFile)
|
||||||
|
|
||||||
algorithm, err := os.ReadFile(algorithmFile)
|
algorithm, err := os.Open(algorithmFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
printError(cmd, "Error reading algorithm file: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req []byte
|
defer algorithm.Close()
|
||||||
|
|
||||||
|
var req *os.File
|
||||||
if requirementsFile != "" {
|
if requirementsFile != "" {
|
||||||
req, err = os.ReadFile(requirementsFile)
|
req, err = os.Open(requirementsFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
printError(cmd, "Error reading requirments file: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
defer req.Close()
|
||||||
|
|
||||||
algoReq := agent.Algorithm{
|
|
||||||
Algorithm: algorithm,
|
|
||||||
Requirements: req,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
privKeyFile, err := os.ReadFile(args[1])
|
privKeyFile, err := os.ReadFile(args[1])
|
||||||
@@ -74,7 +71,7 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command {
|
|||||||
|
|
||||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||||
|
|
||||||
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil {
|
if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algorithm, req, privKey); err != nil {
|
||||||
printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
printError(cmd, "Failed to upload algorithm due to error: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
+26
-7
@@ -176,25 +176,44 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData))
|
filename := attestationFilePath
|
||||||
|
if getJsonAttestation {
|
||||||
|
filename = attestationJson
|
||||||
|
}
|
||||||
|
|
||||||
|
attestationFile, err := os.Create(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
printError(cmd, "Error creating attestation file: %v ❌ ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData), attestationFile); err != nil {
|
||||||
printError(cmd, "Failed to get attestation due to error: %v ❌ ", err)
|
printError(cmd, "Failed to get attestation due to error: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
filename := attestationFilePath
|
if err := attestationFile.Close(); err != nil {
|
||||||
|
printError(cmd, "Error closing attestation file: %v ❌ ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if getJsonAttestation {
|
if getJsonAttestation {
|
||||||
|
result, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
printError(cmd, "Error reading attestation file: %v ❌ ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err = attesationToJSON(result)
|
result, err = attesationToJSON(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
printError(cmd, "Error converting attestation to json: %v ❌ ", err)
|
printError(cmd, "Error converting attestation to json: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
filename = attestationJson
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = os.WriteFile(filename, result, 0o644); err != nil {
|
if err := os.WriteFile(filename, result, 0o644); err != nil {
|
||||||
printError(cmd, "Error saving attestation result: %v ❌ ", err)
|
printError(cmd, "Error writing attestation file: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Attestation result retrieved and saved successfully!")
|
cmd.Println("Attestation result retrieved and saved successfully!")
|
||||||
|
|||||||
+14
-3
@@ -22,7 +22,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNewAttestationCmd(t *testing.T) {
|
func TestNewAttestationCmd(t *testing.T) {
|
||||||
cli := &CLI{}
|
mockSDK := new(mocks.SDK)
|
||||||
|
cli := &CLI{agentSDK: mockSDK}
|
||||||
cmd := cli.NewAttestationCmd()
|
cmd := cli.NewAttestationCmd()
|
||||||
|
|
||||||
assert.Equal(t, "attestation [command]", cmd.Use)
|
assert.Equal(t, "attestation [command]", cmd.Use)
|
||||||
@@ -30,12 +31,19 @@ func TestNewAttestationCmd(t *testing.T) {
|
|||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
cmd.SetOut(&buf)
|
cmd.SetOut(&buf)
|
||||||
|
|
||||||
|
cmd.SetOutput(&buf)
|
||||||
|
|
||||||
|
reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize)
|
||||||
|
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData), mock.Anything).Return(nil)
|
||||||
|
|
||||||
|
cmd.SetArgs([]string{hex.EncodeToString(reportData)})
|
||||||
err := cmd.Execute()
|
err := cmd.Execute()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Contains(t, buf.String(), "Get and validate attestations")
|
assert.Contains(t, buf.String(), "Get and validate attestations")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewGetAttestationCmdN(t *testing.T) {
|
func TestNewGetAttestationCmd(t *testing.T) {
|
||||||
validattestation, err := os.ReadFile("../attestation.bin")
|
validattestation, err := os.ReadFile("../attestation.bin")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
@@ -119,7 +127,10 @@ func TestNewGetAttestationCmdN(t *testing.T) {
|
|||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
cmd.SetOutput(&buf)
|
cmd.SetOutput(&buf)
|
||||||
|
|
||||||
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))).Return(tc.mockResponse, tc.mockError)
|
mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) {
|
||||||
|
_, err := args.Get(2).(*os.File).Write(tc.mockResponse)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
cmd.SetArgs(tc.args)
|
cmd.SetArgs(tc.args)
|
||||||
err := cmd.Execute()
|
err := cmd.Execute()
|
||||||
|
|||||||
+7
-9
@@ -41,25 +41,23 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var dataset []byte
|
var dataset *os.File
|
||||||
|
|
||||||
if f.IsDir() {
|
if f.IsDir() {
|
||||||
dataset, err = internal.ZipDirectoryToMemory(datasetPath)
|
dataset, err = internal.ZipDirectoryToTempFile(datasetPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
printError(cmd, "Error zipping dataset directory: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer dataset.Close()
|
||||||
|
defer os.Remove(dataset.Name())
|
||||||
} else {
|
} else {
|
||||||
dataset, err = os.ReadFile(datasetPath)
|
dataset, err = os.Open(datasetPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
printError(cmd, "Error reading dataset file: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
defer dataset.Close()
|
||||||
|
|
||||||
dataReq := agent.Dataset{
|
|
||||||
Dataset: dataset,
|
|
||||||
Filename: path.Base(datasetPath),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
privKeyFile, err := os.ReadFile(args[1])
|
privKeyFile, err := os.ReadFile(args[1])
|
||||||
@@ -77,7 +75,7 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string)))
|
||||||
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil {
|
if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataset, path.Base(datasetPath), privKey); err != nil {
|
||||||
printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
printError(cmd, "Failed to upload dataset due to error: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func TestDatasetsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful upload",
|
name: "successful upload",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
datasetFile, err := createTempDatasetFile("test dataset content")
|
datasetFile, err := createTempDatasetFile("test dataset content")
|
||||||
@@ -58,7 +58,7 @@ func TestDatasetsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing dataset file",
|
name: "missing dataset file",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
return "", nil
|
return "", nil
|
||||||
@@ -68,7 +68,7 @@ func TestDatasetsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing private key file",
|
name: "missing private key file",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
return createTempDatasetFile("test dataset content")
|
return createTempDatasetFile("test dataset content")
|
||||||
@@ -81,7 +81,7 @@ func TestDatasetsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "upload failure",
|
name: "upload failure",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("failed to upload algorithm due to error"))
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
datasetFile, err := createTempDatasetFile("test dataset content")
|
datasetFile, err := createTempDatasetFile("test dataset content")
|
||||||
@@ -100,7 +100,7 @@ func TestDatasetsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid private key",
|
name: "invalid private key",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Data", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
m.On("Data", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
datasetFile, err := createTempDatasetFile("test dataset content")
|
datasetFile, err := createTempDatasetFile("test dataset content")
|
||||||
|
|||||||
+16
-34
@@ -4,7 +4,6 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
@@ -14,13 +13,14 @@ import (
|
|||||||
const (
|
const (
|
||||||
resultFilePrefix = "results"
|
resultFilePrefix = "results"
|
||||||
resultFileExt = ".zip"
|
resultFileExt = ".zip"
|
||||||
|
resultfilename = "results.zip"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (cli *CLI) NewResultsCmd() *cobra.Command {
|
func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||||
return &cobra.Command{
|
return &cobra.Command{
|
||||||
Use: "result",
|
Use: "result",
|
||||||
Short: "Retrieve computation result file",
|
Short: "Retrieve computation result file",
|
||||||
Example: "result <private_key_file_path>",
|
Example: "result <private_key_file_path> <optional_file_name.zip>",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
if cli.connectErr != nil {
|
if cli.connectErr != nil {
|
||||||
@@ -36,50 +36,32 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pemBlock, _ := pem.Decode(privKeyFile)
|
filename := resultfilename
|
||||||
|
if len(args) > 1 {
|
||||||
|
filename = args[1]
|
||||||
|
}
|
||||||
|
|
||||||
var result []byte
|
pemBlock, _ := pem.Decode(privKeyFile)
|
||||||
|
|
||||||
privKey, err := decodeKey(pemBlock)
|
privKey, err := decodeKey(pemBlock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
printError(cmd, "Error decoding private key: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
result, err = cli.agentSDK.Result(cmd.Context(), privKey)
|
|
||||||
|
resultFile, err := os.Create(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
printError(cmd, "Error creating result file: %v ❌ ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resultFile.Close()
|
||||||
|
|
||||||
|
if err = cli.agentSDK.Result(cmd.Context(), privKey, resultFile); err != nil {
|
||||||
printError(cmd, "Error retrieving computation result: %v ❌ ", err)
|
printError(cmd, "Error retrieving computation result: %v ❌ ", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt)
|
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", filename))
|
||||||
if err != nil {
|
|
||||||
printError(cmd, "Error generating unique file path: %v ❌ ", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(resultFilePath, result, 0o644); err != nil {
|
|
||||||
printError(cmd, "Error saving computation result file: %v ❌ ", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath))
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUniqueFilePath(prefix, ext string) (string, error) {
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
var filename string
|
|
||||||
if i == 0 {
|
|
||||||
filename = prefix + ext
|
|
||||||
} else {
|
|
||||||
filename = fmt.Sprintf("%s_%d%s", prefix, i, ext)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
|
||||||
return filename, nil
|
|
||||||
} else if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+18
-51
@@ -6,7 +6,6 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -20,7 +19,10 @@ const compResult = "Test computation result"
|
|||||||
|
|
||||||
func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
||||||
mockSDK := new(mocks.SDK)
|
mockSDK := new(mocks.SDK)
|
||||||
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||||
|
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
testCLI := CLI{agentSDK: mockSDK}
|
testCLI := CLI{agentSDK: mockSDK}
|
||||||
|
|
||||||
err := generateRSAPrivateKeyFile(privateKeyFile)
|
err := generateRSAPrivateKeyFile(privateKeyFile)
|
||||||
@@ -40,7 +42,6 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
|||||||
|
|
||||||
files, err := filepath.Glob("results*.zip")
|
files, err := filepath.Glob("results*.zip")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, files, 3)
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
@@ -52,7 +53,10 @@ func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
|||||||
|
|
||||||
func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
|
func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
|
||||||
mockSDK := new(mocks.SDK)
|
mockSDK := new(mocks.SDK)
|
||||||
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
mockSDK.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||||
|
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
testCLI := CLI{agentSDK: mockSDK}
|
testCLI := CLI{agentSDK: mockSDK}
|
||||||
|
|
||||||
invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem")
|
invalidPrivateKey, err := os.CreateTemp("", "invalid_private_key.pem")
|
||||||
@@ -73,30 +77,7 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Contains(t, buf.String(), "Error decoding private key")
|
require.Contains(t, buf.String(), "Error decoding private key")
|
||||||
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything)
|
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything, mock.Anything)
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUniqueFilePath(t *testing.T) {
|
|
||||||
prefix := "test"
|
|
||||||
ext := ".txt"
|
|
||||||
|
|
||||||
path, err := getUniqueFilePath(prefix, ext)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "test.txt", path)
|
|
||||||
|
|
||||||
_, err = os.Create("test.txt")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.Remove("test.txt")
|
|
||||||
for i := 1; i < 3; i++ {
|
|
||||||
fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext)
|
|
||||||
_, err := os.Create(fileName)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer os.Remove(fileName)
|
|
||||||
}
|
|
||||||
|
|
||||||
path, err = getUniqueFilePath(prefix, ext)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "test_3.txt", path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResultsCmd(t *testing.T) {
|
func TestResultsCmd(t *testing.T) {
|
||||||
@@ -111,7 +92,10 @@ func TestResultsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful result retrieval",
|
name: "successful result retrieval",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||||
|
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
|
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
|
||||||
@@ -128,7 +112,10 @@ func TestResultsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing private key file",
|
name: "missing private key file",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
|
||||||
|
_, err := args.Get(2).(*os.File).WriteString(compResult)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
return "non_existent_private_key.pem", nil
|
return "non_existent_private_key.pem", nil
|
||||||
@@ -138,7 +125,7 @@ func TestResultsCmd(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "result retrieval failure",
|
name: "result retrieval failure",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
m.On("Result", mock.Anything, mock.Anything).Return(nil, errors.New("error retrieving computation result"))
|
m.On("Result", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("error retrieving computation result"))
|
||||||
},
|
},
|
||||||
setupFiles: func() (string, error) {
|
setupFiles: func() (string, error) {
|
||||||
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
|
return privateKeyFile, generateRSAPrivateKeyFile(privateKeyFile)
|
||||||
@@ -148,26 +135,6 @@ func TestResultsCmd(t *testing.T) {
|
|||||||
os.Remove(privateKeyFile)
|
os.Remove(privateKeyFile)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "save failure",
|
|
||||||
setupMock: func(m *mocks.SDK) {
|
|
||||||
m.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
|
||||||
},
|
|
||||||
setupFiles: func() (string, error) {
|
|
||||||
err := generateRSAPrivateKeyFile(privateKeyFile)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// Simulate failure in saving the result file by making all files read-only
|
|
||||||
return privateKeyFile, os.Chmod(".", 0o555)
|
|
||||||
},
|
|
||||||
expectedOutput: "Error saving computation result file",
|
|
||||||
cleanup: func() {
|
|
||||||
err := os.Chmod(".", 0o755)
|
|
||||||
require.NoError(t, err)
|
|
||||||
os.Remove(privateKeyFile)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "connection error",
|
name: "connection error",
|
||||||
setupMock: func(m *mocks.SDK) {
|
setupMock: func(m *mocks.SDK) {
|
||||||
|
|||||||
+1
-1
@@ -4,7 +4,7 @@
|
|||||||
coverage:
|
coverage:
|
||||||
ignore:
|
ignore:
|
||||||
- "test/*"
|
- "test/*"
|
||||||
- "cmd/*"
|
- "cmd/**"
|
||||||
- "**/mocks/**"
|
- "**/mocks/**"
|
||||||
- "mocks/**"
|
- "mocks/**"
|
||||||
- "**/*.pb.go"
|
- "**/*.pb.go"
|
||||||
|
|||||||
@@ -60,6 +60,60 @@ func ZipDirectoryToMemory(sourceDir string) ([]byte, error) {
|
|||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ZipDirectoryToTempFile(sourceDir string) (*os.File, error) {
|
||||||
|
tmpFile, err := os.CreateTemp("", "dataset*.zip")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
zipWriter := zip.NewWriter(tmpFile)
|
||||||
|
|
||||||
|
err = filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
relPath, err := filepath.Rel(sourceDir, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
zipHeader, err := zip.FileInfoHeader(info)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
zipHeader.Name = relPath
|
||||||
|
|
||||||
|
zipWriterEntry, err := zipWriter.CreateHeader(zipHeader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fileToZip, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fileToZip.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(zipWriterEntry, fileToZip)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
zipWriter.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := zipWriter.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
func UnzipFromMemory(zipData []byte, targetDir string) error {
|
func UnzipFromMemory(zipData []byte, targetDir string) error {
|
||||||
reader := bytes.NewReader(zipData)
|
reader := bytes.NewReader(zipData)
|
||||||
zipReader, err := zip.NewReader(reader, int64(len(zipData)))
|
zipReader, err := zip.NewReader(reader, int64(len(zipData)))
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"archive/zip"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -104,3 +106,147 @@ func TestZipDirectoryToMemory_NonExistentDirectory(t *testing.T) {
|
|||||||
t.Error("ZipDirectoryToMemory should fail with non-existent directory")
|
t.Error("ZipDirectoryToMemory should fail with non-existent directory")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestZipDirectoryToTempFile(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupFiles map[string]string // map of relative path to content
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single file",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"test.txt": "hello world",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple files in root",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"test1.txt": "content1",
|
||||||
|
"test2.txt": "content2",
|
||||||
|
"test3.txt": "content3",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested directory structure",
|
||||||
|
setupFiles: map[string]string{
|
||||||
|
"file1.txt": "root file",
|
||||||
|
"dir1/file2.txt": "nested file",
|
||||||
|
"dir1/dir2/file3.txt": "deeply nested file",
|
||||||
|
"dir1/dir2/dir3/file4.txt": "very deeply nested file",
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty directory",
|
||||||
|
setupFiles: map[string]string{},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sourceDir, err := os.MkdirTemp("", "source")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create temp source directory: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(sourceDir)
|
||||||
|
|
||||||
|
for relPath, content := range tt.setupFiles {
|
||||||
|
fullPath := filepath.Join(sourceDir, relPath)
|
||||||
|
dir := filepath.Dir(fullPath)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
t.Fatalf("Failed to create directory %s: %v", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(fullPath, []byte(content), 0o644); err != nil {
|
||||||
|
t.Fatalf("Failed to write file %s: %v", fullPath, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
zipFile, err := ZipDirectoryToTempFile(sourceDir)
|
||||||
|
if err != nil {
|
||||||
|
if !tt.expectError {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer os.Remove(zipFile.Name())
|
||||||
|
defer zipFile.Close()
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
t.Fatal("Expected error but got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
zipReader, err := zip.OpenReader(zipFile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open zip file: %v", err)
|
||||||
|
}
|
||||||
|
defer zipReader.Close()
|
||||||
|
|
||||||
|
expectedFiles := make(map[string]string)
|
||||||
|
for path, content := range tt.setupFiles {
|
||||||
|
expectedFiles[filepath.ToSlash(path)] = content
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range zipReader.File {
|
||||||
|
expectedContent, exists := expectedFiles[file.Name]
|
||||||
|
if !exists {
|
||||||
|
t.Errorf("Unexpected file in zip: %s", file.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := file.Open()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to open file in zip %s: %v", file.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := io.ReadAll(rc)
|
||||||
|
rc.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to read file in zip %s: %v", file.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(content) != expectedContent {
|
||||||
|
t.Errorf("File %s content mismatch: got %s, want %s", file.Name, content, expectedContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(expectedFiles, file.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for path := range expectedFiles {
|
||||||
|
t.Errorf("Missing file in zip: %s", path)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZipDirectoryToTempFile_InvalidInput(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sourceDir string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-existent directory",
|
||||||
|
sourceDir: "/path/that/does/not/exist",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty path",
|
||||||
|
sourceDir: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := ZipDirectoryToTempFile(tt.sourceDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error but got none")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -53,16 +53,37 @@ func TestSendAlgorithm(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
pb := New(false)
|
pb := New(false)
|
||||||
algobuffer := bytes.NewBufferString("algorithm content")
|
|
||||||
reqBuffer := bytes.NewBufferString("requirements content")
|
algo, err := os.CreateTemp("", "test_algo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
req, err := os.CreateTemp("", "test_req")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = algo.WriteString("test algorithm")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = algo.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
algo, err = os.Open(algo.Name())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = req.WriteString("test request")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = req.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
req, err = os.Open(req.Name())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
algoStream := new(mocks.AgentService_AlgoClient)
|
algoStream := new(mocks.AgentService_AlgoClient)
|
||||||
algoStream.On("Send", mock.Anything).Return(tc.sendError)
|
algoStream.On("Send", mock.Anything).Return(tc.sendError)
|
||||||
algoStream.On("CloseAndRecv").Return(&agent.AlgoResponse{}, tc.closeRecvError)
|
algoStream.On("CloseAndRecv").Return(&agent.AlgoResponse{}, tc.closeRecvError)
|
||||||
mockStream := &mockAlgoStream{stream: algoStream}
|
mockStream := &mockAlgoStream{stream: algoStream}
|
||||||
|
|
||||||
err := pb.SendAlgorithm("Test Algorithm", algobuffer, reqBuffer, &mockStream.stream)
|
err = pb.SendAlgorithm("Test Algorithm", algo, req, mockStream.stream)
|
||||||
assert.True(t, errors.Contains(err, tc.err))
|
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error: %v, got: %v", tc.err, err))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -108,14 +129,24 @@ func TestSendData(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
pb := New(false)
|
pb := New(false)
|
||||||
buffer := bytes.NewBufferString(tc.dataContent)
|
dataset, err := os.CreateTemp("", "test_dataset")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = dataset.WriteString(tc.dataContent)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = dataset.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
dataset, err = os.Open(dataset.Name())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
dataStream := new(mocks.AgentService_DataClient)
|
dataStream := new(mocks.AgentService_DataClient)
|
||||||
dataStream.On("Send", mock.Anything).Return(tc.sendError)
|
dataStream.On("Send", mock.Anything).Return(tc.sendError)
|
||||||
dataStream.On("CloseAndRecv").Return(&agent.DataResponse{}, tc.closeRecvError)
|
dataStream.On("CloseAndRecv").Return(&agent.DataResponse{}, tc.closeRecvError)
|
||||||
mockStream := &mockDataStream{stream: dataStream}
|
mockStream := &mockDataStream{stream: dataStream}
|
||||||
|
|
||||||
err := pb.SendData("Test Data", "test.txt", buffer, &mockStream.stream)
|
err = pb.SendData("Test Data", "test.txt", dataset, mockStream.stream)
|
||||||
assert.True(t, errors.Contains(err, tc.err))
|
assert.True(t, errors.Contains(err, tc.err))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -361,7 +392,7 @@ func TestReceiveResult(t *testing.T) {
|
|||||||
setupMock: func(m *MockResultStream) {
|
setupMock: func(m *MockResultStream) {
|
||||||
m.On("Recv").Return(nil, io.EOF).Once()
|
m.On("Recv").Return(nil, io.EOF).Once()
|
||||||
},
|
},
|
||||||
wantResult: nil,
|
wantResult: []byte{},
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -375,13 +406,26 @@ func TestReceiveResult(t *testing.T) {
|
|||||||
// Disable terminal width check for tests
|
// Disable terminal width check for tests
|
||||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||||
|
|
||||||
result, err := p.ReceiveResult(tt.description, tt.totalSize, mockStream)
|
resultFile, err := os.CreateTemp("", "test_result")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(resultFile.Name())
|
||||||
|
})
|
||||||
|
|
||||||
|
err = p.ReceiveResult(tt.description, tt.totalSize, mockStream, resultFile)
|
||||||
|
|
||||||
|
assert.NoError(t, resultFile.Close())
|
||||||
|
|
||||||
if tt.wantErr != nil {
|
if tt.wantErr != nil {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := os.ReadFile(resultFile.Name())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tt.wantResult, result)
|
assert.Equal(t, tt.wantResult, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,13 +501,26 @@ func TestReceiveAttestation(t *testing.T) {
|
|||||||
// Disable terminal width check for tests
|
// Disable terminal width check for tests
|
||||||
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
p.TerminalWidthFunc = func() (int, error) { return 100, nil }
|
||||||
|
|
||||||
result, err := p.ReceiveAttestation(tt.description, tt.totalSize, mockStream)
|
resultFile, err := os.CreateTemp("", "test_attestation")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(resultFile.Name())
|
||||||
|
})
|
||||||
|
|
||||||
|
err = p.ReceiveAttestation(tt.description, tt.totalSize, mockStream, resultFile)
|
||||||
|
|
||||||
|
assert.NoError(t, resultFile.Close())
|
||||||
|
|
||||||
if tt.wantErr != nil {
|
if tt.wantErr != nil {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
assert.Equal(t, tt.wantErr.Error(), err.Error())
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := os.ReadFile(resultFile.Name())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tt.wantResult, result)
|
assert.Equal(t, tt.wantResult, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package progressbar
|
package progressbar
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
@@ -32,7 +31,7 @@ type streamSender interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type algoClientWrapper struct {
|
type algoClientWrapper struct {
|
||||||
client *agent.AgentService_AlgoClient
|
client agent.AgentService_AlgoClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *algoClientWrapper) Send(req interface{}) error {
|
func (a *algoClientWrapper) Send(req interface{}) error {
|
||||||
@@ -41,15 +40,15 @@ func (a *algoClientWrapper) Send(req interface{}) error {
|
|||||||
return fmt.Errorf("expected *AlgoRequest, got %T", req)
|
return fmt.Errorf("expected *AlgoRequest, got %T", req)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (*a.client).Send(algoReq)
|
return a.client.Send(algoReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) {
|
func (a *algoClientWrapper) CloseAndRecv() (interface{}, error) {
|
||||||
return (*a.client).CloseAndRecv()
|
return a.client.CloseAndRecv()
|
||||||
}
|
}
|
||||||
|
|
||||||
type dataClientWrapper struct {
|
type dataClientWrapper struct {
|
||||||
client *agent.AgentService_DataClient
|
client agent.AgentService_DataClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *dataClientWrapper) Send(req interface{}) error {
|
func (a *dataClientWrapper) Send(req interface{}) error {
|
||||||
@@ -58,11 +57,11 @@ func (a *dataClientWrapper) Send(req interface{}) error {
|
|||||||
return fmt.Errorf("expected *DataRequest, got %T", req)
|
return fmt.Errorf("expected *DataRequest, got %T", req)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (*a.client).Send(dataReq)
|
return a.client.Send(dataReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) {
|
func (a *dataClientWrapper) CloseAndRecv() (interface{}, error) {
|
||||||
return (*a.client).CloseAndRecv()
|
return a.client.CloseAndRecv()
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProgressBar struct {
|
type ProgressBar struct {
|
||||||
@@ -82,21 +81,37 @@ func New(isDownload bool) *ProgressBar {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error {
|
func (p *ProgressBar) SendAlgorithm(description string, algo, req *os.File, stream agent.AgentService_AlgoClient) error {
|
||||||
totalSize := algobuffer.Len() + reqBuffer.Len()
|
algoFileInfo, err := algo.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqSize := 0
|
||||||
|
if req != nil {
|
||||||
|
reqFileInfo, err := req.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reqSize = int(reqFileInfo.Size())
|
||||||
|
}
|
||||||
|
|
||||||
|
totalSize := int(algoFileInfo.Size()) + reqSize
|
||||||
p.reset(description, totalSize)
|
p.reset(description, totalSize)
|
||||||
|
|
||||||
wrapper := &algoClientWrapper{client: stream}
|
wrapper := &algoClientWrapper{client: stream}
|
||||||
|
|
||||||
// Send reqBuffer first
|
// Send req first
|
||||||
if err := p.sendBuffer(reqBuffer, wrapper, func(data []byte) interface{} {
|
if req != nil {
|
||||||
return &agent.AlgoRequest{Requirements: data}
|
if err := p.sendBuffer(req, wrapper, func(data []byte) interface{} {
|
||||||
}); err != nil {
|
return &agent.AlgoRequest{Requirements: data}
|
||||||
return err
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then send algobuffer
|
// Then send algo
|
||||||
if err := p.sendBuffer(algobuffer, wrapper, func(data []byte) interface{} {
|
if err := p.sendBuffer(algo, wrapper, func(data []byte) interface{} {
|
||||||
return &agent.AlgoRequest{Algorithm: data}
|
return &agent.AlgoRequest{Algorithm: data}
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -106,23 +121,32 @@ func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *b
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := wrapper.CloseAndRecv()
|
_, err = wrapper.CloseAndRecv()
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) SendData(description, filename string, buffer *bytes.Buffer, stream *agent.AgentService_DataClient) error {
|
func (p *ProgressBar) SendData(description, filename string, file *os.File, stream agent.AgentService_DataClient) error {
|
||||||
return p.sendData(description, buffer, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
|
return p.sendData(description, file, &dataClientWrapper{client: stream}, func(data []byte) interface{} {
|
||||||
return &agent.DataRequest{Dataset: data, Filename: filename}
|
return &agent.DataRequest{Dataset: data, Filename: filename}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error {
|
func (p *ProgressBar) sendData(description string, file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
|
||||||
p.reset(description, buffer.Len())
|
dataInfo, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.reset(description, int(dataInfo.Size()))
|
||||||
|
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := buffer.Read(buf)
|
n, err := file.Read(buf)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -147,15 +171,15 @@ func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := stream.CloseAndRecv()
|
_, err = stream.CloseAndRecv()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) sendBuffer(buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error {
|
func (p *ProgressBar) sendBuffer(file *os.File, stream streamSender, createRequest func([]byte) interface{}) error {
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := buffer.Read(buf)
|
n, err := file.Read(buf)
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -303,7 +327,7 @@ func (p *ProgressBar) clearProgressBar() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient) ([]byte, error) {
|
func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream agent.AgentService_ResultClient, resultFile *os.File) error {
|
||||||
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
||||||
response, err := stream.Recv()
|
response, err := stream.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -311,10 +335,10 @@ func (p *ProgressBar) ReceiveResult(description string, totalSize int, stream ag
|
|||||||
}
|
}
|
||||||
|
|
||||||
return response.File, nil
|
return response.File, nil
|
||||||
})
|
}, resultFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient) ([]byte, error) {
|
func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stream agent.AgentService_AttestationClient, attestationFile *os.File) error {
|
||||||
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
return p.receiveStream(description, totalSize, func() ([]byte, error) {
|
||||||
response, err := stream.Recv()
|
response, err := stream.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -322,37 +346,37 @@ func (p *ProgressBar) ReceiveAttestation(description string, totalSize int, stre
|
|||||||
}
|
}
|
||||||
|
|
||||||
return response.File, nil
|
return response.File, nil
|
||||||
})
|
}, attestationFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error)) ([]byte, error) {
|
func (p *ProgressBar) receiveStream(description string, totalSize int, recv func() ([]byte, error), file *os.File) error {
|
||||||
p.reset(description, totalSize)
|
p.reset(description, totalSize)
|
||||||
p.isDownload = true
|
p.isDownload = true
|
||||||
|
|
||||||
var result []byte
|
|
||||||
for {
|
for {
|
||||||
chunk, err := recv()
|
chunk, err := recv()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
if _, err := io.WriteString(os.Stdout, "\n"); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
chunkSize := len(chunk)
|
chunkSize := len(chunk)
|
||||||
if err = p.updateProgress(chunkSize); err != nil {
|
if err = p.updateProgress(chunkSize); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
result = append(result, chunk...)
|
if _, err := file.Write(chunk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := p.renderProgressBar(); err != nil {
|
if err := p.renderProgressBar(); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
+20
-23
@@ -3,7 +3,6 @@
|
|||||||
package sdk
|
package sdk
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
@@ -12,6 +11,7 @@ import (
|
|||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/absmach/magistrala/pkg/errors"
|
"github.com/absmach/magistrala/pkg/errors"
|
||||||
@@ -24,10 +24,10 @@ import (
|
|||||||
|
|
||||||
//go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
//go:generate mockery --name SDK --output=mocks --filename sdk.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||||
type SDK interface {
|
type SDK interface {
|
||||||
Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error
|
Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error
|
||||||
Data(ctx context.Context, dataset agent.Dataset, privKey any) error
|
Data(ctx context.Context, dataset *os.File, filename string, privKey any) error
|
||||||
Result(ctx context.Context, privKey any) ([]byte, error)
|
Result(ctx context.Context, privKey any, resultFile *os.File) error
|
||||||
Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error)
|
Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -48,7 +48,7 @@ func NewAgentSDK(agentClient agent.AgentServiceClient) SDK {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey any) error {
|
func (sdk *agentSDK) Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error {
|
||||||
md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey)
|
md, err := generateMetadata(string(auth.AlgorithmProviderRole), privKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -62,14 +62,12 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
|
|
||||||
reqBuffer := bytes.NewBuffer(algorithm.Requirements)
|
|
||||||
|
|
||||||
pb := progressbar.New(false)
|
pb := progressbar.New(false)
|
||||||
return pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream)
|
return pb.SendAlgorithm(algoProgressBarDescription, algorithm, requirements, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey any) error {
|
func (sdk *agentSDK) Data(ctx context.Context, dataset *os.File, filename string, privKey any) error {
|
||||||
md, err := generateMetadata(string(auth.DataProviderRole), privKey)
|
md, err := generateMetadata(string(auth.DataProviderRole), privKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -83,29 +81,28 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
dataBuffer := bytes.NewBuffer(dataset.Dataset)
|
|
||||||
|
|
||||||
pb := progressbar.New(false)
|
pb := progressbar.New(false)
|
||||||
return pb.SendData(dataProgressBarDescription, dataset.Filename, dataBuffer, &stream)
|
return pb.SendData(dataProgressBarDescription, filename, dataset, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
|
func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.File) error {
|
||||||
request := &agent.ResultRequest{}
|
request := &agent.ResultRequest{}
|
||||||
|
|
||||||
md, err := generateMetadata(string(auth.ConsumerRole), privKey)
|
md, err := generateMetadata(string(auth.ConsumerRole), privKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = metadata.NewOutgoingContext(ctx, md)
|
ctx = metadata.NewOutgoingContext(ctx, md)
|
||||||
stream, err := sdk.client.Result(ctx, request)
|
stream, err := sdk.client.Result(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
incomingmd, err := stream.Header()
|
incomingmd, err := stream.Header()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||||
@@ -116,27 +113,27 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) {
|
|||||||
|
|
||||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pb := progressbar.New(true)
|
pb := progressbar.New(true)
|
||||||
|
|
||||||
return pb.ReceiveResult(resultProgressDescription, fileSize, stream)
|
return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) {
|
func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error {
|
||||||
request := &agent.AttestationRequest{
|
request := &agent.AttestationRequest{
|
||||||
ReportData: reportData[:],
|
ReportData: reportData[:],
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := sdk.client.Attestation(ctx, request)
|
stream, err := sdk.client.Attestation(ctx, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
incomingmd, err := stream.Header()
|
incomingmd, err := stream.Header()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
fileSizeStr := incomingmd.Get(grpc.FileSizeKey)
|
||||||
@@ -147,12 +144,12 @@ func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) (
|
|||||||
|
|
||||||
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
fileSize, err := strconv.Atoi(fileSizeStr[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pb := progressbar.New(true)
|
pb := progressbar.New(true)
|
||||||
|
|
||||||
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream)
|
return pb.ReceiveAttestation(attestationProgressDescription, fileSize, stream, attestationFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
|
func signData(userID string, privKey crypto.Signer) ([]byte, error) {
|
||||||
|
|||||||
+61
-9
@@ -115,7 +115,21 @@ func TestAlgo(t *testing.T) {
|
|||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err)
|
svcCall := svc.On("Algo", mock.Anything, mock.Anything).Return(tc.err)
|
||||||
err = sdk.Algo(context.Background(), tc.algo, tc.userKey)
|
|
||||||
|
algo, err := os.CreateTemp("", "algo")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer os.Remove(algo.Name())
|
||||||
|
|
||||||
|
_, err = algo.Write(algorithm.Algorithm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = algo.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
algo, err = os.Open(algo.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sdk.Algo(context.Background(), algo, nil, tc.userKey)
|
||||||
|
|
||||||
st, _ := status.FromError(err)
|
st, _ := status.FromError(err)
|
||||||
|
|
||||||
@@ -212,7 +226,19 @@ func TestData(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr)
|
dataCall := svc.On("Data", mock.Anything, mock.Anything).Return(tc.svcErr)
|
||||||
|
|
||||||
err = sdk.Data(context.Background(), tc.data, tc.userKey)
|
data, err := os.CreateTemp("", "data")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = data.Write(dataset.Dataset)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = data.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err = os.Open(data.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = sdk.Data(context.Background(), data, tc.data.Filename, tc.userKey)
|
||||||
|
|
||||||
st, _ := status.FromError(err)
|
st, _ := status.FromError(err)
|
||||||
|
|
||||||
@@ -273,7 +299,7 @@ func TestResult(t *testing.T) {
|
|||||||
name: "Results not ready",
|
name: "Results not ready",
|
||||||
userKey: resultConsumer1Key,
|
userKey: resultConsumer1Key,
|
||||||
response: &agent.ResultResponse{
|
response: &agent.ResultResponse{
|
||||||
File: []byte(nil),
|
File: []byte{},
|
||||||
},
|
},
|
||||||
svcRes: nil,
|
svcRes: nil,
|
||||||
err: agent.ErrResultsNotReady,
|
err: agent.ErrResultsNotReady,
|
||||||
@@ -282,7 +308,7 @@ func TestResult(t *testing.T) {
|
|||||||
name: "All manifest items received",
|
name: "All manifest items received",
|
||||||
userKey: resultConsumer1Key,
|
userKey: resultConsumer1Key,
|
||||||
response: &agent.ResultResponse{
|
response: &agent.ResultResponse{
|
||||||
File: []byte(nil),
|
File: []byte{},
|
||||||
},
|
},
|
||||||
svcRes: nil,
|
svcRes: nil,
|
||||||
err: agent.ErrAllManifestItemsReceived,
|
err: agent.ErrAllManifestItemsReceived,
|
||||||
@@ -291,7 +317,7 @@ func TestResult(t *testing.T) {
|
|||||||
name: "Undeclared consumer",
|
name: "Undeclared consumer",
|
||||||
userKey: resultConsumer1Key,
|
userKey: resultConsumer1Key,
|
||||||
response: &agent.ResultResponse{
|
response: &agent.ResultResponse{
|
||||||
File: []byte(nil),
|
File: []byte{},
|
||||||
},
|
},
|
||||||
svcRes: nil,
|
svcRes: nil,
|
||||||
err: agent.ErrUndeclaredConsumer,
|
err: agent.ErrUndeclaredConsumer,
|
||||||
@@ -300,7 +326,17 @@ func TestResult(t *testing.T) {
|
|||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
svcCall := svc.On("Result", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||||
res, err := sdk.Result(context.Background(), tc.userKey)
|
|
||||||
|
resultFile, err := os.CreateTemp("", "result")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(resultFile.Name())
|
||||||
|
})
|
||||||
|
|
||||||
|
err = sdk.Result(context.Background(), tc.userKey, resultFile)
|
||||||
|
|
||||||
|
require.NoError(t, resultFile.Close())
|
||||||
|
|
||||||
st, ok := status.FromError(err)
|
st, ok := status.FromError(err)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -312,6 +348,10 @@ func TestResult(t *testing.T) {
|
|||||||
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
t.Errorf("%s: Expected error message %q, but got %q", tc.name, tc.err.Error(), st.Message())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res, err := os.ReadFile(resultFile.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tc.response.File, res, tc.name)
|
assert.Equal(t, tc.response.File, res, tc.name)
|
||||||
|
|
||||||
svcCall.Unset()
|
svcCall.Unset()
|
||||||
@@ -375,7 +415,7 @@ func TestAttestation(t *testing.T) {
|
|||||||
userKey: resultConsumerKey,
|
userKey: resultConsumerKey,
|
||||||
reportData: [agent.ReportDataSize]byte(reportData),
|
reportData: [agent.ReportDataSize]byte(reportData),
|
||||||
response: &agent.AttestationResponse{
|
response: &agent.AttestationResponse{
|
||||||
File: nil,
|
File: []byte{},
|
||||||
},
|
},
|
||||||
err: nil,
|
err: nil,
|
||||||
},
|
},
|
||||||
@@ -384,7 +424,7 @@ func TestAttestation(t *testing.T) {
|
|||||||
userKey: resultConsumerKey,
|
userKey: resultConsumerKey,
|
||||||
reportData: [agent.ReportDataSize]byte{},
|
reportData: [agent.ReportDataSize]byte{},
|
||||||
response: &agent.AttestationResponse{
|
response: &agent.AttestationResponse{
|
||||||
File: nil,
|
File: []byte{},
|
||||||
},
|
},
|
||||||
svcRes: nil,
|
svcRes: nil,
|
||||||
err: errors.New("invalid report data"),
|
err: errors.New("invalid report data"),
|
||||||
@@ -395,7 +435,16 @@ func TestAttestation(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
svcCall := svc.On("Attestation", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err)
|
||||||
|
|
||||||
res, err := sdk.Attestation(context.Background(), tc.reportData)
|
file, err := os.CreateTemp("", "attestation")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
os.Remove(file.Name())
|
||||||
|
})
|
||||||
|
|
||||||
|
err = sdk.Attestation(context.Background(), tc.reportData, file)
|
||||||
|
|
||||||
|
require.NoError(t, file.Close())
|
||||||
|
|
||||||
st, ok := status.FromError(err)
|
st, ok := status.FromError(err)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -408,6 +457,9 @@ func TestAttestation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res, err := os.ReadFile(file.Name())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tc.response.File, res, tc.name)
|
assert.Equal(t, tc.response.File, res, tc.name)
|
||||||
|
|
||||||
svcCall.Unset()
|
svcCall.Unset()
|
||||||
|
|||||||
+27
-52
@@ -7,8 +7,7 @@ package mocks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
context "context"
|
||||||
|
os "os"
|
||||||
agent "github.com/ultravioletrs/cocos/agent"
|
|
||||||
|
|
||||||
mock "github.com/stretchr/testify/mock"
|
mock "github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
@@ -18,17 +17,17 @@ type SDK struct {
|
|||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// Algo provides a mock function with given fields: ctx, algorithm, privKey
|
// Algo provides a mock function with given fields: ctx, algorithm, requirements, privKey
|
||||||
func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey interface{}) error {
|
func (_m *SDK) Algo(ctx context.Context, algorithm *os.File, requirements *os.File, privKey interface{}) error {
|
||||||
ret := _m.Called(ctx, algorithm, privKey)
|
ret := _m.Called(ctx, algorithm, requirements, privKey)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Algo")
|
panic("no return value specified for Algo")
|
||||||
}
|
}
|
||||||
|
|
||||||
var r0 error
|
var r0 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Algorithm, interface{}) error); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *os.File, *os.File, interface{}) error); ok {
|
||||||
r0 = rf(ctx, algorithm, privKey)
|
r0 = rf(ctx, algorithm, requirements, privKey)
|
||||||
} else {
|
} else {
|
||||||
r0 = ret.Error(0)
|
r0 = ret.Error(0)
|
||||||
}
|
}
|
||||||
@@ -36,47 +35,35 @@ func (_m *SDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKey inte
|
|||||||
return r0
|
return r0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attestation provides a mock function with given fields: ctx, reportData
|
// Attestation provides a mock function with given fields: ctx, reportData, attestationFile
|
||||||
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) {
|
func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error {
|
||||||
ret := _m.Called(ctx, reportData)
|
ret := _m.Called(ctx, reportData, attestationFile)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Attestation")
|
panic("no return value specified for Attestation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var r0 []byte
|
var r0 error
|
||||||
var r1 error
|
if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok {
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok {
|
r0 = rf(ctx, reportData, attestationFile)
|
||||||
return rf(ctx, reportData)
|
|
||||||
}
|
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok {
|
|
||||||
r0 = rf(ctx, reportData)
|
|
||||||
} else {
|
} else {
|
||||||
if ret.Get(0) != nil {
|
r0 = ret.Error(0)
|
||||||
r0 = ret.Get(0).([]byte)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok {
|
return r0
|
||||||
r1 = rf(ctx, reportData)
|
|
||||||
} else {
|
|
||||||
r1 = ret.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0, r1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Data provides a mock function with given fields: ctx, dataset, privKey
|
// Data provides a mock function with given fields: ctx, dataset, filename, privKey
|
||||||
func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interface{}) error {
|
func (_m *SDK) Data(ctx context.Context, dataset *os.File, filename string, privKey interface{}) error {
|
||||||
ret := _m.Called(ctx, dataset, privKey)
|
ret := _m.Called(ctx, dataset, filename, privKey)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Data")
|
panic("no return value specified for Data")
|
||||||
}
|
}
|
||||||
|
|
||||||
var r0 error
|
var r0 error
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Dataset, interface{}) error); ok {
|
if rf, ok := ret.Get(0).(func(context.Context, *os.File, string, interface{}) error); ok {
|
||||||
r0 = rf(ctx, dataset, privKey)
|
r0 = rf(ctx, dataset, filename, privKey)
|
||||||
} else {
|
} else {
|
||||||
r0 = ret.Error(0)
|
r0 = ret.Error(0)
|
||||||
}
|
}
|
||||||
@@ -84,34 +71,22 @@ func (_m *SDK) Data(ctx context.Context, dataset agent.Dataset, privKey interfac
|
|||||||
return r0
|
return r0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Result provides a mock function with given fields: ctx, privKey
|
// Result provides a mock function with given fields: ctx, privKey, resultFile
|
||||||
func (_m *SDK) Result(ctx context.Context, privKey interface{}) ([]byte, error) {
|
func (_m *SDK) Result(ctx context.Context, privKey interface{}, resultFile *os.File) error {
|
||||||
ret := _m.Called(ctx, privKey)
|
ret := _m.Called(ctx, privKey, resultFile)
|
||||||
|
|
||||||
if len(ret) == 0 {
|
if len(ret) == 0 {
|
||||||
panic("no return value specified for Result")
|
panic("no return value specified for Result")
|
||||||
}
|
}
|
||||||
|
|
||||||
var r0 []byte
|
var r0 error
|
||||||
var r1 error
|
if rf, ok := ret.Get(0).(func(context.Context, interface{}, *os.File) error); ok {
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, interface{}) ([]byte, error)); ok {
|
r0 = rf(ctx, privKey, resultFile)
|
||||||
return rf(ctx, privKey)
|
|
||||||
}
|
|
||||||
if rf, ok := ret.Get(0).(func(context.Context, interface{}) []byte); ok {
|
|
||||||
r0 = rf(ctx, privKey)
|
|
||||||
} else {
|
} else {
|
||||||
if ret.Get(0) != nil {
|
r0 = ret.Error(0)
|
||||||
r0 = ret.Get(0).([]byte)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rf, ok := ret.Get(1).(func(context.Context, interface{}) error); ok {
|
return r0
|
||||||
r1 = rf(ctx, privKey)
|
|
||||||
} else {
|
|
||||||
r1 = ret.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r0, r1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
// NewSDK creates a new instance of SDK. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||||
|
|||||||
Reference in New Issue
Block a user