mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-214 - Improve manager resiliance by tracking vms on restart (#219)
* track hanging vm processes Signed-off-by: SammyOina <sammyoina@gmail.com> * fix lint Signed-off-by: SammyOina <sammyoina@gmail.com> * fix run test Signed-off-by: SammyOina <sammyoina@gmail.com> * fix stop computation Signed-off-by: SammyOina <sammyoina@gmail.com> * shutdown gracefully Signed-off-by: SammyOina <sammyoina@gmail.com> * check if process still exists Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> * use const Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: SammyOina <sammyoina@gmail.com> Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e572793295
commit
9ca045b06a
+1
-1
@@ -115,7 +115,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal("failed to reconnect: ", err)
|
||||
}
|
||||
handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level})
|
||||
handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID)
|
||||
logger = slog.New(handler)
|
||||
}
|
||||
time.Sleep(retryInterval)
|
||||
|
||||
@@ -10,7 +10,9 @@ import (
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/jaeger"
|
||||
@@ -120,6 +122,21 @@ func main() {
|
||||
|
||||
mc := managerapi.NewClient(pc, svc, eventsChan, logger)
|
||||
|
||||
g.Go(func() error {
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(ch)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
logger.Info("Received signal, shutting down...")
|
||||
cancel()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return mc.Process(ctx, cancel)
|
||||
})
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const BaseGuestCID = 3
|
||||
|
||||
type MemoryConfig struct {
|
||||
Size string `env:"MEMORY_SIZE" envDefault:"2048M"`
|
||||
Slots int `env:"MEMORY_SLOTS" envDefault:"5"`
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
qemu "github.com/ultravioletrs/cocos/manager/qemu"
|
||||
)
|
||||
|
||||
// Persistence is an autogenerated mock type for the Persistence type
|
||||
type Persistence struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// DeleteVM provides a mock function with given fields: id
|
||||
func (_m *Persistence) DeleteVM(id string) error {
|
||||
ret := _m.Called(id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DeleteVM")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string) error); ok {
|
||||
r0 = rf(id)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// LoadVMs provides a mock function with given fields:
|
||||
func (_m *Persistence) LoadVMs() ([]qemu.VMState, error) {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for LoadVMs")
|
||||
}
|
||||
|
||||
var r0 []qemu.VMState
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func() ([]qemu.VMState, error)); ok {
|
||||
return rf()
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func() []qemu.VMState); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]qemu.VMState)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func() error); ok {
|
||||
r1 = rf()
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SaveVM provides a mock function with given fields: state
|
||||
func (_m *Persistence) SaveVM(state qemu.VMState) error {
|
||||
ret := _m.Called(state)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SaveVM")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(qemu.VMState) error); ok {
|
||||
r0 = rf(state)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewPersistence creates a new instance of Persistence. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewPersistence(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Persistence {
|
||||
mock := &Persistence{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package qemu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const jsonExt = ".json"
|
||||
|
||||
type VMState struct {
|
||||
ID string
|
||||
Config Config
|
||||
PID int
|
||||
}
|
||||
|
||||
type FilePersistence struct {
|
||||
dir string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// Persistence is an interface for saving and loading VM states.
|
||||
//
|
||||
//go:generate mockery --name Persistence --output=./mocks --filename persistence.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Persistence interface {
|
||||
SaveVM(state VMState) error
|
||||
LoadVMs() ([]VMState, error)
|
||||
DeleteVM(id string) error
|
||||
}
|
||||
|
||||
func NewFilePersistence(dir string) (Persistence, error) {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FilePersistence{dir: dir}, nil
|
||||
}
|
||||
|
||||
func (fp *FilePersistence) SaveVM(state VMState) error {
|
||||
fp.lock.Lock()
|
||||
defer fp.lock.Unlock()
|
||||
|
||||
data, err := json.Marshal(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(filepath.Join(fp.dir, state.ID+jsonExt), data, 0o644)
|
||||
}
|
||||
|
||||
func (fp *FilePersistence) LoadVMs() ([]VMState, error) {
|
||||
fp.lock.Lock()
|
||||
defer fp.lock.Unlock()
|
||||
|
||||
files, err := os.ReadDir(fp.dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var states []VMState
|
||||
for _, file := range files {
|
||||
if filepath.Ext(file.Name()) != jsonExt {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(fp.dir, file.Name()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var state VMState
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
states = append(states, state)
|
||||
}
|
||||
|
||||
return states, nil
|
||||
}
|
||||
|
||||
func (fp *FilePersistence) DeleteVM(id string) error {
|
||||
fp.lock.Lock()
|
||||
defer fp.lock.Unlock()
|
||||
|
||||
return os.Remove(filepath.Join(fp.dir, id+jsonExt))
|
||||
}
|
||||
+24
-3
@@ -4,6 +4,7 @@ package qemu
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
@@ -38,9 +39,9 @@ func (v *qemuVM) Start() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
qemuCfg := v.config
|
||||
qemuCfg.NetDevConfig.ID = fmt.Sprintf("%s-%s", qemuCfg.NetDevConfig.ID, id)
|
||||
qemuCfg.SevConfig.ID = fmt.Sprintf("%s-%s", qemuCfg.SevConfig.ID, id)
|
||||
|
||||
v.config.NetDevConfig.ID = fmt.Sprintf("%s-%s", v.config.NetDevConfig.ID, id)
|
||||
v.config.SevConfig.ID = fmt.Sprintf("%s-%s", v.config.SevConfig.ID, id)
|
||||
|
||||
exe, args, err := v.executableAndArgs()
|
||||
if err != nil {
|
||||
@@ -58,6 +59,26 @@ func (v *qemuVM) Stop() error {
|
||||
return v.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
func (v *qemuVM) SetProcess(pid int) error {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exe, args, err := v.executableAndArgs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v.cmd = exec.Command(exe, args...)
|
||||
v.cmd.Process = process
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *qemuVM) GetProcess() int {
|
||||
return v.cmd.Process.Pid
|
||||
}
|
||||
|
||||
func (v *qemuVM) executableAndArgs() (string, []string, error) {
|
||||
exe, err := exec.LookPath(v.config.QemuBinPath)
|
||||
if err != nil {
|
||||
|
||||
+87
-2
@@ -9,9 +9,11 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -23,7 +25,10 @@ import (
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const hashLength = 32
|
||||
const (
|
||||
hashLength = 32
|
||||
persistenceDir = "/tmp/cocos"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMalformedEntity indicates malformed entity specification (e.g.
|
||||
@@ -68,6 +73,7 @@ type managerService struct {
|
||||
vmFactory vm.Provider
|
||||
portRangeMin int
|
||||
portRangeMax int
|
||||
persistence qemu.Persistence
|
||||
}
|
||||
|
||||
var _ Service = (*managerService)(nil)
|
||||
@@ -78,6 +84,12 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
persistence, err := qemu.NewFilePersistence(persistenceDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ms := &managerService{
|
||||
qemuCfg: cfg,
|
||||
logger: logger,
|
||||
@@ -87,7 +99,13 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger,
|
||||
backendMeasurementBinaryPath: backendMeasurementBinPath,
|
||||
portRangeMin: start,
|
||||
portRangeMax: end,
|
||||
persistence: persistence,
|
||||
}
|
||||
|
||||
if err := ms.restoreVMs(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
@@ -127,6 +145,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
|
||||
return "", errors.Wrap(ErrFailedToAllocatePort, err)
|
||||
}
|
||||
ms.qemuCfg.HostFwdAgent = agentPort
|
||||
ms.qemuCfg.VSockConfig.GuestCID = qemu.BaseGuestCID + len(ms.vms)
|
||||
|
||||
ch, err := computationHash(ac)
|
||||
if err != nil {
|
||||
@@ -145,13 +164,24 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
|
||||
}
|
||||
ms.vms[c.Id] = cvm
|
||||
|
||||
pid := cvm.GetProcess()
|
||||
|
||||
state := qemu.VMState{
|
||||
ID: c.Id,
|
||||
Config: ms.qemuCfg,
|
||||
PID: pid,
|
||||
}
|
||||
if err := ms.persistence.SaveVM(state); err != nil {
|
||||
ms.logger.Error("Failed to persist VM state", "error", err)
|
||||
}
|
||||
|
||||
err = backoff.Retry(func() error {
|
||||
return cvm.SendAgentConfig(ac)
|
||||
}, backoff.NewExponentialBackOff())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ms.qemuCfg.VSockConfig.GuestCID++
|
||||
|
||||
ms.qemuCfg.VSockConfig.Vnc++
|
||||
|
||||
ms.publishEvent("vm-provision", c.Id, "complete", json.RawMessage{})
|
||||
@@ -169,6 +199,11 @@ func (ms *managerService) Stop(ctx context.Context, computationID string) error
|
||||
return err
|
||||
}
|
||||
delete(ms.vms, computationID)
|
||||
|
||||
if err := ms.persistence.DeleteVM(computationID); err != nil {
|
||||
ms.logger.Error("Failed to delete persisted VM state", "error", err)
|
||||
}
|
||||
|
||||
defer ms.publishEvent("stop-computation", computationID, "complete", json.RawMessage{})
|
||||
return nil
|
||||
}
|
||||
@@ -264,3 +299,53 @@ func decodeRange(input string) (int, int, error) {
|
||||
|
||||
return start, end, nil
|
||||
}
|
||||
|
||||
func (ms *managerService) restoreVMs() error {
|
||||
states, err := ms.persistence.LoadVMs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, state := range states {
|
||||
exists, err := processExists(state.PID)
|
||||
if err != nil {
|
||||
ms.logger.Warn("Failed to check process existence", "computation", state.ID, "pid", state.PID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !exists {
|
||||
if err := ms.persistence.DeleteVM(state.ID); err != nil {
|
||||
ms.logger.Error("Failed to delete persisted VM state", "computation", state.ID, "error", err)
|
||||
}
|
||||
ms.logger.Info("Deleted persisted state for non-existent process", "computation", state.ID, "pid", state.PID)
|
||||
continue
|
||||
}
|
||||
|
||||
cvm := ms.vmFactory(state.Config, ms.eventsChan, state.ID)
|
||||
|
||||
if err = cvm.SetProcess(state.PID); err != nil {
|
||||
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
ms.vms[state.ID] = cvm
|
||||
ms.logger.Info("Successfully restored VM state", "id", state.ID, "computationId", state.ID, "pid", state.PID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processExists(pid int) (bool, error) {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err = process.Signal(syscall.Signal(0)); err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if err == syscall.ESRCH {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
+18
-8
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/manager/vm/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
@@ -35,6 +36,7 @@ func TestNew(t *testing.T) {
|
||||
func TestRun(t *testing.T) {
|
||||
vmf := new(mocks.Provider)
|
||||
vmMock := new(mocks.VM)
|
||||
persistence := new(persistenceMocks.Persistence)
|
||||
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -79,6 +81,9 @@ func TestRun(t *testing.T) {
|
||||
}
|
||||
|
||||
vmMock.On("SendAgentConfig", mock.Anything).Return(nil)
|
||||
vmMock.On("GetProcess").Return(1234)
|
||||
|
||||
persistence.On("SaveVM", mock.Anything).Return(nil)
|
||||
|
||||
qemuCfg := qemu.Config{
|
||||
VSockConfig: qemu.VSockConfig{
|
||||
@@ -90,11 +95,12 @@ func TestRun(t *testing.T) {
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 10)
|
||||
|
||||
ms := &managerService{
|
||||
qemuCfg: qemuCfg,
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
vmFactory: vmf.Execute,
|
||||
qemuCfg: qemuCfg,
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
vmFactory: vmf.Execute,
|
||||
persistence: persistence,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -123,6 +129,7 @@ func TestRun(t *testing.T) {
|
||||
func TestStop(t *testing.T) {
|
||||
vmf := new(mocks.Provider)
|
||||
vmMock := new(mocks.VM)
|
||||
persistence := new(persistenceMocks.Persistence)
|
||||
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
|
||||
|
||||
tests := []struct {
|
||||
@@ -160,9 +167,10 @@ func TestStop(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 10)
|
||||
ms := &managerService{
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: eventsChan,
|
||||
persistence: persistence,
|
||||
}
|
||||
vmMock := new(mocks.VM)
|
||||
|
||||
@@ -172,6 +180,8 @@ func TestStop(t *testing.T) {
|
||||
vmMock.On("Stop").Return(assert.AnError).Once()
|
||||
}
|
||||
|
||||
persistence.On("DeleteVM", tt.computationID).Return(nil)
|
||||
|
||||
if tt.initialVMCount > 0 {
|
||||
ms.vms[tt.computationID] = vmMock
|
||||
}
|
||||
|
||||
@@ -15,6 +15,24 @@ type VM struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// GetProcess provides a mock function with given fields:
|
||||
func (_m *VM) GetProcess() int {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetProcess")
|
||||
}
|
||||
|
||||
var r0 int
|
||||
if rf, ok := ret.Get(0).(func() int); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(int)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// SendAgentConfig provides a mock function with given fields: ac
|
||||
func (_m *VM) SendAgentConfig(ac agent.Computation) error {
|
||||
ret := _m.Called(ac)
|
||||
@@ -33,6 +51,24 @@ func (_m *VM) SendAgentConfig(ac agent.Computation) error {
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetProcess provides a mock function with given fields: pid
|
||||
func (_m *VM) SetProcess(pid int) error {
|
||||
ret := _m.Called(pid)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SetProcess")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(int) error); ok {
|
||||
r0 = rf(pid)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields:
|
||||
func (_m *VM) Start() error {
|
||||
ret := _m.Called()
|
||||
|
||||
@@ -14,6 +14,8 @@ type VM interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
SendAgentConfig(ac agent.Computation) error
|
||||
SetProcess(pid int) error
|
||||
GetProcess() int
|
||||
}
|
||||
|
||||
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
|
||||
Reference in New Issue
Block a user