mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Add agent pkg tests (#271)
* add agent tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
faaddc3571
commit
f6b69d65df
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package binary
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
func TestNewAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "/path/to/algo"
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
|
||||
|
||||
b, ok := algo.(*binary)
|
||||
if !ok {
|
||||
t.Fatalf("NewAlgorithm did not return a *binary")
|
||||
}
|
||||
|
||||
if b.algoFile != algoFile {
|
||||
t.Errorf("Expected algoFile to be %s, got %s", algoFile, b.algoFile)
|
||||
}
|
||||
|
||||
if len(b.args) != len(args) {
|
||||
t.Errorf("Expected %d args, got %d", len(args), len(b.args))
|
||||
}
|
||||
|
||||
for i, arg := range args {
|
||||
if b.args[i] != arg {
|
||||
t.Errorf("Expected arg %d to be %s, got %s", i, arg, b.args[i])
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := b.stderr.(*algorithm.Stderr); !ok {
|
||||
t.Errorf("Expected stderr to be *algorithm.Stderr")
|
||||
}
|
||||
|
||||
if _, ok := b.stdout.(*algorithm.Stdout); !ok {
|
||||
t.Errorf("Expected stdout to be *algorithm.Stdout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBinaryRun(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
algoFile string
|
||||
args []string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Successful execution",
|
||||
algoFile: "echo",
|
||||
args: []string{"Hello, World!"},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Non-existent binary",
|
||||
algoFile: "non_existent_binary",
|
||||
args: []string{},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
|
||||
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args).(*binary)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
b.stdout = &stdout
|
||||
b.stderr = &stderr
|
||||
|
||||
err := b.Run()
|
||||
|
||||
if tt.expectedError && err == nil {
|
||||
t.Errorf("Expected an error, but got none")
|
||||
}
|
||||
|
||||
if !tt.expectedError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !tt.expectedError {
|
||||
if stdout.Len() == 0 {
|
||||
t.Errorf("Expected non-empty stdout")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package docker
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
// TestNewAlgorithm tests the NewAlgorithm function.
|
||||
func TestNewAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "/path/to/algo.tar"
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile)
|
||||
|
||||
d, ok := algo.(*docker)
|
||||
assert.True(t, ok, "NewAlgorithm should return a *docker")
|
||||
assert.Equal(t, algoFile, d.algoFile, "algoFile should be set correctly")
|
||||
assert.NotNil(t, d.logger, "logger should be set")
|
||||
assert.IsType(t, &algorithm.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr")
|
||||
assert.IsType(t, &algorithm.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout")
|
||||
}
|
||||
@@ -67,6 +67,13 @@ func (p *python) Run() error {
|
||||
|
||||
pythonPath := filepath.Join(venvPath, "bin", "python")
|
||||
|
||||
updatePipCmd := exec.Command(pythonPath, "-m", "pip", "install", "--upgrade", "pip")
|
||||
updatePipCmd.Stderr = p.stderr
|
||||
updatePipCmd.Stdout = p.stdout
|
||||
if err := updatePipCmd.Run(); err != nil {
|
||||
return fmt.Errorf("error updating pip: %v", err)
|
||||
}
|
||||
|
||||
if p.requirementsFile != "" {
|
||||
rcmd := exec.Command(pythonPath, "-m", "pip", "install", "-r", p.requirementsFile)
|
||||
rcmd.Stderr = p.stderr
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package python
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
const runtime = "python3"
|
||||
|
||||
func TestPythonRunTimeToContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
newCtx := PythonRunTimeToContext(ctx, runtime)
|
||||
|
||||
md, ok := metadata.FromOutgoingContext(newCtx)
|
||||
if !ok {
|
||||
t.Fatal("Expected metadata in context")
|
||||
}
|
||||
|
||||
values := md.Get(PyRuntimeKey)
|
||||
if len(values) != 1 || values[0] != runtime {
|
||||
t.Errorf("Expected runtime %s, got %v", runtime, values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPythonRunTimeFromContext(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(PyRuntimeKey, runtime))
|
||||
|
||||
got := PythonRunTimeFromContext(ctx)
|
||||
if got != runtime {
|
||||
t.Errorf("Expected runtime %s, got %s", runtime, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAlgorithm(t *testing.T) {
|
||||
logger := &slog.Logger{}
|
||||
eventsSvc := new(mocks.Service)
|
||||
requirementsFile := "requirements.txt"
|
||||
algoFile := "algorithm.py"
|
||||
args := []string{"--arg1", "value1"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args)
|
||||
|
||||
p, ok := algo.(*python)
|
||||
if !ok {
|
||||
t.Fatal("Expected *python type")
|
||||
}
|
||||
|
||||
if p.runtime != runtime {
|
||||
t.Errorf("Expected runtime %s, got %s", runtime, p.runtime)
|
||||
}
|
||||
if p.requirementsFile != requirementsFile {
|
||||
t.Errorf("Expected requirementsFile %s, got %s", requirementsFile, p.requirementsFile)
|
||||
}
|
||||
if p.algoFile != algoFile {
|
||||
t.Errorf("Expected algoFile %s, got %s", algoFile, p.algoFile)
|
||||
}
|
||||
if len(p.args) != len(args) {
|
||||
t.Errorf("Expected %d args, got %d", len(args), len(p.args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "python-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
scriptContent := []byte("print('Hello, World!')")
|
||||
scriptPath := filepath.Join(tmpDir, "test_script.py")
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
algo := &python{
|
||||
algoFile: scriptPath,
|
||||
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
|
||||
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
|
||||
runtime: "python3",
|
||||
}
|
||||
|
||||
err = algo.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expectedOutput := "Hello, World!\n"
|
||||
if !strings.Contains(stdout.String(), expectedOutput) {
|
||||
t.Errorf("Expected output to contain %q, got %q", expectedOutput, stdout.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunWithRequirements(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "python-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
scriptContent := []byte("import requests\nprint(requests.__version__)")
|
||||
scriptPath := filepath.Join(tmpDir, "test_script.py")
|
||||
if err := os.WriteFile(scriptPath, scriptContent, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
requirementsContent := []byte("requests==2.26.0")
|
||||
requirementsPath := filepath.Join(tmpDir, "requirements.txt")
|
||||
if err := os.WriteFile(requirementsPath, requirementsContent, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
eventsSvc := new(mocks.Service)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
algo := &python{
|
||||
algoFile: scriptPath,
|
||||
requirementsFile: requirementsPath,
|
||||
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
|
||||
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
|
||||
runtime: "python3",
|
||||
}
|
||||
|
||||
err = algo.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(stdout.String(), "2.26.0") {
|
||||
t.Errorf("Expected output to contain requests version 2.26.0, got %q", stdout.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
func TestNewAlgorithm(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "test.wasm"
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
|
||||
|
||||
w, ok := algo.(*wasm)
|
||||
if !ok {
|
||||
t.Fatalf("NewAlgorithm did not return a *wasm")
|
||||
}
|
||||
|
||||
if w.algoFile != algoFile {
|
||||
t.Errorf("Expected algoFile to be %s, got %s", algoFile, w.algoFile)
|
||||
}
|
||||
|
||||
if len(w.args) != len(args) {
|
||||
t.Errorf("Expected %d args, got %d", len(args), len(w.args))
|
||||
}
|
||||
|
||||
_, ok = w.stderr.(*algorithm.Stderr)
|
||||
if !ok {
|
||||
t.Errorf("Expected stderr to be *algorithm.Stderr")
|
||||
}
|
||||
|
||||
_, ok = w.stdout.(*algorithm.Stdout)
|
||||
if !ok {
|
||||
t.Errorf("Expected stdout to be *algorithm.Stdout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunError(t *testing.T) {
|
||||
// Mock exec.Command to return an error
|
||||
execCommand = mockExecCommandError
|
||||
defer func() { execCommand = exec.Command }()
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "test.wasm"
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
w := NewAlgorithm(logger, eventsSvc, algoFile, args).(*wasm)
|
||||
|
||||
err := w.Run()
|
||||
if err == nil {
|
||||
t.Errorf("Run() should have returned an error")
|
||||
}
|
||||
}
|
||||
|
||||
func mockExecCommand(command string, args ...string) *exec.Cmd {
|
||||
cs := []string{"-test.run=TestHelperProcess", "--", command}
|
||||
cs = append(cs, args...)
|
||||
cmd := exec.Command(os.Args[0], cs...)
|
||||
cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func mockExecCommandError(command string, args ...string) *exec.Cmd {
|
||||
cmd := mockExecCommand(command, args...)
|
||||
cmd.Env = append(cmd.Env, "GO_WANT_HELPER_PROCESS_ERROR=1")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func TestHelperProcess(t *testing.T) {
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
|
||||
return
|
||||
}
|
||||
if os.Getenv("GO_WANT_HELPER_PROCESS_ERROR") == "1" {
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
var execCommand = exec.Command
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
const svcErr = "Service Error"
|
||||
|
||||
func TestAlgoEndpoint(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
tests := []struct {
|
||||
name string
|
||||
req algoReq
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
req: algoReq{Algorithm: []byte("algorithm")},
|
||||
},
|
||||
{
|
||||
name: "Validation Error",
|
||||
req: algoReq{},
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
req: algoReq{Algorithm: []byte("algorithm")},
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == svcErr {
|
||||
svc.On("Algo", context.Background(), agent.Algorithm{Algorithm: tt.req.Algorithm}).Return(errors.New("")).Once()
|
||||
} else {
|
||||
svc.On("Algo", context.Background(), agent.Algorithm{Algorithm: tt.req.Algorithm}).Return(nil).Once()
|
||||
}
|
||||
endpoint := algoEndpoint(svc)
|
||||
_, err := endpoint(context.Background(), tt.req)
|
||||
if (err != nil) != tt.expectedErr {
|
||||
t.Errorf("algoEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataEndpoint(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
tests := []struct {
|
||||
name string
|
||||
req dataReq
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
req: dataReq{Dataset: []byte("dataset")},
|
||||
},
|
||||
{
|
||||
name: "Validation Error",
|
||||
req: dataReq{},
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
req: dataReq{Dataset: []byte("dataset")},
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == svcErr {
|
||||
svc.On("Data", context.Background(), agent.Dataset{Dataset: tt.req.Dataset}).Return(errors.New("")).Once()
|
||||
} else {
|
||||
svc.On("Data", context.Background(), agent.Dataset{Dataset: tt.req.Dataset}).Return(nil).Once()
|
||||
}
|
||||
endpoint := dataEndpoint(svc)
|
||||
_, err := endpoint(context.Background(), tt.req)
|
||||
if (err != nil) != tt.expectedErr {
|
||||
t.Errorf("dataEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResultEndpoint(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
tests := []struct {
|
||||
name string
|
||||
req resultReq
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
req: resultReq{},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
req: resultReq{},
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == svcErr {
|
||||
svc.On("Result", context.Background()).Return([]byte{}, errors.New("")).Once()
|
||||
} else {
|
||||
svc.On("Result", context.Background()).Return([]byte{}, nil).Once()
|
||||
}
|
||||
endpoint := resultEndpoint(svc)
|
||||
res, err := endpoint(context.Background(), tt.req)
|
||||
if (err != nil) != tt.expectedErr {
|
||||
t.Errorf("resultEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
|
||||
}
|
||||
if err == nil {
|
||||
_, ok := res.(resultRes)
|
||||
if !ok {
|
||||
t.Errorf("resultEndpoint() returned unexpected type %T", res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttestationEndpoint(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
tests := []struct {
|
||||
name string
|
||||
req attestationReq
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))},
|
||||
},
|
||||
{
|
||||
name: "Service Error",
|
||||
req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))},
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.name == svcErr {
|
||||
svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, errors.New("")).Once()
|
||||
} else {
|
||||
svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, nil).Once()
|
||||
}
|
||||
endpoint := attestationEndpoint(svc)
|
||||
res, err := endpoint(context.Background(), tt.req)
|
||||
if (err != nil) != tt.expectedErr {
|
||||
t.Errorf("attestationEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
|
||||
}
|
||||
if err == nil {
|
||||
_, ok := res.(attestationRes)
|
||||
if !ok {
|
||||
t.Errorf("attestationEndpoint() returned unexpected type %T", res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockAgentService_AlgoServer struct {
|
||||
grpc.ServerStream
|
||||
mock.Mock
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AlgoServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AlgoServer) Recv() (*agent.AlgoRequest, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*agent.AlgoRequest), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_AlgoServer) SendAndClose(resp *agent.AlgoResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAgentService_DataServer struct {
|
||||
grpc.ServerStream
|
||||
mock.Mock
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *MockAgentService_DataServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_DataServer) Recv() (*agent.DataRequest, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*agent.DataRequest), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockAgentService_DataServer) SendAndClose(resp *agent.DataResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAgentService_ResultServer struct {
|
||||
grpc.ServerStream
|
||||
mock.Mock
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
|
||||
args := m.Called(resp)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF)
|
||||
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil)
|
||||
|
||||
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo"), Requirements: []byte("req")}).Return(nil)
|
||||
|
||||
err := server.Algo(mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}, nil).Once()
|
||||
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF)
|
||||
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil)
|
||||
|
||||
mockService.On("Data", context.Background(), agent.Dataset{Dataset: []byte("data"), Filename: "test.txt"}).Return(nil)
|
||||
|
||||
err := server.Data(mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
mockStream := &MockAgentService_ResultServer{ctx: context.Background()}
|
||||
mockService.On("Result", mock.Anything).Return([]byte("result data"), nil)
|
||||
mockStream.On("Send", mock.AnythingOfType("*agent.ResultResponse")).Return(nil)
|
||||
|
||||
err := server.Result(&agent.ResultRequest{}, mockStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockStream.AssertExpectations(t)
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestAttestation(t *testing.T) {
|
||||
mockService := new(mocks.Service)
|
||||
server := NewServer(mockService)
|
||||
|
||||
reportData := [agent.ReportDataSize]byte{}
|
||||
mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil)
|
||||
|
||||
resp, err := server.Attestation(context.Background(), &agent.AttestationRequest{ReportData: reportData[:]})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("attestation data"), resp.File)
|
||||
|
||||
mockService.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDecodeAlgoRequest(t *testing.T) {
|
||||
req := &agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}
|
||||
decoded, err := decodeAlgoRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, algoReq{Algorithm: []byte("algo"), Requirements: []byte("req")}, decoded)
|
||||
}
|
||||
|
||||
func TestEncodeAlgoResponse(t *testing.T) {
|
||||
encoded, err := encodeAlgoResponse(context.Background(), algoRes{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.AlgoResponse{}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeDataRequest(t *testing.T) {
|
||||
req := &agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}
|
||||
decoded, err := decodeDataRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataReq{Dataset: []byte("data"), Filename: "test.txt"}, decoded)
|
||||
}
|
||||
|
||||
func TestEncodeDataResponse(t *testing.T) {
|
||||
encoded, err := encodeDataResponse(context.Background(), dataRes{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.DataResponse{}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeResultRequest(t *testing.T) {
|
||||
decoded, err := decodeResultRequest(context.Background(), &agent.ResultRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resultReq{}, decoded)
|
||||
}
|
||||
|
||||
func TestEncodeResultResponse(t *testing.T) {
|
||||
encoded, err := encodeResultResponse(context.Background(), resultRes{File: []byte("result")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.ResultResponse{File: []byte("result")}, encoded)
|
||||
}
|
||||
|
||||
func TestDecodeAttestationRequest(t *testing.T) {
|
||||
reportData := [agent.ReportDataSize]byte{}
|
||||
req := &agent.AttestationRequest{ReportData: reportData[:]}
|
||||
decoded, err := decodeAttestationRequest(context.Background(), req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, attestationReq{ReportData: reportData}, decoded)
|
||||
}
|
||||
|
||||
func TestEncodeAttestationResponse(t *testing.T) {
|
||||
encoded, err := encodeAttestationResponse(context.Background(), attestationRes{File: []byte("attestation")})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &agent.AttestationResponse{File: []byte("attestation")}, encoded)
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestDatasetsString(t *testing.T) {
|
||||
datasets := Datasets{
|
||||
{
|
||||
Hash: [32]byte{1, 2, 3},
|
||||
UserKey: []byte("user_key"),
|
||||
Filename: "test.dat",
|
||||
},
|
||||
}
|
||||
|
||||
expected := `[{"hash":[1,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"user_key":"dXNlcl9rZXk=","filename":"test.dat"}]`
|
||||
result := datasets.String()
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Datasets.String() = %v, want %v", result, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexToContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
index := 5
|
||||
|
||||
newCtx := IndexToContext(ctx, index)
|
||||
result, ok := IndexFromContext(newCtx)
|
||||
|
||||
if !ok {
|
||||
t.Errorf("IndexFromContext() ok = false, want true")
|
||||
}
|
||||
|
||||
if result != index {
|
||||
t.Errorf("IndexFromContext() = %v, want %v", result, index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecompressFromContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "No decompress metadata",
|
||||
ctx: context.Background(),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Decompress true",
|
||||
ctx: metadata.NewIncomingContext(
|
||||
context.Background(),
|
||||
metadata.Pairs(DecompressKey, "true"),
|
||||
),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Decompress false",
|
||||
ctx: metadata.NewIncomingContext(
|
||||
context.Background(),
|
||||
metadata.Pairs(DecompressKey, "false"),
|
||||
),
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DecompressFromContext(tt.ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("DecompressFromContext() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecompressToContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
decompress := true
|
||||
|
||||
newCtx := DecompressToContext(ctx, decompress)
|
||||
md, ok := metadata.FromOutgoingContext(newCtx)
|
||||
|
||||
if !ok {
|
||||
t.Errorf("metadata.FromOutgoingContext() ok = false, want true")
|
||||
}
|
||||
|
||||
vals := md.Get(DecompressKey)
|
||||
if len(vals) != 1 {
|
||||
t.Errorf("len(md.Get(DecompressKey)) = %v, want 1", len(vals))
|
||||
}
|
||||
|
||||
if vals[0] != "true" {
|
||||
t.Errorf("md.Get(DecompressKey)[0] = %v, want 'true'", vals[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentConfigJSON(t *testing.T) {
|
||||
config := AgentConfig{
|
||||
LogLevel: "info",
|
||||
Host: "localhost",
|
||||
Port: "8080",
|
||||
CertFile: "cert.pem",
|
||||
KeyFile: "key.pem",
|
||||
ServerCAFile: "server_ca.pem",
|
||||
ClientCAFile: "client_ca.pem",
|
||||
AttestedTls: true,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal AgentConfig: %v", err)
|
||||
}
|
||||
|
||||
var unmarshaledConfig AgentConfig
|
||||
err = json.Unmarshal(data, &unmarshaledConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal AgentConfig: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(config, unmarshaledConfig) {
|
||||
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
writeErr error
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *mockConn) Write(p []byte) (n int, err error) {
|
||||
if m.writeErr != nil {
|
||||
return 0, m.writeErr
|
||||
}
|
||||
return m.buf.Write(p)
|
||||
}
|
||||
|
||||
func TestSendEventSuccess(t *testing.T) {
|
||||
mockConnection := &mockConn{}
|
||||
|
||||
svc, err := New("test_service", "12345", mockConnection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
err = svc.SendEvent("test_event", "success", details)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var writtenMessage manager.ClientStreamMessage
|
||||
err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test_event", writtenMessage.GetAgentEvent().EventType)
|
||||
assert.Equal(t, "12345", writtenMessage.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, "test_service", writtenMessage.GetAgentEvent().Originator)
|
||||
assert.Equal(t, "success", writtenMessage.GetAgentEvent().Status)
|
||||
|
||||
now := time.Now()
|
||||
eventTimestamp := writtenMessage.GetAgentEvent().GetTimestamp().AsTime()
|
||||
assert.WithinDuration(t, now, eventTimestamp, 1*time.Second)
|
||||
}
|
||||
|
||||
func TestSendEventFailure(t *testing.T) {
|
||||
mockConnection := &mockConn{writeErr: errors.New("write error")}
|
||||
|
||||
svc, err := New("test_service", "12345", mockConnection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
err = svc.SendEvent("test_event", "failure", details)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "write error", err.Error())
|
||||
|
||||
assert.Len(t, svc.(*service).cachedMessages, 1)
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
mockConnection := &mockConn{}
|
||||
|
||||
svc, err := New("test_service", "12345", mockConnection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
svc.Close()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
err = svc.SendEvent("test_event", "success", details)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
Reference in New Issue
Block a user