mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Host data verification (#275)
* host data verification Signed-off-by: Sammy Oina <sammyoina@gmail.com> * update mocks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug host data Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug Signed-off-by: Sammy Oina <sammyoina@gmail.com> * check device Signed-off-by: Sammy Oina <sammyoina@gmail.com> * imorove test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * missing header Signed-off-by: Sammy Oina <sammyoina@gmail.com> * update embed option Signed-off-by: Sammy Oina <sammyoina@gmail.com> * minor fixes Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> * update deps Signed-off-by: Sammy Oina <sammyoina@gmail.com> * each case is unique Signed-off-by: Sammy Oina <sammyoina@gmail.com> * all files Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix coverage Signed-off-by: Sammy Oina <sammyoina@gmail.com> * resolve comments Signed-off-by: Sammy Oina <sammyoina@gmail.com> * improve coverage Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add test case Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add test cases Signed-off-by: Sammy Oina <sammyoina@gmail.com> * use consts Signed-off-by: Sammy Oina <sammyoina@gmail.com> * add coverage Signed-off-by: Sammy Oina <sammyoina@gmail.com> * make sure pid is exited Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
bb903c0170
commit
184617da9e
@@ -34,14 +34,28 @@ jobs:
|
||||
run: |
|
||||
mkdir coverage
|
||||
|
||||
- name: Run tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/cover.out ./...
|
||||
- name: Run Agent tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/agent.out ./agent/...
|
||||
|
||||
- name: Run cli tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/cli.out ./cli/...
|
||||
|
||||
- name: Run cmd tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/cmd.out ./cmd/...
|
||||
|
||||
- name: Run internal tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/internal.out ./internal/...
|
||||
|
||||
- name: Run pkg tests
|
||||
run: go test -v --race -covermode=atomic -coverprofile coverage/pkg.out ./pkg/...
|
||||
|
||||
- name: Run manager tests
|
||||
run: sudo go test -v --race -covermode=atomic -coverprofile coverage/manager.out ./manager/...
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
directory: ./coverage/
|
||||
name: codecov-umbrella
|
||||
files: ./coverage/*.out
|
||||
codecov_yml_path: codecov.yml
|
||||
verbose: true
|
||||
|
||||
|
||||
@@ -13,14 +13,14 @@ import (
|
||||
var _ fmt.Stringer = (*Datasets)(nil)
|
||||
|
||||
type AgentConfig struct {
|
||||
LogLevel string `json:"log_level"`
|
||||
Host string `json:"host"`
|
||||
Port string `json:"port"`
|
||||
CertFile string `json:"cert_file"`
|
||||
KeyFile string `json:"server_key"`
|
||||
ServerCAFile string `json:"server_ca_file"`
|
||||
ClientCAFile string `json:"client_ca_file"`
|
||||
AttestedTls bool `json:"attested_tls"`
|
||||
LogLevel string `json:"log_level,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
CertFile string `json:"cert_file,omitempty"`
|
||||
KeyFile string `json:"server_key,omitempty"`
|
||||
ServerCAFile string `json:"server_ca_file,omitempty"`
|
||||
ClientCAFile string `json:"client_ca_file,omitempty"`
|
||||
AttestedTls bool `json:"attested_tls,omitempty"`
|
||||
}
|
||||
|
||||
type Computation struct {
|
||||
|
||||
@@ -8,9 +8,8 @@ package mocks
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
statemachine "github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
)
|
||||
|
||||
// StateMachine is an autogenerated mock type for the StateMachine type
|
||||
@@ -19,24 +18,24 @@ type StateMachine struct {
|
||||
}
|
||||
|
||||
// AddTransition provides a mock function with given fields: t
|
||||
func (_m *StateMachine) AddTransition(t agent.Transition) {
|
||||
func (_m *StateMachine) AddTransition(t statemachine.Transition) {
|
||||
_m.Called(t)
|
||||
}
|
||||
|
||||
// GetState provides a mock function with given fields:
|
||||
func (_m *StateMachine) GetState() agent.State {
|
||||
func (_m *StateMachine) GetState() statemachine.State {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetState")
|
||||
}
|
||||
|
||||
var r0 agent.State
|
||||
if rf, ok := ret.Get(0).(func() agent.State); ok {
|
||||
var r0 statemachine.State
|
||||
if rf, ok := ret.Get(0).(func() statemachine.State); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(agent.State)
|
||||
r0 = ret.Get(0).(statemachine.State)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,12 +43,12 @@ func (_m *StateMachine) GetState() agent.State {
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: event
|
||||
func (_m *StateMachine) SendEvent(event agent.Event) {
|
||||
func (_m *StateMachine) SendEvent(event statemachine.Event) {
|
||||
_m.Called(event)
|
||||
}
|
||||
|
||||
// SetAction provides a mock function with given fields: state, action
|
||||
func (_m *StateMachine) SetAction(state agent.State, action agent.Action) {
|
||||
func (_m *StateMachine) SetAction(state statemachine.State, action statemachine.Action) {
|
||||
_m.Called(state, action)
|
||||
}
|
||||
|
||||
|
||||
+58
-6
@@ -3,17 +3,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/prometheus"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/mdlayher/vsock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
@@ -28,6 +31,7 @@ import (
|
||||
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
@@ -84,6 +88,14 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := verifyManifest(cfg, qp); err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
setDefaultValues(&cfg)
|
||||
|
||||
svc := newService(ctx, logger, eventSvc, cfg, qp)
|
||||
|
||||
grpcServerConfig := server.Config{
|
||||
@@ -179,15 +191,23 @@ func readConfig() (agent.Computation, error) {
|
||||
if err := json.Unmarshal(buffer, &ac); err != nil {
|
||||
return agent.Computation{}, err
|
||||
}
|
||||
if ac.AgentConfig.LogLevel == "" {
|
||||
ac.AgentConfig.LogLevel = "info"
|
||||
}
|
||||
if ac.AgentConfig.Port == "" {
|
||||
ac.AgentConfig.Port = defSvcGRPCPort
|
||||
}
|
||||
return ac, nil
|
||||
}
|
||||
|
||||
func setDefaultValues(cfg *agent.Computation) {
|
||||
if cfg.AgentConfig.LogLevel == "" {
|
||||
cfg.AgentConfig.LogLevel = "info"
|
||||
}
|
||||
if cfg.AgentConfig.Port == "" {
|
||||
cfg.AgentConfig.Port = defSvcGRPCPort
|
||||
}
|
||||
}
|
||||
|
||||
func isTEE() bool {
|
||||
_, err := os.Stat("/dev/sev-guest")
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
func dialVsock() (*vsock.Conn, error) {
|
||||
var conn *vsock.Conn
|
||||
var err error
|
||||
@@ -207,3 +227,35 @@ func dialVsock() (*vsock.Conn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func verifyManifest(cfg agent.Computation, qp client.QuoteProvider) error {
|
||||
if !isTEE() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ar, err := qp.GetRawQuote(sha3.Sum512([]byte(cfg.ID)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arProto, err := abi.ReportCertsToProto(ar[:abi.ReportSize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfgBytes, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcHash := sha3.Sum256(cfgBytes)
|
||||
|
||||
if arProto.Report.HostData == nil {
|
||||
return fmt.Errorf("manifest verification failed: HostData is nil")
|
||||
}
|
||||
if !bytes.Equal(arProto.Report.HostData, mcHash[:]) {
|
||||
return fmt.Errorf("manifest verification failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
qpmocks "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
)
|
||||
|
||||
func TestSetDefaultValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input agent.Computation
|
||||
expected agent.Computation
|
||||
}{
|
||||
{
|
||||
name: "Empty config",
|
||||
input: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{},
|
||||
},
|
||||
expected: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "info",
|
||||
Port: "7002",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partial config",
|
||||
input: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "debug",
|
||||
},
|
||||
},
|
||||
expected: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "debug",
|
||||
Port: "7002",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setDefaultValues(&tt.input)
|
||||
assert.Equal(t, tt.expected, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := new(mocks.Service)
|
||||
eventSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
cmp := agent.Computation{
|
||||
ID: "test-computation",
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "info",
|
||||
Port: "7002",
|
||||
},
|
||||
}
|
||||
qp := new(qpmocks.QuoteProvider)
|
||||
|
||||
svc := newService(ctx, logger, eventSvc, cmp, qp)
|
||||
|
||||
assert.NotNil(t, svc)
|
||||
}
|
||||
|
||||
func TestVerifyManifest(t *testing.T) {
|
||||
cfg := agent.Computation{
|
||||
ID: "test-computation",
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "info",
|
||||
Port: "7002",
|
||||
},
|
||||
}
|
||||
|
||||
mockQP := new(qpmocks.QuoteProvider)
|
||||
mockQP.On("GetRawQuote", mock.Anything).Return([]byte{}, nil)
|
||||
|
||||
err := verifyManifest(cfg, mockQP)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Ultraviolet
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
coverage:
|
||||
ignore:
|
||||
- "test/*"
|
||||
@@ -6,7 +6,6 @@ require (
|
||||
github.com/absmach/magistrala v0.14.1-0.20240709113739-04c359462746
|
||||
github.com/caarlos0/env/v11 v11.2.2
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752
|
||||
github.com/fatih/color v1.17.0
|
||||
github.com/go-kit/kit v0.13.0
|
||||
github.com/gofrs/uuid v4.4.0+incompatible
|
||||
|
||||
@@ -19,8 +19,6 @@ github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752 h1:NI7XEcHzWVvBfVjSVK6Qk4wmrUfoyQxCNpBjrHelZFk=
|
||||
github.com/digitalocean/go-libvirt v0.0.0-20240709142323-d8406205c752/go.mod h1:/Ok8PA2qi/ve0Py38+oL+VxoYmlowigYRyLEODRYdgc=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v27.3.1+incompatible h1:KttF0XoteNTicmUtBO0L2tP+J7FGRFTjaEF4k6WdhfI=
|
||||
|
||||
@@ -80,7 +80,7 @@ func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkg
|
||||
case *pkgmanager.ServerStreamMessage_StopComputation:
|
||||
go client.handleStopComputation(ctx, mes)
|
||||
case *pkgmanager.ServerStreamMessage_BackendInfoReq:
|
||||
go client.handleBackendInfoReq(mes)
|
||||
go client.handleBackendInfoReq(ctx, mes)
|
||||
default:
|
||||
return errors.New("unknown message type")
|
||||
}
|
||||
@@ -133,8 +133,8 @@ func (client ManagerClient) handleStopComputation(ctx context.Context, mes *pkgm
|
||||
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: msg})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleBackendInfoReq(mes *pkgmanager.ServerStreamMessage_BackendInfoReq) {
|
||||
res, err := client.svc.FetchBackendInfo()
|
||||
func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *pkgmanager.ServerStreamMessage_BackendInfoReq) {
|
||||
res, err := client.svc.FetchBackendInfo(ctx, mes.BackendInfoReq.Id)
|
||||
if err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
|
||||
@@ -150,32 +150,57 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManagerClient_handleBackendInfoReq(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &pkgmanager.BackendInfoReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &pkgmanager.BackendInfoReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("FetchBackendInfo").Return([]byte("test-backend-info"), nil)
|
||||
mockSvc.On("FetchBackendInfo", context.Background(), infoReq.BackendInfoReq.Id).Return([]byte("test-backend-info"), nil)
|
||||
|
||||
client.handleBackendInfoReq(infoReq)
|
||||
client.handleBackendInfoReq(context.Background(), infoReq)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
|
||||
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
|
||||
msg := <-messageQueue
|
||||
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
|
||||
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
|
||||
})
|
||||
t.Run("error", func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &pkgmanager.BackendInfoReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("FetchBackendInfo", context.Background(), infoReq.BackendInfoReq.Id).Return(nil, assert.AnError)
|
||||
|
||||
client.handleBackendInfoReq(context.Background(), infoReq)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 0)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -58,11 +58,16 @@ func (lm *loggingMiddleware) RetrieveAgentEventsLogs() {
|
||||
lm.svc.RetrieveAgentEventsLogs()
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) FetchBackendInfo() ([]byte, error) {
|
||||
func (lm *loggingMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) (body []byte, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method FetchBackendInfo took %s to complete", time.Since(begin))
|
||||
message := fmt.Sprintf("Method FetchBackendInfo for computation %s took %s to complete", cmpId, time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
|
||||
return
|
||||
}
|
||||
|
||||
lm.logger.Info(message)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.FetchBackendInfo()
|
||||
return lm.svc.FetchBackendInfo(ctx, cmpId)
|
||||
}
|
||||
|
||||
@@ -55,11 +55,11 @@ func (ms *metricsMiddleware) RetrieveAgentEventsLogs() {
|
||||
ms.svc.RetrieveAgentEventsLogs()
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) FetchBackendInfo() ([]byte, error) {
|
||||
func (ms *metricsMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) ([]byte, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "FetchBackendInfo").Add(1)
|
||||
ms.latency.With("method", "FetchBackendInfo").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.FetchBackendInfo()
|
||||
return ms.svc.FetchBackendInfo(ctx, cmpId)
|
||||
}
|
||||
|
||||
+21
-7
@@ -7,6 +7,7 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -22,9 +23,21 @@ import (
|
||||
|
||||
const defGuestFeatures = 0x1
|
||||
|
||||
func (ms *managerService) FetchBackendInfo() ([]byte, error) {
|
||||
func (ms *managerService) FetchBackendInfo(_ context.Context, computationId string) ([]byte, error) {
|
||||
cmd := exec.Command("sudo", fmt.Sprintf("%s/backend_info", ms.backendMeasurementBinaryPath), "--policy", "1966081")
|
||||
|
||||
ms.mu.Lock()
|
||||
vm, exists := ms.vms[computationId]
|
||||
ms.mu.Unlock()
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("computationId %s not found", computationId)
|
||||
}
|
||||
|
||||
config, ok := vm.GetConfig().(qemu.Config)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to cast config to qemu.Config")
|
||||
}
|
||||
|
||||
_, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -42,13 +55,14 @@ func (ms *managerService) FetchBackendInfo() ([]byte, error) {
|
||||
}
|
||||
|
||||
var measurement []byte
|
||||
if ms.qemuCfg.EnableSEV {
|
||||
measurement, err = guest.CalcLaunchDigest(guest.SEV, ms.qemuCfg.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), ms.qemuCfg.OVMFCodeConfig.File, ms.qemuCfg.KernelFile, ms.qemuCfg.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
|
||||
switch {
|
||||
case config.EnableSEV:
|
||||
measurement, err = guest.CalcLaunchDigest(guest.SEV, config.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), config.OVMFCodeConfig.File, config.KernelFile, config.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if ms.qemuCfg.EnableSEVSNP {
|
||||
measurement, err = guest.CalcLaunchDigest(guest.SEV_SNP, ms.qemuCfg.SMPCount, uint64(cpuid.CpuSigs[ms.qemuCfg.CPU]), ms.qemuCfg.OVMFCodeConfig.File, ms.qemuCfg.KernelFile, ms.qemuCfg.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
|
||||
case config.EnableSEVSNP:
|
||||
measurement, err = guest.CalcLaunchDigest(guest.SEV_SNP, config.SMPCount, uint64(cpuid.CpuSigs[config.CPU]), config.OVMFCodeConfig.File, config.KernelFile, config.RootFsFile, qemu.KernelCommandLine, defGuestFeatures, "", vmmtypes.QEMU, false, "", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -57,8 +71,8 @@ func (ms *managerService) FetchBackendInfo() ([]byte, error) {
|
||||
backendInfo.SNPPolicy.Measurement = measurement
|
||||
}
|
||||
|
||||
if ms.qemuCfg.HostData != "" {
|
||||
hostData, err := base64.StdEncoding.DecodeString(ms.qemuCfg.HostData)
|
||||
if config.HostData != "" {
|
||||
hostData, err := base64.StdEncoding.DecodeString(config.HostData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -6,8 +6,12 @@
|
||||
|
||||
package manager
|
||||
|
||||
import backendinfo "github.com/ultravioletrs/cocos/scripts/backend_info"
|
||||
import (
|
||||
"context"
|
||||
|
||||
func (ms *managerService) FetchBackendInfo() ([]byte, error) {
|
||||
backendinfo "github.com/ultravioletrs/cocos/scripts/backend_info"
|
||||
)
|
||||
|
||||
func (ms *managerService) FetchBackendInfo(_ context.Context, _ string) ([]byte, error) {
|
||||
return backendinfo.BackendInfo, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/manager/vm/mocks"
|
||||
)
|
||||
|
||||
func createDummyBackendInfoBinary(t *testing.T, behavior string) string {
|
||||
var content []byte
|
||||
switch behavior {
|
||||
case "success":
|
||||
content = []byte(`#!/bin/sh
|
||||
echo '{"snp_policy": {"measurement": null, "host_data": null}}' > backend_info.json
|
||||
`)
|
||||
case "fail":
|
||||
content = []byte(`#!/bin/sh
|
||||
echo "Error: Failed to execute backend_info" >&2
|
||||
exit 1
|
||||
`)
|
||||
case "no_json":
|
||||
content = []byte(`#!/bin/sh
|
||||
echo 'No JSON file created'
|
||||
`)
|
||||
default:
|
||||
t.Fatalf("Unknown behavior: %s", behavior)
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
binaryPath := filepath.Join(tempDir, "backend_info")
|
||||
err := os.WriteFile(binaryPath, content, 0o755)
|
||||
assert.NoError(t, err)
|
||||
return tempDir
|
||||
}
|
||||
|
||||
func TestFetchBackendInfo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
computationId string
|
||||
vmConfig interface{}
|
||||
binaryBehavior string
|
||||
expectedError string
|
||||
expectedResult map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "Valid SEV configuration",
|
||||
computationId: "sev-computation",
|
||||
binaryBehavior: "success",
|
||||
vmConfig: qemu.Config{
|
||||
EnableSEV: true,
|
||||
SMPCount: 2,
|
||||
CPU: "EPYC",
|
||||
OVMFCodeConfig: qemu.OVMFCodeConfig{
|
||||
File: "/path/to/OVMF_CODE.fd",
|
||||
},
|
||||
},
|
||||
expectedError: "open /path/to/OVMF_CODE.fd: no such file or directory",
|
||||
},
|
||||
{
|
||||
name: "Valid SEV-SNP configuration",
|
||||
computationId: "sev-snp-computation",
|
||||
binaryBehavior: "success",
|
||||
vmConfig: qemu.Config{
|
||||
EnableSEVSNP: true,
|
||||
SMPCount: 4,
|
||||
CPU: "EPYC-v2",
|
||||
OVMFCodeConfig: qemu.OVMFCodeConfig{
|
||||
File: "/path/to/OVMF_CODE_SNP.fd",
|
||||
},
|
||||
},
|
||||
expectedError: "open /path/to/OVMF_CODE_SNP.fd: no such file or director",
|
||||
},
|
||||
{
|
||||
name: "Invalid computation ID",
|
||||
computationId: "non-existent",
|
||||
binaryBehavior: "success",
|
||||
vmConfig: qemu.Config{},
|
||||
expectedError: "computationId non-existent not found",
|
||||
},
|
||||
{
|
||||
name: "Invalid config type",
|
||||
computationId: "invalid-config",
|
||||
binaryBehavior: "success",
|
||||
vmConfig: struct{}{},
|
||||
expectedError: "failed to cast config to qemu.Config",
|
||||
},
|
||||
{
|
||||
name: "Binary execution failure",
|
||||
computationId: "binary-fail",
|
||||
binaryBehavior: "fail",
|
||||
vmConfig: qemu.Config{
|
||||
EnableSEV: true,
|
||||
},
|
||||
expectedError: "exit status 1",
|
||||
},
|
||||
{
|
||||
name: "JSON file not created",
|
||||
computationId: "no-json",
|
||||
binaryBehavior: "no_json",
|
||||
vmConfig: qemu.Config{
|
||||
EnableSEV: true,
|
||||
},
|
||||
expectedError: "no such file or directory",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tempDir := createDummyBackendInfoBinary(t, tc.binaryBehavior)
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
ms := &managerService{
|
||||
vms: make(map[string]vm.VM),
|
||||
backendMeasurementBinaryPath: tempDir,
|
||||
qemuCfg: qemu.Config{
|
||||
CPU: "EPYC",
|
||||
},
|
||||
}
|
||||
|
||||
mockVM := new(mocks.VM)
|
||||
mockVM.On("GetConfig").Return(tc.vmConfig)
|
||||
|
||||
if tc.computationId != "non-existent" {
|
||||
ms.vms[tc.computationId] = mockVM
|
||||
}
|
||||
|
||||
result, err := ms.FetchBackendInfo(context.Background(), tc.computationId)
|
||||
|
||||
if tc.expectedError != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expectedError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
|
||||
var backendInfo map[string]interface{}
|
||||
err = json.Unmarshal(result, &backendInfo)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expectedResult, backendInfo)
|
||||
}
|
||||
|
||||
if tc.binaryBehavior == "success" {
|
||||
os.Remove("backend_info.json")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,9 @@ type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// FetchBackendInfo provides a mock function with given fields:
|
||||
func (_m *Service) FetchBackendInfo() ([]byte, error) {
|
||||
ret := _m.Called()
|
||||
// FetchBackendInfo provides a mock function with given fields: ctx, computationID
|
||||
func (_m *Service) FetchBackendInfo(ctx context.Context, computationID string) ([]byte, error) {
|
||||
ret := _m.Called(ctx, computationID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for FetchBackendInfo")
|
||||
@@ -28,19 +28,19 @@ func (_m *Service) FetchBackendInfo() ([]byte, error) {
|
||||
|
||||
var r0 []byte
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() ([]byte, error)); ok {
|
||||
return rf()
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) ([]byte, error)); ok {
|
||||
return rf(ctx, computationID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() []byte); ok {
|
||||
r0 = rf()
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) []byte); ok {
|
||||
r0 = rf(ctx, computationID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]byte)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, computationID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
@@ -197,3 +197,7 @@ func processExists(pid int) bool {
|
||||
func (v *qemuVM) GetCID() int {
|
||||
return v.config.GuestCID
|
||||
}
|
||||
|
||||
func (v *qemuVM) GetConfig() interface{} {
|
||||
return v.config
|
||||
}
|
||||
|
||||
+82
-19
@@ -10,15 +10,17 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/manager/vm/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
const testComputationID = "test-computation"
|
||||
|
||||
func TestNewVM(t *testing.T) {
|
||||
config := Config{}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
computationId := "test-computation"
|
||||
|
||||
vm := NewVM(config, logsChan, computationId)
|
||||
vm := NewVM(config, logsChan, testComputationID)
|
||||
|
||||
assert.NotNil(t, vm)
|
||||
assert.IsType(t, &qemuVM{}, vm)
|
||||
@@ -34,35 +36,84 @@ func TestStart(t *testing.T) {
|
||||
OVMFVarsConfig: OVMFVarsConfig{
|
||||
File: tmpFile.Name(),
|
||||
},
|
||||
QemuBinPath: "echo", // Use 'echo' as a dummy QEMU binary
|
||||
QemuBinPath: "echo",
|
||||
}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
computationId := "test-computation"
|
||||
|
||||
vm := NewVM(config, logsChan, computationId).(*qemuVM)
|
||||
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, vm.cmd)
|
||||
|
||||
_ = vm.Stop()
|
||||
}
|
||||
|
||||
func TestStartSudo(t *testing.T) {
|
||||
// Create a temporary file for testing
|
||||
tmpFile, err := os.CreateTemp("", "test-ovmf-vars")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
config := Config{
|
||||
OVMFVarsConfig: OVMFVarsConfig{
|
||||
File: tmpFile.Name(),
|
||||
},
|
||||
QemuBinPath: "echo",
|
||||
UseSudo: true,
|
||||
}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, vm.cmd)
|
||||
|
||||
// Clean up
|
||||
_ = vm.Stop()
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
cmd := exec.Command("echo", "test")
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
t.Run("success", func(t *testing.T) {
|
||||
cmd := exec.Command("echo", "test")
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
sm := new(mocks.StateMachine)
|
||||
sm.On("Transition", manager.StopComputationRun).Return(nil)
|
||||
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
Process: cmd.Process,
|
||||
},
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
}
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
Process: cmd.Process,
|
||||
},
|
||||
StateMachine: sm,
|
||||
}
|
||||
|
||||
err = vm.Stop()
|
||||
assert.NoError(t, err)
|
||||
err = vm.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
t.Run("transition error", func(t *testing.T) {
|
||||
cmd := exec.Command("echo", "test")
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
sm := new(mocks.StateMachine)
|
||||
sm.On("Transition", manager.StopComputationRun).Return(assert.AnError)
|
||||
sm.On("State").Return(manager.Stopped.String())
|
||||
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
Process: cmd.Process,
|
||||
},
|
||||
StateMachine: sm,
|
||||
logsChan: make(chan *manager.ClientStreamMessage),
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-vm.logsChan
|
||||
}()
|
||||
|
||||
err = vm.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetProcess(t *testing.T) {
|
||||
@@ -104,11 +155,23 @@ func TestGetCID(t *testing.T) {
|
||||
assert.Equal(t, expectedCID, cid)
|
||||
}
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
expectedConfig := Config{
|
||||
QemuBinPath: "echo",
|
||||
}
|
||||
vm := &qemuVM{
|
||||
config: expectedConfig,
|
||||
}
|
||||
|
||||
config := vm.GetConfig()
|
||||
assert.Equal(t, expectedConfig, config)
|
||||
}
|
||||
|
||||
func TestCheckVMProcessPeriodically(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
vm := &qemuVM{
|
||||
logsChan: logsChan,
|
||||
computationId: "test-computation",
|
||||
computationId: testComputationID,
|
||||
cmd: &exec.Cmd{
|
||||
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
|
||||
},
|
||||
@@ -120,7 +183,7 @@ func TestCheckVMProcessPeriodically(t *testing.T) {
|
||||
select {
|
||||
case msg := <-logsChan:
|
||||
assert.NotNil(t, msg.GetAgentEvent())
|
||||
assert.Equal(t, "test-computation", msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, testComputationID, msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, manager.VmProvision.String(), msg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status)
|
||||
case <-time.After(2 * interval):
|
||||
|
||||
+6
-1
@@ -63,7 +63,7 @@ type Service interface {
|
||||
// RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock.
|
||||
RetrieveAgentEventsLogs()
|
||||
// FetchBackendInfo measures and fetches the backend information.
|
||||
FetchBackendInfo() ([]byte, error)
|
||||
FetchBackendInfo(ctx context.Context, computationID string) ([]byte, error)
|
||||
}
|
||||
|
||||
type managerService struct {
|
||||
@@ -128,6 +128,11 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
|
||||
LogLevel: c.AgentConfig.LogLevel,
|
||||
},
|
||||
}
|
||||
if len(c.Algorithm.Hash) != hashLength {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
return "", errInvalidHashLength
|
||||
}
|
||||
|
||||
ac.Algorithm = agent.Algorithm{Hash: [hashLength]byte(c.Algorithm.Hash), UserKey: c.Algorithm.UserKey}
|
||||
|
||||
for _, data := range c.Datasets {
|
||||
|
||||
+50
-3
@@ -5,12 +5,15 @@ package manager
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -74,6 +77,37 @@ func TestRun(t *testing.T) {
|
||||
vmStartError: assert.AnError,
|
||||
expectedError: assert.AnError,
|
||||
},
|
||||
{
|
||||
name: "Invalid algorithm hash",
|
||||
req: &manager.ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &manager.Algorithm{
|
||||
Hash: make([]byte, hashLength-1),
|
||||
},
|
||||
AgentConfig: &manager.AgentConfig{},
|
||||
},
|
||||
vmStartError: nil,
|
||||
expectedError: errInvalidHashLength,
|
||||
},
|
||||
{
|
||||
name: "Invalid dataset hash",
|
||||
req: &manager.ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &manager.Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &manager.AgentConfig{},
|
||||
Datasets: []*manager.Dataset{
|
||||
{
|
||||
Hash: make([]byte, hashLength-1),
|
||||
},
|
||||
},
|
||||
},
|
||||
vmStartError: nil,
|
||||
expectedError: errInvalidHashLength,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -211,7 +245,14 @@ func TestGetFreePort(t *testing.T) {
|
||||
port, err := getFreePort(6000, 6100)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, port, 0)
|
||||
assert.GreaterOrEqual(t, port, 6000)
|
||||
|
||||
_, err = net.Listen("tcp", net.JoinHostPort("localhost", fmt.Sprint(port)))
|
||||
assert.NoError(t, err)
|
||||
|
||||
port, err = getFreePort(6000, 6100)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, port, 6000)
|
||||
}
|
||||
|
||||
func TestPublishEvent(t *testing.T) {
|
||||
@@ -338,12 +379,18 @@ func TestRestoreVMs(t *testing.T) {
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
|
||||
cmd2 := exec.Command("echo", "test")
|
||||
err = cmd2.Run()
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockPersistence.On("LoadVMs").Return([]qemu.VMState{
|
||||
{ID: "vm1", PID: cmd.Process.Pid},
|
||||
{ID: "vm2", PID: 1000},
|
||||
{ID: "vm2", PID: cmd2.Process.Pid},
|
||||
{ID: "vm3", PID: cmd2.Process.Pid},
|
||||
}, nil)
|
||||
|
||||
mockPersistence.On("DeleteVM", mock.Anything).Return(nil)
|
||||
mockPersistence.On("DeleteVM", "vm2").Return(nil)
|
||||
mockPersistence.On("DeleteVM", "vm3").Return(errors.New("failed to delete"))
|
||||
|
||||
err = ms.restoreVMs()
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -40,9 +40,9 @@ func (tm *tracingMiddleware) RetrieveAgentEventsLogs() {
|
||||
tm.svc.RetrieveAgentEventsLogs()
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) FetchBackendInfo() ([]byte, error) {
|
||||
_, span := tm.tracer.Start(context.Background(), "fetch_backend_info")
|
||||
func (tm *tracingMiddleware) FetchBackendInfo(ctx context.Context, computationId string) ([]byte, error) {
|
||||
_, span := tm.tracer.Start(ctx, "fetch_backend_info")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.FetchBackendInfo()
|
||||
return tm.svc.FetchBackendInfo(ctx, computationId)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
manager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
// StateMachine is an autogenerated mock type for the StateMachine type
|
||||
type StateMachine struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// State provides a mock function with given fields:
|
||||
func (_m *StateMachine) State() string {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for State")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Transition provides a mock function with given fields: newState
|
||||
func (_m *StateMachine) Transition(newState manager.ManagerState) error {
|
||||
ret := _m.Called(newState)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Transition")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(manager.ManagerState) error); ok {
|
||||
r0 = rf(newState)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewStateMachine(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *StateMachine {
|
||||
mock := &StateMachine{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -35,6 +35,26 @@ func (_m *VM) GetCID() int {
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetConfig provides a mock function with given fields:
|
||||
func (_m *VM) GetConfig() interface{} {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetConfig")
|
||||
}
|
||||
|
||||
var r0 interface{}
|
||||
if rf, ok := ret.Get(0).(func() interface{}); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// GetProcess provides a mock function with given fields:
|
||||
func (_m *VM) GetProcess() int {
|
||||
ret := _m.Called()
|
||||
|
||||
@@ -14,6 +14,7 @@ type sm struct {
|
||||
state manager.ManagerState
|
||||
}
|
||||
|
||||
//go:generate mockery --name StateMachine --output=./mocks --filename state_machine.go --quiet
|
||||
type StateMachine interface {
|
||||
Transition(newState manager.ManagerState) error
|
||||
State() string
|
||||
|
||||
@@ -19,6 +19,7 @@ type VM interface {
|
||||
GetCID() int
|
||||
Transition(newState manager.ManagerState) error
|
||||
State() string
|
||||
GetConfig() interface{}
|
||||
}
|
||||
|
||||
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
|
||||
Reference in New Issue
Block a user