NOISSUE - Add agent pkg tests (#271)

* add agent tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-08 16:29:21 +03:00
committed by GitHub
parent faaddc3571
commit f6b69d65df
9 changed files with 948 additions and 0 deletions
+100
View File
@@ -0,0 +1,100 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package binary
import (
"bytes"
"log/slog"
"os"
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
func TestNewAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
algoFile := "/path/to/algo"
args := []string{"arg1", "arg2"}
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
b, ok := algo.(*binary)
if !ok {
t.Fatalf("NewAlgorithm did not return a *binary")
}
if b.algoFile != algoFile {
t.Errorf("Expected algoFile to be %s, got %s", algoFile, b.algoFile)
}
if len(b.args) != len(args) {
t.Errorf("Expected %d args, got %d", len(args), len(b.args))
}
for i, arg := range args {
if b.args[i] != arg {
t.Errorf("Expected arg %d to be %s, got %s", i, arg, b.args[i])
}
}
if _, ok := b.stderr.(*algorithm.Stderr); !ok {
t.Errorf("Expected stderr to be *algorithm.Stderr")
}
if _, ok := b.stdout.(*algorithm.Stdout); !ok {
t.Errorf("Expected stdout to be *algorithm.Stdout")
}
}
func TestBinaryRun(t *testing.T) {
tests := []struct {
name string
algoFile string
args []string
expectedError bool
}{
{
name: "Successful execution",
algoFile: "echo",
args: []string{"Hello, World!"},
expectedError: false,
},
{
name: "Non-existent binary",
algoFile: "non_existent_binary",
args: []string{},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args).(*binary)
var stdout, stderr bytes.Buffer
b.stdout = &stdout
b.stderr = &stderr
err := b.Run()
if tt.expectedError && err == nil {
t.Errorf("Expected an error, but got none")
}
if !tt.expectedError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !tt.expectedError {
if stdout.Len() == 0 {
t.Errorf("Expected non-empty stdout")
}
}
})
}
}
+29
View File
@@ -0,0 +1,29 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package docker
import (
"log/slog"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
// TestNewAlgorithm tests the NewAlgorithm function.
func TestNewAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
algoFile := "/path/to/algo.tar"
algo := NewAlgorithm(logger, eventsSvc, algoFile)
d, ok := algo.(*docker)
assert.True(t, ok, "NewAlgorithm should return a *docker")
assert.Equal(t, algoFile, d.algoFile, "algoFile should be set correctly")
assert.NotNil(t, d.logger, "logger should be set")
assert.IsType(t, &algorithm.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr")
assert.IsType(t, &algorithm.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout")
}
+7
View File
@@ -67,6 +67,13 @@ func (p *python) Run() error {
pythonPath := filepath.Join(venvPath, "bin", "python")
updatePipCmd := exec.Command(pythonPath, "-m", "pip", "install", "--upgrade", "pip")
updatePipCmd.Stderr = p.stderr
updatePipCmd.Stdout = p.stdout
if err := updatePipCmd.Run(); err != nil {
return fmt.Errorf("error updating pip: %v", err)
}
if p.requirementsFile != "" {
rcmd := exec.Command(pythonPath, "-m", "pip", "install", "-r", p.requirementsFile)
rcmd.Stderr = p.stderr
+148
View File
@@ -0,0 +1,148 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package python
import (
"bytes"
"context"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events/mocks"
"google.golang.org/grpc/metadata"
)
const runtime = "python3"
func TestPythonRunTimeToContext(t *testing.T) {
ctx := context.Background()
newCtx := PythonRunTimeToContext(ctx, runtime)
md, ok := metadata.FromOutgoingContext(newCtx)
if !ok {
t.Fatal("Expected metadata in context")
}
values := md.Get(PyRuntimeKey)
if len(values) != 1 || values[0] != runtime {
t.Errorf("Expected runtime %s, got %v", runtime, values)
}
}
func TestPythonRunTimeFromContext(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(PyRuntimeKey, runtime))
got := PythonRunTimeFromContext(ctx)
if got != runtime {
t.Errorf("Expected runtime %s, got %s", runtime, got)
}
}
func TestNewAlgorithm(t *testing.T) {
logger := &slog.Logger{}
eventsSvc := new(mocks.Service)
requirementsFile := "requirements.txt"
algoFile := "algorithm.py"
args := []string{"--arg1", "value1"}
algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args)
p, ok := algo.(*python)
if !ok {
t.Fatal("Expected *python type")
}
if p.runtime != runtime {
t.Errorf("Expected runtime %s, got %s", runtime, p.runtime)
}
if p.requirementsFile != requirementsFile {
t.Errorf("Expected requirementsFile %s, got %s", requirementsFile, p.requirementsFile)
}
if p.algoFile != algoFile {
t.Errorf("Expected algoFile %s, got %s", algoFile, p.algoFile)
}
if len(p.args) != len(args) {
t.Errorf("Expected %d args, got %d", len(args), len(p.args))
}
}
func TestRun(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "python-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
scriptContent := []byte("print('Hello, World!')")
scriptPath := filepath.Join(tmpDir, "test_script.py")
if err := os.WriteFile(scriptPath, scriptContent, 0o644); err != nil {
t.Fatal(err)
}
eventsSvc := new(mocks.Service)
var stdout, stderr bytes.Buffer
algo := &python{
algoFile: scriptPath,
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
runtime: "python3",
}
err = algo.Run()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
expectedOutput := "Hello, World!\n"
if !strings.Contains(stdout.String(), expectedOutput) {
t.Errorf("Expected output to contain %q, got %q", expectedOutput, stdout.String())
}
}
func TestRunWithRequirements(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "python-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpDir)
scriptContent := []byte("import requests\nprint(requests.__version__)")
scriptPath := filepath.Join(tmpDir, "test_script.py")
if err := os.WriteFile(scriptPath, scriptContent, 0o644); err != nil {
t.Fatal(err)
}
requirementsContent := []byte("requests==2.26.0")
requirementsPath := filepath.Join(tmpDir, "requirements.txt")
if err := os.WriteFile(requirementsPath, requirementsContent, 0o644); err != nil {
t.Fatal(err)
}
eventsSvc := new(mocks.Service)
var stdout, stderr bytes.Buffer
algo := &python{
algoFile: scriptPath,
requirementsFile: requirementsPath,
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
runtime: "python3",
}
err = algo.Run()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !strings.Contains(stdout.String(), "2.26.0") {
t.Errorf("Expected output to contain requests version 2.26.0, got %q", stdout.String())
}
}
+89
View File
@@ -0,0 +1,89 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package wasm
import (
"log/slog"
"os"
"os/exec"
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
func TestNewAlgorithm(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
algoFile := "test.wasm"
args := []string{"arg1", "arg2"}
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
w, ok := algo.(*wasm)
if !ok {
t.Fatalf("NewAlgorithm did not return a *wasm")
}
if w.algoFile != algoFile {
t.Errorf("Expected algoFile to be %s, got %s", algoFile, w.algoFile)
}
if len(w.args) != len(args) {
t.Errorf("Expected %d args, got %d", len(args), len(w.args))
}
_, ok = w.stderr.(*algorithm.Stderr)
if !ok {
t.Errorf("Expected stderr to be *algorithm.Stderr")
}
_, ok = w.stdout.(*algorithm.Stdout)
if !ok {
t.Errorf("Expected stdout to be *algorithm.Stdout")
}
}
func TestRunError(t *testing.T) {
// Mock exec.Command to return an error
execCommand = mockExecCommandError
defer func() { execCommand = exec.Command }()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
algoFile := "test.wasm"
args := []string{"arg1", "arg2"}
w := NewAlgorithm(logger, eventsSvc, algoFile, args).(*wasm)
err := w.Run()
if err == nil {
t.Errorf("Run() should have returned an error")
}
}
func mockExecCommand(command string, args ...string) *exec.Cmd {
cs := []string{"-test.run=TestHelperProcess", "--", command}
cs = append(cs, args...)
cmd := exec.Command(os.Args[0], cs...)
cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"}
return cmd
}
func mockExecCommandError(command string, args ...string) *exec.Cmd {
cmd := mockExecCommand(command, args...)
cmd.Env = append(cmd.Env, "GO_WANT_HELPER_PROCESS_ERROR=1")
return cmd
}
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
}
if os.Getenv("GO_WANT_HELPER_PROCESS_ERROR") == "1" {
os.Exit(1)
}
os.Exit(0)
}
var execCommand = exec.Command
+173
View File
@@ -0,0 +1,173 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"errors"
"testing"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/mocks"
"golang.org/x/crypto/sha3"
)
const svcErr = "Service Error"
func TestAlgoEndpoint(t *testing.T) {
svc := new(mocks.Service)
tests := []struct {
name string
req algoReq
expectedErr bool
}{
{
name: "Success",
req: algoReq{Algorithm: []byte("algorithm")},
},
{
name: "Validation Error",
req: algoReq{},
expectedErr: true,
},
{
name: "Service Error",
req: algoReq{Algorithm: []byte("algorithm")},
expectedErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == svcErr {
svc.On("Algo", context.Background(), agent.Algorithm{Algorithm: tt.req.Algorithm}).Return(errors.New("")).Once()
} else {
svc.On("Algo", context.Background(), agent.Algorithm{Algorithm: tt.req.Algorithm}).Return(nil).Once()
}
endpoint := algoEndpoint(svc)
_, err := endpoint(context.Background(), tt.req)
if (err != nil) != tt.expectedErr {
t.Errorf("algoEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
}
})
}
}
func TestDataEndpoint(t *testing.T) {
svc := new(mocks.Service)
tests := []struct {
name string
req dataReq
expectedErr bool
}{
{
name: "Success",
req: dataReq{Dataset: []byte("dataset")},
},
{
name: "Validation Error",
req: dataReq{},
expectedErr: true,
},
{
name: "Service Error",
req: dataReq{Dataset: []byte("dataset")},
expectedErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == svcErr {
svc.On("Data", context.Background(), agent.Dataset{Dataset: tt.req.Dataset}).Return(errors.New("")).Once()
} else {
svc.On("Data", context.Background(), agent.Dataset{Dataset: tt.req.Dataset}).Return(nil).Once()
}
endpoint := dataEndpoint(svc)
_, err := endpoint(context.Background(), tt.req)
if (err != nil) != tt.expectedErr {
t.Errorf("dataEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
}
})
}
}
func TestResultEndpoint(t *testing.T) {
svc := new(mocks.Service)
tests := []struct {
name string
req resultReq
expectedErr bool
}{
{
name: "Success",
req: resultReq{},
},
{
name: "Service Error",
req: resultReq{},
expectedErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == svcErr {
svc.On("Result", context.Background()).Return([]byte{}, errors.New("")).Once()
} else {
svc.On("Result", context.Background()).Return([]byte{}, nil).Once()
}
endpoint := resultEndpoint(svc)
res, err := endpoint(context.Background(), tt.req)
if (err != nil) != tt.expectedErr {
t.Errorf("resultEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
}
if err == nil {
_, ok := res.(resultRes)
if !ok {
t.Errorf("resultEndpoint() returned unexpected type %T", res)
}
}
})
}
}
func TestAttestationEndpoint(t *testing.T) {
svc := new(mocks.Service)
tests := []struct {
name string
req attestationReq
expectedErr bool
}{
{
name: "Success",
req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))},
},
{
name: "Service Error",
req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))},
expectedErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == svcErr {
svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, errors.New("")).Once()
} else {
svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, nil).Once()
}
endpoint := attestationEndpoint(svc)
res, err := endpoint(context.Background(), tt.req)
if (err != nil) != tt.expectedErr {
t.Errorf("attestationEndpoint() error = %v, expectedErr %v", err, tt.expectedErr)
}
if err == nil {
_, ok := res.(attestationRes)
if !ok {
t.Errorf("attestationEndpoint() returned unexpected type %T", res)
}
}
})
}
}
+187
View File
@@ -0,0 +1,187 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/mocks"
"google.golang.org/grpc"
)
type MockAgentService_AlgoServer struct {
grpc.ServerStream
mock.Mock
ctx context.Context
}
func (m *MockAgentService_AlgoServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_AlgoServer) Recv() (*agent.AlgoRequest, error) {
args := m.Called()
return args.Get(0).(*agent.AlgoRequest), args.Error(1)
}
func (m *MockAgentService_AlgoServer) SendAndClose(resp *agent.AlgoResponse) error {
args := m.Called(resp)
return args.Error(0)
}
type MockAgentService_DataServer struct {
grpc.ServerStream
mock.Mock
ctx context.Context
}
func (m *MockAgentService_DataServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_DataServer) Recv() (*agent.DataRequest, error) {
args := m.Called()
return args.Get(0).(*agent.DataRequest), args.Error(1)
}
func (m *MockAgentService_DataServer) SendAndClose(resp *agent.DataResponse) error {
args := m.Called(resp)
return args.Error(0)
}
type MockAgentService_ResultServer struct {
grpc.ServerStream
mock.Mock
ctx context.Context
}
func (m *MockAgentService_ResultServer) Context() context.Context {
return m.ctx
}
func (m *MockAgentService_ResultServer) Send(resp *agent.ResultResponse) error {
args := m.Called(resp)
return args.Error(0)
}
func TestAlgo(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_AlgoServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}, nil).Once()
mockStream.On("Recv").Return(&agent.AlgoRequest{}, io.EOF)
mockStream.On("SendAndClose", &agent.AlgoResponse{}).Return(nil)
mockService.On("Algo", context.Background(), agent.Algorithm{Algorithm: []byte("algo"), Requirements: []byte("req")}).Return(nil)
err := server.Algo(mockStream)
assert.NoError(t, err)
mockStream.AssertExpectations(t)
mockService.AssertExpectations(t)
}
func TestData(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_DataServer{ctx: context.Background()}
mockStream.On("Recv").Return(&agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}, nil).Once()
mockStream.On("Recv").Return(&agent.DataRequest{}, io.EOF)
mockStream.On("SendAndClose", &agent.DataResponse{}).Return(nil)
mockService.On("Data", context.Background(), agent.Dataset{Dataset: []byte("data"), Filename: "test.txt"}).Return(nil)
err := server.Data(mockStream)
assert.NoError(t, err)
mockStream.AssertExpectations(t)
mockService.AssertExpectations(t)
}
func TestResult(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
mockStream := &MockAgentService_ResultServer{ctx: context.Background()}
mockService.On("Result", mock.Anything).Return([]byte("result data"), nil)
mockStream.On("Send", mock.AnythingOfType("*agent.ResultResponse")).Return(nil)
err := server.Result(&agent.ResultRequest{}, mockStream)
assert.NoError(t, err)
mockStream.AssertExpectations(t)
mockService.AssertExpectations(t)
}
func TestAttestation(t *testing.T) {
mockService := new(mocks.Service)
server := NewServer(mockService)
reportData := [agent.ReportDataSize]byte{}
mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil)
resp, err := server.Attestation(context.Background(), &agent.AttestationRequest{ReportData: reportData[:]})
assert.NoError(t, err)
assert.Equal(t, []byte("attestation data"), resp.File)
mockService.AssertExpectations(t)
}
func TestDecodeAlgoRequest(t *testing.T) {
req := &agent.AlgoRequest{Algorithm: []byte("algo"), Requirements: []byte("req")}
decoded, err := decodeAlgoRequest(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, algoReq{Algorithm: []byte("algo"), Requirements: []byte("req")}, decoded)
}
func TestEncodeAlgoResponse(t *testing.T) {
encoded, err := encodeAlgoResponse(context.Background(), algoRes{})
assert.NoError(t, err)
assert.Equal(t, &agent.AlgoResponse{}, encoded)
}
func TestDecodeDataRequest(t *testing.T) {
req := &agent.DataRequest{Dataset: []byte("data"), Filename: "test.txt"}
decoded, err := decodeDataRequest(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, dataReq{Dataset: []byte("data"), Filename: "test.txt"}, decoded)
}
func TestEncodeDataResponse(t *testing.T) {
encoded, err := encodeDataResponse(context.Background(), dataRes{})
assert.NoError(t, err)
assert.Equal(t, &agent.DataResponse{}, encoded)
}
func TestDecodeResultRequest(t *testing.T) {
decoded, err := decodeResultRequest(context.Background(), &agent.ResultRequest{})
assert.NoError(t, err)
assert.Equal(t, resultReq{}, decoded)
}
func TestEncodeResultResponse(t *testing.T) {
encoded, err := encodeResultResponse(context.Background(), resultRes{File: []byte("result")})
assert.NoError(t, err)
assert.Equal(t, &agent.ResultResponse{File: []byte("result")}, encoded)
}
func TestDecodeAttestationRequest(t *testing.T) {
reportData := [agent.ReportDataSize]byte{}
req := &agent.AttestationRequest{ReportData: reportData[:]}
decoded, err := decodeAttestationRequest(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, attestationReq{ReportData: reportData}, decoded)
}
func TestEncodeAttestationResponse(t *testing.T) {
encoded, err := encodeAttestationResponse(context.Background(), attestationRes{File: []byte("attestation")})
assert.NoError(t, err)
assert.Equal(t, &agent.AttestationResponse{File: []byte("attestation")}, encoded)
}
+133
View File
@@ -0,0 +1,133 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"encoding/json"
"reflect"
"testing"
"google.golang.org/grpc/metadata"
)
func TestDatasetsString(t *testing.T) {
datasets := Datasets{
{
Hash: [32]byte{1, 2, 3},
UserKey: []byte("user_key"),
Filename: "test.dat",
},
}
expected := `[{"hash":[1,2,3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],"user_key":"dXNlcl9rZXk=","filename":"test.dat"}]`
result := datasets.String()
if result != expected {
t.Errorf("Datasets.String() = %v, want %v", result, expected)
}
}
func TestIndexToContext(t *testing.T) {
ctx := context.Background()
index := 5
newCtx := IndexToContext(ctx, index)
result, ok := IndexFromContext(newCtx)
if !ok {
t.Errorf("IndexFromContext() ok = false, want true")
}
if result != index {
t.Errorf("IndexFromContext() = %v, want %v", result, index)
}
}
func TestDecompressFromContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expected bool
}{
{
name: "No decompress metadata",
ctx: context.Background(),
expected: false,
},
{
name: "Decompress true",
ctx: metadata.NewIncomingContext(
context.Background(),
metadata.Pairs(DecompressKey, "true"),
),
expected: true,
},
{
name: "Decompress false",
ctx: metadata.NewIncomingContext(
context.Background(),
metadata.Pairs(DecompressKey, "false"),
),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DecompressFromContext(tt.ctx)
if result != tt.expected {
t.Errorf("DecompressFromContext() = %v, want %v", result, tt.expected)
}
})
}
}
func TestDecompressToContext(t *testing.T) {
ctx := context.Background()
decompress := true
newCtx := DecompressToContext(ctx, decompress)
md, ok := metadata.FromOutgoingContext(newCtx)
if !ok {
t.Errorf("metadata.FromOutgoingContext() ok = false, want true")
}
vals := md.Get(DecompressKey)
if len(vals) != 1 {
t.Errorf("len(md.Get(DecompressKey)) = %v, want 1", len(vals))
}
if vals[0] != "true" {
t.Errorf("md.Get(DecompressKey)[0] = %v, want 'true'", vals[0])
}
}
func TestAgentConfigJSON(t *testing.T) {
config := AgentConfig{
LogLevel: "info",
Host: "localhost",
Port: "8080",
CertFile: "cert.pem",
KeyFile: "key.pem",
ServerCAFile: "server_ca.pem",
ClientCAFile: "client_ca.pem",
AttestedTls: true,
}
data, err := json.Marshal(config)
if err != nil {
t.Fatalf("Failed to marshal AgentConfig: %v", err)
}
var unmarshaledConfig AgentConfig
err = json.Unmarshal(data, &unmarshaledConfig)
if err != nil {
t.Fatalf("Failed to unmarshal AgentConfig: %v", err)
}
if !reflect.DeepEqual(config, unmarshaledConfig) {
t.Errorf("Unmarshaled config does not match original. Got %+v, want %+v", unmarshaledConfig, config)
}
}
+82
View File
@@ -0,0 +1,82 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"bytes"
"encoding/json"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
)
type mockConn struct {
writeErr error
buf bytes.Buffer
}
func (m *mockConn) Write(p []byte) (n int, err error) {
if m.writeErr != nil {
return 0, m.writeErr
}
return m.buf.Write(p)
}
func TestSendEventSuccess(t *testing.T) {
mockConnection := &mockConn{}
svc, err := New("test_service", "12345", mockConnection)
assert.NoError(t, err)
details := json.RawMessage(`{"key": "value"}`)
err = svc.SendEvent("test_event", "success", details)
assert.NoError(t, err)
var writtenMessage manager.ClientStreamMessage
err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage)
assert.NoError(t, err)
assert.Equal(t, "test_event", writtenMessage.GetAgentEvent().EventType)
assert.Equal(t, "12345", writtenMessage.GetAgentEvent().ComputationId)
assert.Equal(t, "test_service", writtenMessage.GetAgentEvent().Originator)
assert.Equal(t, "success", writtenMessage.GetAgentEvent().Status)
now := time.Now()
eventTimestamp := writtenMessage.GetAgentEvent().GetTimestamp().AsTime()
assert.WithinDuration(t, now, eventTimestamp, 1*time.Second)
}
func TestSendEventFailure(t *testing.T) {
mockConnection := &mockConn{writeErr: errors.New("write error")}
svc, err := New("test_service", "12345", mockConnection)
assert.NoError(t, err)
details := json.RawMessage(`{"key": "value"}`)
err = svc.SendEvent("test_event", "failure", details)
assert.Error(t, err)
assert.Equal(t, "write error", err.Error())
assert.Len(t, svc.(*service).cachedMessages, 1)
}
func TestClose(t *testing.T) {
mockConnection := &mockConn{}
svc, err := New("test_service", "12345", mockConnection)
assert.NoError(t, err)
svc.Close()
time.Sleep(1 * time.Second)
details := json.RawMessage(`{"key": "value"}`)
err = svc.SendEvent("test_event", "success", details)
assert.NoError(t, err)
}