COCOS-454 - Implement graceful shutdown for services and add TTL management for VMs (#473)

* Implement graceful shutdown for services and add TTL management for VMs

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Remove unnecessary comment from go-tdx-guest dependency in go.mod

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Update manager/api/logging.go

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Add TTL manager initialization in TestStop

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Fix logging format in Shutdown method for consistency

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add unit tests for TTL manager functionality

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Enhance TTL tests with mutex for thread safety in expiration checks

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add TTL parameter to CreateVM in TestRun for improved testing scenarios

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* Add Shutdown test to verify VM cleanup and TTL manager integration

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Sammy Kerata Oina
2025-07-11 17:21:29 +03:00
committed by GitHub
parent f543cb4363
commit 45187d7f41
8 changed files with 512 additions and 17 deletions
+11 -3
View File
@@ -114,13 +114,21 @@ func main() {
return
}
svc, err := newService(ctx, logger, tracer, *qemuCfg, cfg.AttestationPolicyBinary, cfg.IgvmMeasureBinary, cfg.PcrValues, cfg.EosVersion)
svc, err := newService(logger, tracer, *qemuCfg, cfg.AttestationPolicyBinary, cfg.IgvmMeasureBinary, cfg.PcrValues, cfg.EosVersion)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer func() {
if err := svc.Shutdown(); err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
}()
registerManagerServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
@@ -141,8 +149,8 @@ func main() {
}
}
func newService(ctx context.Context, logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, attestationPolicyPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, eosVersion string) (manager.Service, error) {
svc, err := manager.New(ctx, qemuCfg, attestationPolicyPath, igvmMeasurementBinaryPath, pcrValuesFilePath, logger, qemu.NewVM, eosVersion)
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, attestationPolicyPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, eosVersion string) (manager.Service, error) {
svc, err := manager.New(qemuCfg, attestationPolicyPath, igvmMeasurementBinaryPath, pcrValuesFilePath, logger, qemu.NewVM, eosVersion)
if err != nil {
return nil, err
}
+17
View File
@@ -75,3 +75,20 @@ func (lm *loggingMiddleware) ReturnCVMInfo(ctx context.Context) (string, int, st
return lm.svc.ReturnCVMInfo(ctx)
}
func (lm *loggingMiddleware) Shutdown() (err error) {
defer func(begin time.Time) {
if err != nil {
lm.logger.Warn("Method Shutdown completed with error",
"time_taken", time.Since(begin),
"error", err,
)
return
}
lm.logger.Info("Method Shutdown completed successfully",
"time_taken", time.Since(begin),
)
}(time.Now())
return lm.svc.Shutdown()
}
+9
View File
@@ -67,3 +67,12 @@ func (ms *metricsMiddleware) ReturnCVMInfo(ctx context.Context) (string, int, st
return ms.svc.ReturnCVMInfo(ctx)
}
func (ms *metricsMiddleware) Shutdown() error {
defer func(begin time.Time) {
ms.counter.With("method", "Shutdown").Add(1)
ms.latency.With("method", "Shutdown").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.Shutdown()
}
+28 -12
View File
@@ -81,6 +81,8 @@ type Service interface {
FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error)
// ReturnCVMInfo returns CVM information needed for attestation verification and validation.
ReturnCVMInfo(ctx context.Context) (string, int, string, string)
// Shutdown gracefully shuts down the service
Shutdown() error
}
type managerService struct {
@@ -97,13 +99,13 @@ type managerService struct {
portRangeMax int
persistence qemu.Persistence
eosVersion string
ctx context.Context
ttlManager *TTLManager
}
var _ Service = (*managerService)(nil)
// New instantiates the manager service implementation.
func New(ctx context.Context, cfg qemu.Config, attestationPolicyBinPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string) (Service, error) {
func New(cfg qemu.Config, attestationPolicyBinPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string) (Service, error) {
start, end, err := decodeRange(cfg.HostFwdRange)
if err != nil {
return nil, err
@@ -126,7 +128,7 @@ func New(ctx context.Context, cfg qemu.Config, attestationPolicyBinPath string,
portRangeMax: end,
persistence: persistence,
eosVersion: eosVersion,
ctx: ctx,
ttlManager: NewTTLManager(),
}
if err := ms.restoreVMs(); err != nil {
@@ -224,16 +226,13 @@ func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string,
return "", id, err
}
go func() {
select {
case <-time.After(ttl):
if err := ms.RemoveVM(ctx, id); err != nil {
ms.logger.Error("Failed to remove VM after TTL", "error", err)
}
case <-ms.ctx.Done():
return
ms.ttlManager.SetTTL(id, ttl, func() { //nolint:contextcheck
if err := ms.RemoveVM(context.Background(), id); err != nil {
ms.logger.Error("Failed to remove VM after TTL expiry", "vmID", id, "error", err)
} else {
ms.logger.Info("Successfully removed VM after TTL expiry", "vmID", id)
}
}()
})
}
pid := cvm.GetProcess()
@@ -259,6 +258,9 @@ func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string,
func (ms *managerService) RemoveVM(ctx context.Context, computationID string) error {
ms.mu.Lock()
defer ms.mu.Unlock()
ms.ttlManager.CancelTTL(computationID)
cvm, ok := ms.vms[computationID]
if !ok {
return ErrNotFound
@@ -279,6 +281,20 @@ func (ms *managerService) ReturnCVMInfo(ctx context.Context) (string, int, strin
return ms.qemuCfg.OVMFCodeConfig.Version, ms.qemuCfg.SMPCount, ms.qemuCfg.CPU, ms.eosVersion
}
// Shutdown gracefully shuts down the service.
func (ms *managerService) Shutdown() error {
ms.logger.Info("Shutting down manager service")
ms.ttlManager.CancelAll()
ms.mu.Lock()
defer ms.mu.Unlock()
ms.vms = make(map[string]vm.VM)
return nil
}
func getFreePort(minPort, maxPort int) (int, error) {
if checkPortisFree(minPort) {
return minPort, nil
+32 -2
View File
@@ -30,7 +30,7 @@ func TestNew(t *testing.T) {
logger := slog.Default()
vmf := new(mocks.Provider)
service, err := New(context.Background(), cfg, "", "", "", logger, vmf.Execute, "")
service, err := New(cfg, "", "", "", logger, vmf.Execute, "")
require.NoError(t, err)
assert.NotNil(t, service)
@@ -47,24 +47,35 @@ func TestRun(t *testing.T) {
binaryBehavior string
vmStartError error
expectedError error
ttl string
}{
{
name: "Successful run",
binaryBehavior: "success",
vmStartError: nil,
expectedError: nil,
ttl: "",
},
{
name: "VM start failure",
binaryBehavior: "success",
vmStartError: assert.AnError,
expectedError: assert.AnError,
ttl: "",
},
{
name: "Invalid attestation policy",
binaryBehavior: "fail",
vmStartError: nil,
expectedError: ErrFailedToCreateAttestationPolicy,
ttl: "",
},
{
name: "With TTL",
binaryBehavior: "success",
vmStartError: nil,
expectedError: nil,
ttl: "10s",
},
}
@@ -101,11 +112,12 @@ func TestRun(t *testing.T) {
vms: make(map[string]vm.VM),
vmFactory: vmf.Execute,
persistence: persistence,
ttlManager: NewTTLManager(),
}
ctx := context.Background()
port, _, err := ms.CreateVM(ctx, &CreateReq{})
port, _, err := ms.CreateVM(ctx, &CreateReq{Ttl: tt.ttl})
if tt.expectedError != nil {
assert.Error(t, err)
@@ -165,6 +177,7 @@ func TestStop(t *testing.T) {
logger: logger,
vms: make(map[string]vm.VM),
persistence: persistence,
ttlManager: NewTTLManager(),
}
vmMock := new(mocks.VM)
@@ -287,3 +300,20 @@ func TestProcessExists(t *testing.T) {
assert.False(t, ms.processExists(1)) // PID 1 is usually the init process.
}
}
func TestShutdown(t *testing.T) {
ms := &managerService{
vms: make(map[string]vm.VM),
ttlManager: NewTTLManager(),
logger: mglog.NewMock(),
}
vmMock := new(mocks.VM)
vmMock.On("Stop").Return(nil).Once()
ms.vms["test-vm"] = vmMock
err := ms.Shutdown()
assert.NoError(t, err)
assert.Len(t, ms.vms, 0)
}
+7
View File
@@ -48,3 +48,10 @@ func (tm *tracingMiddleware) ReturnCVMInfo(ctx context.Context) (string, int, st
return tm.svc.ReturnCVMInfo(ctx)
}
func (tm *tracingMiddleware) Shutdown() error {
_, span := tm.tracer.Start(context.Background(), "shutdown")
defer span.End()
return tm.svc.Shutdown()
}
+66
View File
@@ -0,0 +1,66 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager
import (
"context"
sync "sync"
"time"
)
// TTLManager handles TTL functionality for VMs.
type TTLManager struct {
timers map[string]*time.Timer
mu sync.RWMutex
}
// NewTTLManager creates a new TTL manager.
func NewTTLManager() *TTLManager {
return &TTLManager{
timers: make(map[string]*time.Timer),
}
}
// SetTTL sets a TTL for a VM and returns a function to cancel it.
func (tm *TTLManager) SetTTL(vmID string, ttl time.Duration, onExpiry func()) context.CancelFunc {
tm.mu.Lock()
defer tm.mu.Unlock()
if timer, exists := tm.timers[vmID]; exists {
timer.Stop()
}
timer := time.AfterFunc(ttl, onExpiry)
tm.timers[vmID] = timer
return func() {
tm.mu.Lock()
defer tm.mu.Unlock()
if t, exists := tm.timers[vmID]; exists {
t.Stop()
delete(tm.timers, vmID)
}
}
}
// CancelTTL cancels the TTL for a specific VM.
func (tm *TTLManager) CancelTTL(vmID string) {
tm.mu.Lock()
defer tm.mu.Unlock()
if timer, exists := tm.timers[vmID]; exists {
timer.Stop()
delete(tm.timers, vmID)
}
}
// CancelAll cancels all active TTLs.
func (tm *TTLManager) CancelAll() {
tm.mu.Lock()
defer tm.mu.Unlock()
for vmID, timer := range tm.timers {
timer.Stop()
delete(tm.timers, vmID)
}
}
+342
View File
@@ -0,0 +1,342 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package manager
import (
"fmt"
"sync"
"testing"
"time"
)
func TestNewTTLManager(t *testing.T) {
tm := NewTTLManager()
if tm == nil {
t.Fatal("NewTTLManager() returned nil")
}
if tm.timers == nil {
t.Fatal("NewTTLManager() did not initialize timers map")
}
if len(tm.timers) != 0 {
t.Errorf("NewTTLManager() timers map should be empty, got %d entries", len(tm.timers))
}
}
func TestSetTTL_Basic(t *testing.T) {
tm := NewTTLManager()
mu := sync.Mutex{}
expired := false
vmID := "test-vm-1"
ttl := 50 * time.Millisecond
cancelFunc := tm.SetTTL(vmID, ttl, func() {
mu.Lock()
defer mu.Unlock()
expired = true
})
tm.mu.RLock()
if _, exists := tm.timers[vmID]; !exists {
t.Error("Timer was not created for VM")
}
tm.mu.RUnlock()
time.Sleep(100 * time.Millisecond)
mu.Lock()
if !expired {
t.Error("TTL did not expire as expected")
}
mu.Unlock()
cancelFunc()
tm.mu.RLock()
if _, exists := tm.timers[vmID]; exists {
t.Error("Timer should be cleaned up after expiry")
}
tm.mu.RUnlock()
}
func TestSetTTL_CancelBeforeExpiry(t *testing.T) {
tm := NewTTLManager()
expired := false
vmID := "test-vm-2"
ttl := 100 * time.Millisecond
cancelFunc := tm.SetTTL(vmID, ttl, func() {
expired = true
})
time.Sleep(20 * time.Millisecond)
cancelFunc()
time.Sleep(150 * time.Millisecond)
if expired {
t.Error("TTL should not have expired after being cancelled")
}
tm.mu.RLock()
if _, exists := tm.timers[vmID]; exists {
t.Error("Timer should be cleaned up after cancellation")
}
tm.mu.RUnlock()
}
func TestSetTTL_OverwriteExistingTimer(t *testing.T) {
tm := NewTTLManager()
mu := sync.Mutex{}
firstExpired := false
secondExpired := false
vmID := "test-vm-3"
// Set first TTL
tm.SetTTL(vmID, 200*time.Millisecond, func() {
firstExpired = true
})
// Immediately overwrite with second TTL
tm.SetTTL(vmID, 50*time.Millisecond, func() {
mu.Lock()
defer mu.Unlock()
secondExpired = true
})
// Wait for second TTL to expire
time.Sleep(100 * time.Millisecond)
if firstExpired {
t.Error("First TTL should not have expired (it was overwritten)")
}
mu.Lock()
if !secondExpired {
t.Error("Second TTL should have expired")
}
mu.Unlock()
// Verify only one timer entry exists (or none after cleanup)
tm.mu.RLock()
count := len(tm.timers)
tm.mu.RUnlock()
if count > 1 {
t.Errorf("Expected at most 1 timer entry, got %d", count)
}
}
func TestSetTTL_MultipleConcurrentTimers(t *testing.T) {
tm := NewTTLManager()
numVMs := 5
expiredCount := int32(0)
var mu sync.Mutex
for i := 0; i < numVMs; i++ {
vmID := fmt.Sprintf("vm-%d", i)
tm.SetTTL(vmID, 50*time.Millisecond, func() {
mu.Lock()
expiredCount++
mu.Unlock()
})
}
tm.mu.RLock()
if len(tm.timers) != numVMs {
t.Errorf("Expected %d timers, got %d", numVMs, len(tm.timers))
}
tm.mu.RUnlock()
time.Sleep(100 * time.Millisecond)
mu.Lock()
finalCount := expiredCount
mu.Unlock()
if int(finalCount) != numVMs {
t.Errorf("Expected %d timers to expire, got %d", numVMs, finalCount)
}
}
func TestCancelTTL_ExistingTimer(t *testing.T) {
tm := NewTTLManager()
expired := false
vmID := "test-vm-4"
tm.SetTTL(vmID, 100*time.Millisecond, func() {
expired = true
})
// Cancel the timer
tm.CancelTTL(vmID)
time.Sleep(150 * time.Millisecond)
if expired {
t.Error("TTL should not have expired after being cancelled")
}
// Verify timer was removed
tm.mu.RLock()
if _, exists := tm.timers[vmID]; exists {
t.Error("Timer should be removed after cancellation")
}
tm.mu.RUnlock()
}
func TestCancelTTL_NonExistentTimer(t *testing.T) {
tm := NewTTLManager()
// Should not panic when cancelling non-existent timer
tm.CancelTTL("non-existent-vm")
// Verify timers map is still empty
tm.mu.RLock()
if len(tm.timers) != 0 {
t.Errorf("Expected empty timers map, got %d entries", len(tm.timers))
}
tm.mu.RUnlock()
}
func TestCancelAll_MultipleTimers(t *testing.T) {
tm := NewTTLManager()
numVMs := 3
expiredCount := int32(0)
var mu sync.Mutex
for i := 0; i < numVMs; i++ {
vmID := fmt.Sprintf("vm-%d", i)
tm.SetTTL(vmID, 200*time.Millisecond, func() {
mu.Lock()
expiredCount++
mu.Unlock()
})
}
tm.mu.RLock()
if len(tm.timers) != numVMs {
t.Errorf("Expected %d timers, got %d", numVMs, len(tm.timers))
}
tm.mu.RUnlock()
tm.CancelAll()
tm.mu.RLock()
if len(tm.timers) != 0 {
t.Errorf("Expected 0 timers after CancelAll, got %d", len(tm.timers))
}
tm.mu.RUnlock()
time.Sleep(250 * time.Millisecond)
mu.Lock()
finalCount := expiredCount
mu.Unlock()
if finalCount != 0 {
t.Errorf("Expected 0 timers to expire after CancelAll, got %d", finalCount)
}
}
func TestCancelAll_EmptyManager(t *testing.T) {
tm := NewTTLManager()
tm.CancelAll()
tm.mu.RLock()
if len(tm.timers) != 0 {
t.Errorf("Expected empty timers map, got %d entries", len(tm.timers))
}
tm.mu.RUnlock()
}
func TestConcurrentAccess(t *testing.T) {
tm := NewTTLManager()
var wg sync.WaitGroup
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
vmID := fmt.Sprintf("concurrent-vm-%d", id)
cancelFunc := tm.SetTTL(vmID, 100*time.Millisecond, func() {})
// Sometimes cancel immediately
if id%2 == 0 {
cancelFunc()
}
}(i)
}
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
vmID := fmt.Sprintf("concurrent-vm-%d", id)
time.Sleep(10 * time.Millisecond)
tm.CancelTTL(vmID)
}(i)
}
wg.Wait()
tm.CancelAll()
// This test primarily checks that no race conditions occur
// The actual state at the end is unpredictable due to timing
}
func TestSetTTL_ZeroDuration(t *testing.T) {
tm := NewTTLManager()
mu := sync.Mutex{}
expired := false
vmID := "zero-duration-vm"
tm.SetTTL(vmID, 0, func() {
mu.Lock()
defer mu.Unlock()
expired = true
})
time.Sleep(10 * time.Millisecond)
mu.Lock()
if !expired {
t.Error("TTL with zero duration should expire immediately")
}
mu.Unlock()
}
func TestSetTTL_NegativeDuration(t *testing.T) {
tm := NewTTLManager()
mu := sync.Mutex{}
expired := false
vmID := "negative-duration-vm"
tm.SetTTL(vmID, -100*time.Millisecond, func() {
mu.Lock()
defer mu.Unlock()
expired = true
})
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if !expired {
t.Error("TTL with negative duration should expire immediately")
}
}