COCOS-214 - Improve manager resiliance by tracking vms on restart (#219)

* track hanging vm processes

Signed-off-by: SammyOina <sammyoina@gmail.com>

* fix lint

Signed-off-by: SammyOina <sammyoina@gmail.com>

* fix run test

Signed-off-by: SammyOina <sammyoina@gmail.com>

* fix stop computation

Signed-off-by: SammyOina <sammyoina@gmail.com>

* shutdown gracefully

Signed-off-by: SammyOina <sammyoina@gmail.com>

* check if process still exists

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* use const

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: SammyOina <sammyoina@gmail.com>
Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-08-30 19:08:11 +03:00
committed by GitHub
parent e572793295
commit 9ca045b06a
10 changed files with 372 additions and 14 deletions
+1 -1
View File
@@ -115,7 +115,7 @@ func main() {
if err != nil {
log.Fatal("failed to reconnect: ", err)
}
handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level})
handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID)
logger = slog.New(handler)
}
time.Sleep(retryInterval)
+17
View File
@@ -10,7 +10,9 @@ import (
"log/slog"
"net/url"
"os"
"os/signal"
"strings"
"syscall"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/jaeger"
@@ -120,6 +122,21 @@ func main() {
mc := managerapi.NewClient(pc, svc, eventsChan, logger)
g.Go(func() error {
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(ch)
select {
case <-ch:
logger.Info("Received signal, shutting down...")
cancel()
return nil
case <-ctx.Done():
return ctx.Err()
}
})
g.Go(func() error {
return mc.Process(ctx, cancel)
})
+2
View File
@@ -7,6 +7,8 @@ import (
"strconv"
)
const BaseGuestCID = 3
type MemoryConfig struct {
Size string `env:"MEMORY_SIZE" envDefault:"2048M"`
Slots int `env:"MEMORY_SLOTS" envDefault:"5"`
+96
View File
@@ -0,0 +1,96 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
mock "github.com/stretchr/testify/mock"
qemu "github.com/ultravioletrs/cocos/manager/qemu"
)
// Persistence is an autogenerated mock type for the Persistence type
type Persistence struct {
mock.Mock
}
// DeleteVM provides a mock function with given fields: id
func (_m *Persistence) DeleteVM(id string) error {
ret := _m.Called(id)
if len(ret) == 0 {
panic("no return value specified for DeleteVM")
}
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(id)
} else {
r0 = ret.Error(0)
}
return r0
}
// LoadVMs provides a mock function with given fields:
func (_m *Persistence) LoadVMs() ([]qemu.VMState, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for LoadVMs")
}
var r0 []qemu.VMState
var r1 error
if rf, ok := ret.Get(0).(func() ([]qemu.VMState, error)); ok {
return rf()
}
if rf, ok := ret.Get(0).(func() []qemu.VMState); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]qemu.VMState)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// SaveVM provides a mock function with given fields: state
func (_m *Persistence) SaveVM(state qemu.VMState) error {
ret := _m.Called(state)
if len(ret) == 0 {
panic("no return value specified for SaveVM")
}
var r0 error
if rf, ok := ret.Get(0).(func(qemu.VMState) error); ok {
r0 = rf(state)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewPersistence creates a new instance of Persistence. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewPersistence(t interface {
mock.TestingT
Cleanup(func())
}) *Persistence {
mock := &Persistence{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+89
View File
@@ -0,0 +1,89 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package qemu
import (
"encoding/json"
"os"
"path/filepath"
"sync"
)
const jsonExt = ".json"
type VMState struct {
ID string
Config Config
PID int
}
type FilePersistence struct {
dir string
lock sync.Mutex
}
// Persistence is an interface for saving and loading VM states.
//
//go:generate mockery --name Persistence --output=./mocks --filename persistence.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type Persistence interface {
SaveVM(state VMState) error
LoadVMs() ([]VMState, error)
DeleteVM(id string) error
}
func NewFilePersistence(dir string) (Persistence, error) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
return &FilePersistence{dir: dir}, nil
}
func (fp *FilePersistence) SaveVM(state VMState) error {
fp.lock.Lock()
defer fp.lock.Unlock()
data, err := json.Marshal(state)
if err != nil {
return err
}
return os.WriteFile(filepath.Join(fp.dir, state.ID+jsonExt), data, 0o644)
}
func (fp *FilePersistence) LoadVMs() ([]VMState, error) {
fp.lock.Lock()
defer fp.lock.Unlock()
files, err := os.ReadDir(fp.dir)
if err != nil {
return nil, err
}
var states []VMState
for _, file := range files {
if filepath.Ext(file.Name()) != jsonExt {
continue
}
data, err := os.ReadFile(filepath.Join(fp.dir, file.Name()))
if err != nil {
return nil, err
}
var state VMState
if err := json.Unmarshal(data, &state); err != nil {
return nil, err
}
states = append(states, state)
}
return states, nil
}
func (fp *FilePersistence) DeleteVM(id string) error {
fp.lock.Lock()
defer fp.lock.Unlock()
return os.Remove(filepath.Join(fp.dir, id+jsonExt))
}
+24 -3
View File
@@ -4,6 +4,7 @@ package qemu
import (
"fmt"
"os"
"os/exec"
"github.com/gofrs/uuid"
@@ -38,9 +39,9 @@ func (v *qemuVM) Start() error {
if err != nil {
return err
}
qemuCfg := v.config
qemuCfg.NetDevConfig.ID = fmt.Sprintf("%s-%s", qemuCfg.NetDevConfig.ID, id)
qemuCfg.SevConfig.ID = fmt.Sprintf("%s-%s", qemuCfg.SevConfig.ID, id)
v.config.NetDevConfig.ID = fmt.Sprintf("%s-%s", v.config.NetDevConfig.ID, id)
v.config.SevConfig.ID = fmt.Sprintf("%s-%s", v.config.SevConfig.ID, id)
exe, args, err := v.executableAndArgs()
if err != nil {
@@ -58,6 +59,26 @@ func (v *qemuVM) Stop() error {
return v.cmd.Process.Kill()
}
func (v *qemuVM) SetProcess(pid int) error {
process, err := os.FindProcess(pid)
if err != nil {
return err
}
exe, args, err := v.executableAndArgs()
if err != nil {
return err
}
v.cmd = exec.Command(exe, args...)
v.cmd.Process = process
return nil
}
func (v *qemuVM) GetProcess() int {
return v.cmd.Process.Pid
}
func (v *qemuVM) executableAndArgs() (string, []string, error) {
exe, err := exec.LookPath(v.config.QemuBinPath)
if err != nil {
+87 -2
View File
@@ -9,9 +9,11 @@ import (
"fmt"
"log/slog"
"net"
"os"
"regexp"
"strconv"
"sync"
"syscall"
"github.com/absmach/magistrala/pkg/errors"
"github.com/cenkalti/backoff/v4"
@@ -23,7 +25,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)
const hashLength = 32
const (
hashLength = 32
persistenceDir = "/tmp/cocos"
)
var (
// ErrMalformedEntity indicates malformed entity specification (e.g.
@@ -68,6 +73,7 @@ type managerService struct {
vmFactory vm.Provider
portRangeMin int
portRangeMax int
persistence qemu.Persistence
}
var _ Service = (*managerService)(nil)
@@ -78,6 +84,12 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger,
if err != nil {
return nil, err
}
persistence, err := qemu.NewFilePersistence(persistenceDir)
if err != nil {
return nil, err
}
ms := &managerService{
qemuCfg: cfg,
logger: logger,
@@ -87,7 +99,13 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger,
backendMeasurementBinaryPath: backendMeasurementBinPath,
portRangeMin: start,
portRangeMax: end,
persistence: persistence,
}
if err := ms.restoreVMs(); err != nil {
return nil, err
}
return ms, nil
}
@@ -127,6 +145,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
return "", errors.Wrap(ErrFailedToAllocatePort, err)
}
ms.qemuCfg.HostFwdAgent = agentPort
ms.qemuCfg.VSockConfig.GuestCID = qemu.BaseGuestCID + len(ms.vms)
ch, err := computationHash(ac)
if err != nil {
@@ -145,13 +164,24 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
}
ms.vms[c.Id] = cvm
pid := cvm.GetProcess()
state := qemu.VMState{
ID: c.Id,
Config: ms.qemuCfg,
PID: pid,
}
if err := ms.persistence.SaveVM(state); err != nil {
ms.logger.Error("Failed to persist VM state", "error", err)
}
err = backoff.Retry(func() error {
return cvm.SendAgentConfig(ac)
}, backoff.NewExponentialBackOff())
if err != nil {
return "", err
}
ms.qemuCfg.VSockConfig.GuestCID++
ms.qemuCfg.VSockConfig.Vnc++
ms.publishEvent("vm-provision", c.Id, "complete", json.RawMessage{})
@@ -169,6 +199,11 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error
return err
}
delete(ms.vms, computationID)
if err := ms.persistence.DeleteVM(computationID); err != nil {
ms.logger.Error("Failed to delete persisted VM state", "error", err)
}
defer ms.publishEvent("stop-computation", computationID, "complete", json.RawMessage{})
return nil
}
@@ -264,3 +299,53 @@ func decodeRange(input string) (int, int, error) {
return start, end, nil
}
func (ms *managerService) restoreVMs() error {
states, err := ms.persistence.LoadVMs()
if err != nil {
return err
}
for _, state := range states {
exists, err := processExists(state.PID)
if err != nil {
ms.logger.Warn("Failed to check process existence", "computation", state.ID, "pid", state.PID, "error", err)
continue
}
if !exists {
if err := ms.persistence.DeleteVM(state.ID); err != nil {
ms.logger.Error("Failed to delete persisted VM state", "computation", state.ID, "error", err)
}
ms.logger.Info("Deleted persisted state for non-existent process", "computation", state.ID, "pid", state.PID)
continue
}
cvm := ms.vmFactory(state.Config, ms.eventsChan, state.ID)
if err = cvm.SetProcess(state.PID); err != nil {
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
continue
}
ms.vms[state.ID] = cvm
ms.logger.Info("Successfully restored VM state", "id", state.ID, "computationId", state.ID, "pid", state.PID)
}
return nil
}
func processExists(pid int) (bool, error) {
process, err := os.FindProcess(pid)
if err != nil {
return false, err
}
if err = process.Signal(syscall.Signal(0)); err == nil {
return true, nil
}
if err == syscall.ESRCH {
return false, nil
}
return false, err
}
+18 -8
View File
@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/ultravioletrs/cocos/manager/qemu"
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/manager/vm/mocks"
"github.com/ultravioletrs/cocos/pkg/manager"
@@ -35,6 +36,7 @@ func TestNew(t *testing.T) {
func TestRun(t *testing.T) {
vmf := new(mocks.Provider)
vmMock := new(mocks.VM)
persistence := new(persistenceMocks.Persistence)
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
tests := []struct {
name string
@@ -79,6 +81,9 @@ func TestRun(t *testing.T) {
}
vmMock.On("SendAgentConfig", mock.Anything).Return(nil)
vmMock.On("GetProcess").Return(1234)
persistence.On("SaveVM", mock.Anything).Return(nil)
qemuCfg := qemu.Config{
VSockConfig: qemu.VSockConfig{
@@ -90,11 +95,12 @@ func TestRun(t *testing.T) {
eventsChan := make(chan *manager.ClientStreamMessage, 10)
ms := &managerService{
qemuCfg: qemuCfg,
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
vmFactory: vmf.Execute,
qemuCfg: qemuCfg,
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
vmFactory: vmf.Execute,
persistence: persistence,
}
ctx := context.Background()
@@ -123,6 +129,7 @@ func TestRun(t *testing.T) {
func TestStop(t *testing.T) {
vmf := new(mocks.Provider)
vmMock := new(mocks.VM)
persistence := new(persistenceMocks.Persistence)
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
tests := []struct {
@@ -160,9 +167,10 @@ func TestStop(t *testing.T) {
logger := slog.Default()
eventsChan := make(chan *manager.ClientStreamMessage, 10)
ms := &managerService{
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
logger: logger,
vms: make(map[string]vm.VM),
eventsChan: eventsChan,
persistence: persistence,
}
vmMock := new(mocks.VM)
@@ -172,6 +180,8 @@ func TestStop(t *testing.T) {
vmMock.On("Stop").Return(assert.AnError).Once()
}
persistence.On("DeleteVM", tt.computationID).Return(nil)
if tt.initialVMCount > 0 {
ms.vms[tt.computationID] = vmMock
}
+36
View File
@@ -15,6 +15,24 @@ type VM struct {
mock.Mock
}
// GetProcess provides a mock function with given fields:
func (_m *VM) GetProcess() int {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetProcess")
}
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// SendAgentConfig provides a mock function with given fields: ac
func (_m *VM) SendAgentConfig(ac agent.Computation) error {
ret := _m.Called(ac)
@@ -33,6 +51,24 @@ func (_m *VM) SendAgentConfig(ac agent.Computation) error {
return r0
}
// SetProcess provides a mock function with given fields: pid
func (_m *VM) SetProcess(pid int) error {
ret := _m.Called(pid)
if len(ret) == 0 {
panic("no return value specified for SetProcess")
}
var r0 error
if rf, ok := ret.Get(0).(func(int) error); ok {
r0 = rf(pid)
} else {
r0 = ret.Error(0)
}
return r0
}
// Start provides a mock function with given fields:
func (_m *VM) Start() error {
ret := _m.Called()
+2
View File
@@ -14,6 +14,8 @@ type VM interface {
Start() error
Stop() error
SendAgentConfig(ac agent.Computation) error
SetProcess(pid int) error
GetProcess() int
}
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"