NOISSUE - Host data verification (#275)

* host data verification

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* update mocks

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug host data

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* check device

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* imorove test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* missing header

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* update embed option

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* minor fixes

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* update deps

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* each case is unique

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* all files

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* resolve comments

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add test case

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add test cases

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* use consts

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add coverage

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* make sure pid is exited

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-14 12:12:49 +03:00
committed by GitHub
parent bb903c0170
commit 184617da9e
25 changed files with 675 additions and 104 deletions
+19 -5
View File
@@ -34,14 +34,28 @@ jobs:
run: |
mkdir coverage
- name: Run tests
run: go test -v --race -covermode=atomic -coverprofile coverage/cover.out ./...
- name: Run Agent tests
run: go test -v --race -covermode=atomic -coverprofile coverage/agent.out ./agent/...
- name: Run cli tests
run: go test -v --race -covermode=atomic -coverprofile coverage/cli.out ./cli/...
- name: Run cmd tests
run: go test -v --race -covermode=atomic -coverprofile coverage/cmd.out ./cmd/...
- name: Run internal tests
run: go test -v --race -covermode=atomic -coverprofile coverage/internal.out ./internal/...
- name: Run pkg tests
run: go test -v --race -covermode=atomic -coverprofile coverage/pkg.out ./pkg/...
- name: Run manager tests
run: sudo go test -v --race -covermode=atomic -coverprofile coverage/manager.out ./manager/...
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
directory: ./coverage/
name: codecov-umbrella
files: ./coverage/*.out
codecov_yml_path: codecov.yml
verbose: true
+8 -8
View File
@@ -13,14 +13,14 @@ import (
var _ fmt.Stringer = (*Datasets)(nil)
type AgentConfig struct {
LogLevel string `json:"log_level"`
Host string `json:"host"`
Port string `json:"port"`
CertFile string `json:"cert_file"`
KeyFile string `json:"server_key"`
ServerCAFile string `json:"server_ca_file"`
ClientCAFile string `json:"client_ca_file"`
AttestedTls bool `json:"attested_tls"`
LogLevel string `json:"log_level,omitempty"`
Host string `json:"host,omitempty"`
Port string `json:"port,omitempty"`
CertFile string `json:"cert_file,omitempty"`
KeyFile string `json:"server_key,omitempty"`
ServerCAFile string `json:"server_ca_file,omitempty"`
ClientCAFile string `json:"client_ca_file,omitempty"`
AttestedTls bool `json:"attested_tls,omitempty"`
}
type Computation struct {
+8 -9
View File
@@ -8,9 +8,8 @@ package mocks
import (
context "context"
agent "github.com/ultravioletrs/cocos/agent/statemachine"
mock "github.com/stretchr/testify/mock"
statemachine "github.com/ultravioletrs/cocos/agent/statemachine"
)
// StateMachine is an autogenerated mock type for the StateMachine type
@@ -19,24 +18,24 @@ type StateMachine struct {
}
// AddTransition provides a mock function with given fields: t
func (_m *StateMachine) AddTransition(t agent.Transition) {
func (_m *StateMachine) AddTransition(t statemachine.Transition) {
_m.Called(t)
}
// GetState provides a mock function with given fields:
func (_m *StateMachine) GetState() agent.State {
func (_m *StateMachine) GetState() statemachine.State {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetState")
}
var r0 agent.State
if rf, ok := ret.Get(0).(func() agent.State); ok {
var r0 statemachine.State
if rf, ok := ret.Get(0).(func() statemachine.State); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(agent.State)
r0 = ret.Get(0).(statemachine.State)
}
}
@@ -44,12 +43,12 @@ func (_m *StateMachine) GetState() agent.State {
}
// SendEvent provides a mock function with given fields: event
func (_m *StateMachine) SendEvent(event agent.Event) {
func (_m *StateMachine) SendEvent(event statemachine.Event) {
_m.Called(event)
}
// SetAction provides a mock function with given fields: state, action
func (_m *StateMachine) SetAction(state agent.State, action agent.Action) {
func (_m *StateMachine) SetAction(state statemachine.State, action statemachine.Action) {
_m.Called(state, action)
}
+58 -6
View File
@@ -3,17 +3,20 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"os"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/prometheus"
"github.com/cenkalti/backoff/v4"
"github.com/google/go-sev-guest/abi"
"github.com/google/go-sev-guest/client"
"github.com/mdlayher/vsock"
"github.com/ultravioletrs/cocos/agent"
@@ -28,6 +31,7 @@ import (
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/qemu"
"golang.org/x/crypto/sha3"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
@@ -84,6 +88,14 @@ func main() {
return
}
if err := verifyManifest(cfg, qp); err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
setDefaultValues(&cfg)
svc := newService(ctx, logger, eventSvc, cfg, qp)
grpcServerConfig := server.Config{
@@ -179,15 +191,23 @@ func readConfig() (agent.Computation, error) {
if err := json.Unmarshal(buffer, &ac); err != nil {
return agent.Computation{}, err
}
if ac.AgentConfig.LogLevel == "" {
ac.AgentConfig.LogLevel = "info"
}
if ac.AgentConfig.Port == "" {
ac.AgentConfig.Port = defSvcGRPCPort
}
return ac, nil
}
func setDefaultValues(cfg *agent.Computation) {
if cfg.AgentConfig.LogLevel == "" {
cfg.AgentConfig.LogLevel = "info"
}
if cfg.AgentConfig.Port == "" {
cfg.AgentConfig.Port = defSvcGRPCPort
}
}
func isTEE() bool {
_, err := os.Stat("/dev/sev-guest")
return !os.IsNotExist(err)
}
func dialVsock() (*vsock.Conn, error) {
var conn *vsock.Conn
var err error
@@ -207,3 +227,35 @@ func dialVsock() (*vsock.Conn, error) {
return conn, nil
}
func verifyManifest(cfg agent.Computation, qp client.QuoteProvider) error {
if !isTEE() {
return nil
}
ar, err := qp.GetRawQuote(sha3.Sum512([]byte(cfg.ID)))
if err != nil {
return err
}
arProto, err := abi.ReportCertsToProto(ar[:abi.ReportSize])
if err != nil {
return err
}
cfgBytes, err := json.Marshal(cfg)
if err != nil {
return err
}
mcHash := sha3.Sum256(cfgBytes)
if arProto.Report.HostData == nil {
return fmt.Errorf("manifest verification failed: HostData is nil")
}
if !bytes.Equal(arProto.Report.HostData, mcHash[:]) {
return fmt.Errorf("manifest verification failed")
}
return nil
}
+94
View File
@@ -0,0 +1,94 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package main
import (
"context"
"log/slog"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/events/mocks"
qpmocks "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
)
func TestSetDefaultValues(t *testing.T) {
tests := []struct {
name string
input agent.Computation
expected agent.Computation
}{
{
name: "Empty config",
input: agent.Computation{
AgentConfig: agent.AgentConfig{},
},
expected: agent.Computation{
AgentConfig: agent.AgentConfig{
LogLevel: "info",
Port: "7002",
},
},
},
{
name: "Partial config",
input: agent.Computation{
AgentConfig: agent.AgentConfig{
LogLevel: "debug",
},
},
expected: agent.Computation{
AgentConfig: agent.AgentConfig{
LogLevel: "debug",
Port: "7002",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setDefaultValues(&tt.input)
assert.Equal(t, tt.expected, tt.input)
})
}
}
func TestNewService(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := new(mocks.Service)
eventSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
cmp := agent.Computation{
ID: "test-computation",
AgentConfig: agent.AgentConfig{
LogLevel: "info",
Port: "7002",
},
}
qp := new(qpmocks.QuoteProvider)
svc := newService(ctx, logger, eventSvc, cmp, qp)
assert.NotNil(t, svc)
}
func TestVerifyManifest(t *testing.T) {
cfg := agent.Computation{
ID: "test-computation",
AgentConfig: agent.AgentConfig{
LogLevel: "info",
Port: "7002",
},
}
mockQP := new(qpmocks.QuoteProvider)
mockQP.On("GetRawQuote", mock.Anything).Return([]byte{}, nil)
err := verifyManifest(cfg, mockQP)
assert.NoError(t, err)
}
+6
View File
@@ -0,0 +1,6 @@
# Copyright (c) Ultraviolet
# SPDX-License-Identifier: Apache-2.0
coverage:
ignore:
- "test/*"
-1
View File
@@ -6,7 +6,6 @@ require (
github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746
github.com/caarlos0/env/v11 v11.2.2
github.com/cenkalti/backoff/v4 v4.3.0
github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752
github.com/fatih/color v1.17.0
github.com/go-kit/kit v0.13.0
github.com/gofrs/uuid v4.4.0+incompatible
-2
View File
@@ -19,8 +19,6 @@ github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752 h1:NI7XEcHzWVvBfVjSVK6Qk4wmrUfoyQxCNpBjrHelZFk=
github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752/go.mod h1:/Ok8PA2qi/ve0Py38+oL+VxoYmlowigYRyLEODRYdgc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v27.3.1+incompatible h1:KttF0XoteNTicmUtBO0L2tP+J7FGRFTjaEF4k6WdhfI=
+3 -3
View File
@@ -80,7 +80,7 @@ func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkg
case *pkgmanager.ServerStreamMessage_StopComputation:
go client.handleStopComputation(ctx, mes)
case *pkgmanager.ServerStreamMessage_BackendInfoReq:
go client.handleBackendInfoReq(mes)
go client.handleBackendInfoReq(ctx, mes)
default:
return errors.New("unknown message type")
}
@@ -133,8 +133,8 @@ func (client ManagerClient) handleStopComputation(ctx context.Context, mes *pkgm
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: msg})
}
func (client ManagerClient) handleBackendInfoReq(mes *pkgmanager.ServerStreamMessage_BackendInfoReq) {
res, err := client.svc.FetchBackendInfo()
func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *pkgmanager.ServerStreamMessage_BackendInfoReq) {
res, err := client.svc.FetchBackendInfo(ctx, mes.BackendInfoReq.Id)
if err != nil {
client.logger.Warn(err.Error())
return
+46 -21
View File
@@ -150,32 +150,57 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
}
func TestManagerClient_handleBackendInfoReq(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
logger := mglog.NewMock()
t.Run("success", func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
client := NewClient(mockStream, mockSvc, messageQueue, logger)
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &pkgmanager.BackendInfoReq{
Id: "test-info-id",
},
}
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &pkgmanager.BackendInfoReq{
Id: "test-info-id",
},
}
mockSvc.On("FetchBackendInfo").Return([]byte("test-backend-info"), nil)
mockSvc.On("FetchBackendInfo", context.Background(), infoReq.BackendInfoReq.Id).Return([]byte("test-backend-info"), nil)
client.handleBackendInfoReq(infoReq)
client.handleBackendInfoReq(context.Background(), infoReq)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
assert.True(t, ok)
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
msg := <-messageQueue
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
assert.True(t, ok)
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
})
t.Run("error", func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &pkgmanager.BackendInfoReq{
Id: "test-info-id",
},
}
mockSvc.On("FetchBackendInfo", context.Background(), infoReq.BackendInfoReq.Id).Return(nil, assert.AnError)
client.handleBackendInfoReq(context.Background(), infoReq)
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 0)
})
}
+8 -3
View File
@@ -58,11 +58,16 @@ func (lm *loggingMiddleware) RetrieveAgentEventsLogs() {
lm.svc.RetrieveAgentEventsLogs()
}
func (lm *loggingMiddleware) FetchBackendInfo() ([]byte, error) {
func (lm *loggingMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) (body []byte, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method FetchBackendInfo took %s to complete", time.Since(begin))
message := fmt.Sprintf("Method FetchBackendInfo for computation %s took %s to complete", cmpId, time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.FetchBackendInfo()
return lm.svc.FetchBackendInfo(ctx, cmpId)
}
+2 -2
View File
@@ -55,11 +55,11 @@ func (ms *metricsMiddleware) RetrieveAgentEventsLogs() {
ms.svc.RetrieveAgentEventsLogs()
}
func (ms *metricsMiddleware) FetchBackendInfo() ([]byte, error) {
func (ms *metricsMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) ([]byte, error) {
defer func(begin time.Time) {
ms.counter.With("method", "FetchBackendInfo").Add(1)
ms.latency.With("method", "FetchBackendInfo").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.FetchBackendInfo()
return ms.svc.FetchBackendInfo(ctx, cmpId)
}
+21 -7
View File
@@ -7,6 +7,7 @@
package manager
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
@@ -22,9 +23,21 @@ import (
const defGuestFeatures = 0x1
func (ms *managerService) FetchBackendInfo() ([]byte, error) {
func (ms *managerService) FetchBackendInfo(_ context.Context, computationId string) ([]byte, error) {
cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "1966081")
ms.mu.Lock()
vm, exists := ms.vms[computationId]
ms.mu.Unlock()
if !exists {
return nil, fmt.Errorf("computationId %s not found", computationId)
}
config, ok := vm.GetConfig().(qemu.Config)
if !ok {
return nil, fmt.Errorf("failed to cast config to qemu.Config")
}
_, err := cmd.Output()
if err != nil {
return nil, err
@@ -42,13 +55,14 @@ func (ms *managerService) FetchBackendInfo() ([]byte, error) {
}
var measurement []byte
if ms.qemuCfg.EnableSEV {
measurement, err = guest.CalcLaunchDigest(guest.SEV, ms.qemuCfg.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), ms.qemuCfg.OVMFCodeConfig.File, ms.qemuCfg.KernelFile, ms.qemuCfg.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
switch {
case config.EnableSEV:
measurement, err = guest.CalcLaunchDigest(guest.SEV, config.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), config.OVMFCodeConfig.File, config.KernelFile, config.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
if err != nil {
return nil, err
}
} else if ms.qemuCfg.EnableSEVSNP {
measurement, err = guest.CalcLaunchDigest(guest.SEV_SNP, ms.qemuCfg.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), ms.qemuCfg.OVMFCodeConfig.File, ms.qemuCfg.KernelFile, ms.qemuCfg.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
case config.EnableSEVSNP:
measurement, err = guest.CalcLaunchDigest(guest.SEV_SNP, config.SMPCount, uint64(cpuid.CpuSigs[config.CPU]), config.OVMFCodeConfig.File, config.KernelFile, config.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
if err != nil {
return nil, err
}
@@ -57,8 +71,8 @@ func (ms *managerService) FetchBackendInfo() ([]byte, error) {
backendInfo.SNPPolicy.Measurement = measurement
}
if ms.qemuCfg.HostData != "" {
hostData, err := base64.StdEncoding.DecodeString(ms.qemuCfg.HostData)
if config.HostData != "" {
hostData, err := base64.StdEncoding.DecodeString(config.HostData)
if err != nil {
return nil, err
}
+6 -2
View File
@@ -6,8 +6,12 @@
package manager
import backendinfo "github.com/ultravioletrs/cocos/scripts/backend_info"
import (
"context"
func (ms *managerService) FetchBackendInfo() ([]byte, error) {
backendinfo "github.com/ultravioletrs/cocos/scripts/backend_info"
)
func (ms *managerService) FetchBackendInfo(_ context.Context, _ string) ([]byte, error) {
return backendinfo.BackendInfo, nil
}
+157
View File
@@ -0,0 +1,157 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/manager/vm/mocks"
)
func createDummyBackendInfoBinary(t *testing.T, behavior string) string {
var content []byte
switch behavior {
case "success":
content = []byte(`#!/bin/sh
echo '{"snp_policy": {"measurement": null, "host_data": null}}' > backend_info.json
`)
case "fail":
content = []byte(`#!/bin/sh
echo "Error: Failed to execute backend_info" >&2
exit 1
`)
case "no_json":
content = []byte(`#!/bin/sh
echo 'No JSON file created'
`)
default:
t.Fatalf("Unknown behavior: %s", behavior)
}
tempDir := t.TempDir()
binaryPath := filepath.Join(tempDir, "backend_info")
err := os.WriteFile(binaryPath, content, 0o755)
assert.NoError(t, err)
return tempDir
}
func TestFetchBackendInfo(t *testing.T) {
testCases := []struct {
name string
computationId string
vmConfig interface{}
binaryBehavior string
expectedError string
expectedResult map[string]interface{}
}{
{
name: "Valid SEV configuration",
computationId: "sev-computation",
binaryBehavior: "success",
vmConfig: qemu.Config{
EnableSEV: true,
SMPCount: 2,
CPU: "EPYC",
OVMFCodeConfig: qemu.OVMFCodeConfig{
File: "/path/to/OVMF_CODE.fd",
},
},
expectedError: "open /path/to/OVMF_CODE.fd: no such file or directory",
},
{
name: "Valid SEV-SNP configuration",
computationId: "sev-snp-computation",
binaryBehavior: "success",
vmConfig: qemu.Config{
EnableSEVSNP: true,
SMPCount: 4,
CPU: "EPYC-v2",
OVMFCodeConfig: qemu.OVMFCodeConfig{
File: "/path/to/OVMF_CODE_SNP.fd",
},
},
expectedError: "open /path/to/OVMF_CODE_SNP.fd: no such file or director",
},
{
name: "Invalid computation ID",
computationId: "non-existent",
binaryBehavior: "success",
vmConfig: qemu.Config{},
expectedError: "computationId non-existent not found",
},
{
name: "Invalid config type",
computationId: "invalid-config",
binaryBehavior: "success",
vmConfig: struct{}{},
expectedError: "failed to cast config to qemu.Config",
},
{
name: "Binary execution failure",
computationId: "binary-fail",
binaryBehavior: "fail",
vmConfig: qemu.Config{
EnableSEV: true,
},
expectedError: "exit status 1",
},
{
name: "JSON file not created",
computationId: "no-json",
binaryBehavior: "no_json",
vmConfig: qemu.Config{
EnableSEV: true,
},
expectedError: "no such file or directory",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempDir := createDummyBackendInfoBinary(t, tc.binaryBehavior)
defer os.RemoveAll(tempDir)
ms := &managerService{
vms: make(map[string]vm.VM),
backendMeasurementBinaryPath: tempDir,
qemuCfg: qemu.Config{
CPU: "EPYC",
},
}
mockVM := new(mocks.VM)
mockVM.On("GetConfig").Return(tc.vmConfig)
if tc.computationId != "non-existent" {
ms.vms[tc.computationId] = mockVM
}
result, err := ms.FetchBackendInfo(context.Background(), tc.computationId)
if tc.expectedError != "" {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
var backendInfo map[string]interface{}
err = json.Unmarshal(result, &backendInfo)
assert.NoError(t, err)
assert.Equal(t, tc.expectedResult, backendInfo)
}
if tc.binaryBehavior == "success" {
os.Remove("backend_info.json")
}
})
}
}
+9 -9
View File
@@ -18,9 +18,9 @@ type Service struct {
mock.Mock
}
// FetchBackendInfo provides a mock function with given fields:
func (_m *Service) FetchBackendInfo() ([]byte, error) {
ret := _m.Called()
// FetchBackendInfo provides a mock function with given fields: ctx, computationID
func (_m *Service) FetchBackendInfo(ctx context.Context, computationID string) ([]byte, error) {
ret := _m.Called(ctx, computationID)
if len(ret) == 0 {
panic("no return value specified for FetchBackendInfo")
@@ -28,19 +28,19 @@ func (_m *Service) FetchBackendInfo() ([]byte, error) {
var r0 []byte
var r1 error
if rf, ok := ret.Get(0).(func() ([]byte, error)); ok {
return rf()
if rf, ok := ret.Get(0).(func(context.Context, string) ([]byte, error)); ok {
return rf(ctx, computationID)
}
if rf, ok := ret.Get(0).(func() []byte); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context, string) []byte); ok {
r0 = rf(ctx, computationID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, computationID)
} else {
r1 = ret.Error(1)
}
+4
View File
@@ -197,3 +197,7 @@ func processExists(pid int) bool {
func (v *qemuVM) GetCID() int {
return v.config.GuestCID
}
func (v *qemuVM) GetConfig() interface{} {
return v.config
}
+82 -19
View File
@@ -10,15 +10,17 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/manager/vm/mocks"
"github.com/ultravioletrs/cocos/pkg/manager"
)
const testComputationID = "test-computation"
func TestNewVM(t *testing.T) {
config := Config{}
logsChan := make(chan *manager.ClientStreamMessage)
computationId := "test-computation"
vm := NewVM(config, logsChan, computationId)
vm := NewVM(config, logsChan, testComputationID)
assert.NotNil(t, vm)
assert.IsType(t, &qemuVM{}, vm)
@@ -34,35 +36,84 @@ func TestStart(t *testing.T) {
OVMFVarsConfig: OVMFVarsConfig{
File: tmpFile.Name(),
},
QemuBinPath: "echo", // Use 'echo' as a dummy QEMU binary
QemuBinPath: "echo",
}
logsChan := make(chan *manager.ClientStreamMessage)
computationId := "test-computation"
vm := NewVM(config, logsChan, computationId).(*qemuVM)
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
err = vm.Start()
assert.NoError(t, err)
assert.NotNil(t, vm.cmd)
_ = vm.Stop()
}
func TestStartSudo(t *testing.T) {
// Create a temporary file for testing
tmpFile, err := os.CreateTemp("", "test-ovmf-vars")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
config := Config{
OVMFVarsConfig: OVMFVarsConfig{
File: tmpFile.Name(),
},
QemuBinPath: "echo",
UseSudo: true,
}
logsChan := make(chan *manager.ClientStreamMessage)
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
err = vm.Start()
assert.NoError(t, err)
assert.NotNil(t, vm.cmd)
// Clean up
_ = vm.Stop()
}
func TestStop(t *testing.T) {
cmd := exec.Command("echo", "test")
err := cmd.Start()
assert.NoError(t, err)
t.Run("success", func(t *testing.T) {
cmd := exec.Command("echo", "test")
err := cmd.Start()
assert.NoError(t, err)
sm := new(mocks.StateMachine)
sm.On("Transition", manager.StopComputationRun).Return(nil)
vm := &qemuVM{
cmd: &exec.Cmd{
Process: cmd.Process,
},
StateMachine: vm.NewStateMachine(),
}
vm := &qemuVM{
cmd: &exec.Cmd{
Process: cmd.Process,
},
StateMachine: sm,
}
err = vm.Stop()
assert.NoError(t, err)
err = vm.Stop()
assert.NoError(t, err)
})
t.Run("transition error", func(t *testing.T) {
cmd := exec.Command("echo", "test")
err := cmd.Start()
assert.NoError(t, err)
sm := new(mocks.StateMachine)
sm.On("Transition", manager.StopComputationRun).Return(assert.AnError)
sm.On("State").Return(manager.Stopped.String())
vm := &qemuVM{
cmd: &exec.Cmd{
Process: cmd.Process,
},
StateMachine: sm,
logsChan: make(chan *manager.ClientStreamMessage),
}
go func() {
<-vm.logsChan
}()
err = vm.Stop()
assert.NoError(t, err)
})
}
func TestSetProcess(t *testing.T) {
@@ -104,11 +155,23 @@ func TestGetCID(t *testing.T) {
assert.Equal(t, expectedCID, cid)
}
func TestGetConfig(t *testing.T) {
expectedConfig := Config{
QemuBinPath: "echo",
}
vm := &qemuVM{
config: expectedConfig,
}
config := vm.GetConfig()
assert.Equal(t, expectedConfig, config)
}
func TestCheckVMProcessPeriodically(t *testing.T) {
logsChan := make(chan *manager.ClientStreamMessage, 1)
vm := &qemuVM{
logsChan: logsChan,
computationId: "test-computation",
computationId: testComputationID,
cmd: &exec.Cmd{
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
},
@@ -120,7 +183,7 @@ func TestCheckVMProcessPeriodically(t *testing.T) {
select {
case msg := <-logsChan:
assert.NotNil(t, msg.GetAgentEvent())
assert.Equal(t, "test-computation", msg.GetAgentEvent().ComputationId)
assert.Equal(t, testComputationID, msg.GetAgentEvent().ComputationId)
assert.Equal(t, manager.VmProvision.String(), msg.GetAgentEvent().EventType)
assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status)
case <-time.After(2 * interval):
+6 -1
View File
@@ -63,7 +63,7 @@ type Service interface {
// RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock.
RetrieveAgentEventsLogs()
// FetchBackendInfo measures and fetches the backend information.
FetchBackendInfo() ([]byte, error)
FetchBackendInfo(ctx context.Context, computationID string) ([]byte, error)
}
type managerService struct {
@@ -128,6 +128,11 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
LogLevel: c.AgentConfig.LogLevel,
},
}
if len(c.Algorithm.Hash) != hashLength {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
return "", errInvalidHashLength
}
ac.Algorithm = agent.Algorithm{Hash: [hashLength]byte(c.Algorithm.Hash), UserKey: c.Algorithm.UserKey}
for _, data := range c.Datasets {
+50 -3
View File
@@ -5,12 +5,15 @@ package manager
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"os"
"os/exec"
"testing"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@@ -74,6 +77,37 @@ func TestRun(t *testing.T) {
vmStartError: assert.AnError,
expectedError: assert.AnError,
},
{
name: "Invalid algorithm hash",
req: &manager.ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &manager.Algorithm{
Hash: make([]byte, hashLength-1),
},
AgentConfig: &manager.AgentConfig{},
},
vmStartError: nil,
expectedError: errInvalidHashLength,
},
{
name: "Invalid dataset hash",
req: &manager.ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &manager.Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &manager.AgentConfig{},
Datasets: []*manager.Dataset{
{
Hash: make([]byte, hashLength-1),
},
},
},
vmStartError: nil,
expectedError: errInvalidHashLength,
},
}
for _, tt := range tests {
@@ -211,7 +245,14 @@ func TestGetFreePort(t *testing.T) {
port, err := getFreePort(6000, 6100)
assert.NoError(t, err)
assert.Greater(t, port, 0)
assert.GreaterOrEqual(t, port, 6000)
_, err = net.Listen("tcp", net.JoinHostPort("localhost", fmt.Sprint(port)))
assert.NoError(t, err)
port, err = getFreePort(6000, 6100)
assert.NoError(t, err)
assert.Greater(t, port, 6000)
}
func TestPublishEvent(t *testing.T) {
@@ -338,12 +379,18 @@ func TestRestoreVMs(t *testing.T) {
err := cmd.Start()
assert.NoError(t, err)
cmd2 := exec.Command("echo", "test")
err = cmd2.Run()
assert.NoError(t, err)
mockPersistence.On("LoadVMs").Return([]qemu.VMState{
{ID: "vm1", PID: cmd.Process.Pid},
{ID: "vm2", PID: 1000},
{ID: "vm2", PID: cmd2.Process.Pid},
{ID: "vm3", PID: cmd2.Process.Pid},
}, nil)
mockPersistence.On("DeleteVM", mock.Anything).Return(nil)
mockPersistence.On("DeleteVM", "vm2").Return(nil)
mockPersistence.On("DeleteVM", "vm3").Return(errors.New("failed to delete"))
err = ms.restoreVMs()
assert.NoError(t, err)
+3 -3
View File
@@ -40,9 +40,9 @@ func (tm *tracingMiddleware) RetrieveAgentEventsLogs() {
tm.svc.RetrieveAgentEventsLogs()
}
func (tm *tracingMiddleware) FetchBackendInfo() ([]byte, error) {
_, span := tm.tracer.Start(context.Background(), "fetch_backend_info")
func (tm *tracingMiddleware) FetchBackendInfo(ctx context.Context, computationId string) ([]byte, error) {
_, span := tm.tracer.Start(ctx, "fetch_backend_info")
defer span.End()
return tm.svc.FetchBackendInfo()
return tm.svc.FetchBackendInfo(ctx, computationId)
}
+63
View File
@@ -0,0 +1,63 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
package mocks
import (
mock "github.com/stretchr/testify/mock"
manager "github.com/ultravioletrs/cocos/pkg/manager"
)
// StateMachine is an autogenerated mock type for the StateMachine type
type StateMachine struct {
mock.Mock
}
// State provides a mock function with given fields:
func (_m *StateMachine) State() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for State")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// Transition provides a mock function with given fields: newState
func (_m *StateMachine) Transition(newState manager.ManagerState) error {
ret := _m.Called(newState)
if len(ret) == 0 {
panic("no return value specified for Transition")
}
var r0 error
if rf, ok := ret.Get(0).(func(manager.ManagerState) error); ok {
r0 = rf(newState)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewStateMachine(t interface {
mock.TestingT
Cleanup(func())
}) *StateMachine {
mock := &StateMachine{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+20
View File
@@ -35,6 +35,26 @@ func (_m *VM) GetCID() int {
return r0
}
// GetConfig provides a mock function with given fields:
func (_m *VM) GetConfig() interface{} {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetConfig")
}
var r0 interface{}
if rf, ok := ret.Get(0).(func() interface{}); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(interface{})
}
}
return r0
}
// GetProcess provides a mock function with given fields:
func (_m *VM) GetProcess() int {
ret := _m.Called()
+1
View File
@@ -14,6 +14,7 @@ type sm struct {
state manager.ManagerState
}
//go:generate mockery --name StateMachine --output=./mocks --filename state_machine.go --quiet
type StateMachine interface {
Transition(newState manager.ManagerState) error
State() string
+1
View File
@@ -19,6 +19,7 @@ type VM interface {
GetCID() int
Transition(newState manager.ManagerState) error
State() string
GetConfig() interface{}
}
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"