mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-454 - Implement graceful shutdown for services and add TTL management for VMs (#473)
* Implement graceful shutdown for services and add TTL management for VMs Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Remove unnecessary comment from go-tdx-guest dependency in go.mod Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Update manager/api/logging.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add TTL manager initialization in TestStop Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Fix logging format in Shutdown method for consistency Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add unit tests for TTL manager functionality Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Enhance TTL tests with mutex for thread safety in expiration checks Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add TTL parameter to CreateVM in TestRun for improved testing scenarios Signed-off-by: Sammy Oina <sammyoina@gmail.com> * Add Shutdown test to verify VM cleanup and TTL manager integration Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
f543cb4363
commit
45187d7f41
+11
-3
@@ -114,13 +114,21 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
svc, err := newService(ctx, logger, tracer, *qemuCfg, cfg.AttestationPolicyBinary, cfg.IgvmMeasureBinary, cfg.PcrValues, cfg.EosVersion)
|
||||
svc, err := newService(logger, tracer, *qemuCfg, cfg.AttestationPolicyBinary, cfg.IgvmMeasureBinary, cfg.PcrValues, cfg.EosVersion)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := svc.Shutdown(); err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
registerManagerServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(svc))
|
||||
@@ -141,8 +149,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(ctx context.Context, logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, attestationPolicyPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, eosVersion string) (manager.Service, error) {
|
||||
svc, err := manager.New(ctx, qemuCfg, attestationPolicyPath, igvmMeasurementBinaryPath, pcrValuesFilePath, logger, qemu.NewVM, eosVersion)
|
||||
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, attestationPolicyPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, eosVersion string) (manager.Service, error) {
|
||||
svc, err := manager.New(qemuCfg, attestationPolicyPath, igvmMeasurementBinaryPath, pcrValuesFilePath, logger, qemu.NewVM, eosVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -75,3 +75,20 @@ func (lm *loggingMiddleware) ReturnCVMInfo(ctx context.Context) (string, int, st
|
||||
|
||||
return lm.svc.ReturnCVMInfo(ctx)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Shutdown() (err error) {
|
||||
defer func(begin time.Time) {
|
||||
if err != nil {
|
||||
lm.logger.Warn("Method Shutdown completed with error",
|
||||
"time_taken", time.Since(begin),
|
||||
"error", err,
|
||||
)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Method Shutdown completed successfully",
|
||||
"time_taken", time.Since(begin),
|
||||
)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Shutdown()
|
||||
}
|
||||
|
||||
@@ -67,3 +67,12 @@ func (ms *metricsMiddleware) ReturnCVMInfo(ctx context.Context) (string, int, st
|
||||
|
||||
return ms.svc.ReturnCVMInfo(ctx)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Shutdown() error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "Shutdown").Add(1)
|
||||
ms.latency.With("method", "Shutdown").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.Shutdown()
|
||||
}
|
||||
|
||||
+28
-12
@@ -81,6 +81,8 @@ type Service interface {
|
||||
FetchAttestationPolicy(ctx context.Context, computationID string) ([]byte, error)
|
||||
// ReturnCVMInfo returns CVM information needed for attestation verification and validation.
|
||||
ReturnCVMInfo(ctx context.Context) (string, int, string, string)
|
||||
// Shutdown gracefully shuts down the service
|
||||
Shutdown() error
|
||||
}
|
||||
|
||||
type managerService struct {
|
||||
@@ -97,13 +99,13 @@ type managerService struct {
|
||||
portRangeMax int
|
||||
persistence qemu.Persistence
|
||||
eosVersion string
|
||||
ctx context.Context
|
||||
ttlManager *TTLManager
|
||||
}
|
||||
|
||||
var _ Service = (*managerService)(nil)
|
||||
|
||||
// New instantiates the manager service implementation.
|
||||
func New(ctx context.Context, cfg qemu.Config, attestationPolicyBinPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string) (Service, error) {
|
||||
func New(cfg qemu.Config, attestationPolicyBinPath string, igvmMeasurementBinaryPath string, pcrValuesFilePath string, logger *slog.Logger, vmFactory vm.Provider, eosVersion string) (Service, error) {
|
||||
start, end, err := decodeRange(cfg.HostFwdRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -126,7 +128,7 @@ func New(ctx context.Context, cfg qemu.Config, attestationPolicyBinPath string,
|
||||
portRangeMax: end,
|
||||
persistence: persistence,
|
||||
eosVersion: eosVersion,
|
||||
ctx: ctx,
|
||||
ttlManager: NewTTLManager(),
|
||||
}
|
||||
|
||||
if err := ms.restoreVMs(); err != nil {
|
||||
@@ -224,16 +226,13 @@ func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string,
|
||||
return "", id, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-time.After(ttl):
|
||||
if err := ms.RemoveVM(ctx, id); err != nil {
|
||||
ms.logger.Error("Failed to remove VM after TTL", "error", err)
|
||||
}
|
||||
case <-ms.ctx.Done():
|
||||
return
|
||||
ms.ttlManager.SetTTL(id, ttl, func() { //nolint:contextcheck
|
||||
if err := ms.RemoveVM(context.Background(), id); err != nil {
|
||||
ms.logger.Error("Failed to remove VM after TTL expiry", "vmID", id, "error", err)
|
||||
} else {
|
||||
ms.logger.Info("Successfully removed VM after TTL expiry", "vmID", id)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
pid := cvm.GetProcess()
|
||||
@@ -259,6 +258,9 @@ func (ms *managerService) CreateVM(ctx context.Context, req *CreateReq) (string,
|
||||
func (ms *managerService) RemoveVM(ctx context.Context, computationID string) error {
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
ms.ttlManager.CancelTTL(computationID)
|
||||
|
||||
cvm, ok := ms.vms[computationID]
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
@@ -279,6 +281,20 @@ func (ms *managerService) ReturnCVMInfo(ctx context.Context) (string, int, strin
|
||||
return ms.qemuCfg.OVMFCodeConfig.Version, ms.qemuCfg.SMPCount, ms.qemuCfg.CPU, ms.eosVersion
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the service.
|
||||
func (ms *managerService) Shutdown() error {
|
||||
ms.logger.Info("Shutting down manager service")
|
||||
|
||||
ms.ttlManager.CancelAll()
|
||||
|
||||
ms.mu.Lock()
|
||||
defer ms.mu.Unlock()
|
||||
|
||||
ms.vms = make(map[string]vm.VM)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getFreePort(minPort, maxPort int) (int, error) {
|
||||
if checkPortisFree(minPort) {
|
||||
return minPort, nil
|
||||
|
||||
+32
-2
@@ -30,7 +30,7 @@ func TestNew(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
vmf := new(mocks.Provider)
|
||||
|
||||
service, err := New(context.Background(), cfg, "", "", "", logger, vmf.Execute, "")
|
||||
service, err := New(cfg, "", "", "", logger, vmf.Execute, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, service)
|
||||
@@ -47,24 +47,35 @@ func TestRun(t *testing.T) {
|
||||
binaryBehavior string
|
||||
vmStartError error
|
||||
expectedError error
|
||||
ttl string
|
||||
}{
|
||||
{
|
||||
name: "Successful run",
|
||||
binaryBehavior: "success",
|
||||
vmStartError: nil,
|
||||
expectedError: nil,
|
||||
ttl: "",
|
||||
},
|
||||
{
|
||||
name: "VM start failure",
|
||||
binaryBehavior: "success",
|
||||
vmStartError: assert.AnError,
|
||||
expectedError: assert.AnError,
|
||||
ttl: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid attestation policy",
|
||||
binaryBehavior: "fail",
|
||||
vmStartError: nil,
|
||||
expectedError: ErrFailedToCreateAttestationPolicy,
|
||||
ttl: "",
|
||||
},
|
||||
{
|
||||
name: "With TTL",
|
||||
binaryBehavior: "success",
|
||||
vmStartError: nil,
|
||||
expectedError: nil,
|
||||
ttl: "10s",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -101,11 +112,12 @@ func TestRun(t *testing.T) {
|
||||
vms: make(map[string]vm.VM),
|
||||
vmFactory: vmf.Execute,
|
||||
persistence: persistence,
|
||||
ttlManager: NewTTLManager(),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
port, _, err := ms.CreateVM(ctx, &CreateReq{})
|
||||
port, _, err := ms.CreateVM(ctx, &CreateReq{Ttl: tt.ttl})
|
||||
|
||||
if tt.expectedError != nil {
|
||||
assert.Error(t, err)
|
||||
@@ -165,6 +177,7 @@ func TestStop(t *testing.T) {
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
persistence: persistence,
|
||||
ttlManager: NewTTLManager(),
|
||||
}
|
||||
vmMock := new(mocks.VM)
|
||||
|
||||
@@ -287,3 +300,20 @@ func TestProcessExists(t *testing.T) {
|
||||
assert.False(t, ms.processExists(1)) // PID 1 is usually the init process.
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
ms := &managerService{
|
||||
vms: make(map[string]vm.VM),
|
||||
ttlManager: NewTTLManager(),
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
vmMock := new(mocks.VM)
|
||||
vmMock.On("Stop").Return(nil).Once()
|
||||
ms.vms["test-vm"] = vmMock
|
||||
|
||||
err := ms.Shutdown()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Len(t, ms.vms, 0)
|
||||
}
|
||||
|
||||
@@ -48,3 +48,10 @@ func (tm *tracingMiddleware) ReturnCVMInfo(ctx context.Context) (string, int, st
|
||||
|
||||
return tm.svc.ReturnCVMInfo(ctx)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Shutdown() error {
|
||||
_, span := tm.tracer.Start(context.Background(), "shutdown")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Shutdown()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
sync "sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TTLManager handles TTL functionality for VMs.
|
||||
type TTLManager struct {
|
||||
timers map[string]*time.Timer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTTLManager creates a new TTL manager.
|
||||
func NewTTLManager() *TTLManager {
|
||||
return &TTLManager{
|
||||
timers: make(map[string]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
// SetTTL sets a TTL for a VM and returns a function to cancel it.
|
||||
func (tm *TTLManager) SetTTL(vmID string, ttl time.Duration, onExpiry func()) context.CancelFunc {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
if timer, exists := tm.timers[vmID]; exists {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
timer := time.AfterFunc(ttl, onExpiry)
|
||||
tm.timers[vmID] = timer
|
||||
|
||||
return func() {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
if t, exists := tm.timers[vmID]; exists {
|
||||
t.Stop()
|
||||
delete(tm.timers, vmID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CancelTTL cancels the TTL for a specific VM.
|
||||
func (tm *TTLManager) CancelTTL(vmID string) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
if timer, exists := tm.timers[vmID]; exists {
|
||||
timer.Stop()
|
||||
delete(tm.timers, vmID)
|
||||
}
|
||||
}
|
||||
|
||||
// CancelAll cancels all active TTLs.
|
||||
func (tm *TTLManager) CancelAll() {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
for vmID, timer := range tm.timers {
|
||||
timer.Stop()
|
||||
delete(tm.timers, vmID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTTLManager(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
if tm == nil {
|
||||
t.Fatal("NewTTLManager() returned nil")
|
||||
}
|
||||
|
||||
if tm.timers == nil {
|
||||
t.Fatal("NewTTLManager() did not initialize timers map")
|
||||
}
|
||||
|
||||
if len(tm.timers) != 0 {
|
||||
t.Errorf("NewTTLManager() timers map should be empty, got %d entries", len(tm.timers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetTTL_Basic(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
mu := sync.Mutex{}
|
||||
expired := false
|
||||
vmID := "test-vm-1"
|
||||
ttl := 50 * time.Millisecond
|
||||
|
||||
cancelFunc := tm.SetTTL(vmID, ttl, func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
expired = true
|
||||
})
|
||||
|
||||
tm.mu.RLock()
|
||||
if _, exists := tm.timers[vmID]; !exists {
|
||||
t.Error("Timer was not created for VM")
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
if !expired {
|
||||
t.Error("TTL did not expire as expected")
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
cancelFunc()
|
||||
|
||||
tm.mu.RLock()
|
||||
if _, exists := tm.timers[vmID]; exists {
|
||||
t.Error("Timer should be cleaned up after expiry")
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestSetTTL_CancelBeforeExpiry(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
expired := false
|
||||
vmID := "test-vm-2"
|
||||
ttl := 100 * time.Millisecond
|
||||
|
||||
cancelFunc := tm.SetTTL(vmID, ttl, func() {
|
||||
expired = true
|
||||
})
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancelFunc()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
if expired {
|
||||
t.Error("TTL should not have expired after being cancelled")
|
||||
}
|
||||
|
||||
tm.mu.RLock()
|
||||
if _, exists := tm.timers[vmID]; exists {
|
||||
t.Error("Timer should be cleaned up after cancellation")
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestSetTTL_OverwriteExistingTimer(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
mu := sync.Mutex{}
|
||||
firstExpired := false
|
||||
secondExpired := false
|
||||
vmID := "test-vm-3"
|
||||
|
||||
// Set first TTL
|
||||
tm.SetTTL(vmID, 200*time.Millisecond, func() {
|
||||
firstExpired = true
|
||||
})
|
||||
|
||||
// Immediately overwrite with second TTL
|
||||
tm.SetTTL(vmID, 50*time.Millisecond, func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
secondExpired = true
|
||||
})
|
||||
|
||||
// Wait for second TTL to expire
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if firstExpired {
|
||||
t.Error("First TTL should not have expired (it was overwritten)")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
if !secondExpired {
|
||||
t.Error("Second TTL should have expired")
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify only one timer entry exists (or none after cleanup)
|
||||
tm.mu.RLock()
|
||||
count := len(tm.timers)
|
||||
tm.mu.RUnlock()
|
||||
|
||||
if count > 1 {
|
||||
t.Errorf("Expected at most 1 timer entry, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetTTL_MultipleConcurrentTimers(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
numVMs := 5
|
||||
expiredCount := int32(0)
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < numVMs; i++ {
|
||||
vmID := fmt.Sprintf("vm-%d", i)
|
||||
tm.SetTTL(vmID, 50*time.Millisecond, func() {
|
||||
mu.Lock()
|
||||
expiredCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
tm.mu.RLock()
|
||||
if len(tm.timers) != numVMs {
|
||||
t.Errorf("Expected %d timers, got %d", numVMs, len(tm.timers))
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
finalCount := expiredCount
|
||||
mu.Unlock()
|
||||
|
||||
if int(finalCount) != numVMs {
|
||||
t.Errorf("Expected %d timers to expire, got %d", numVMs, finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelTTL_ExistingTimer(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
expired := false
|
||||
vmID := "test-vm-4"
|
||||
|
||||
tm.SetTTL(vmID, 100*time.Millisecond, func() {
|
||||
expired = true
|
||||
})
|
||||
|
||||
// Cancel the timer
|
||||
tm.CancelTTL(vmID)
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
if expired {
|
||||
t.Error("TTL should not have expired after being cancelled")
|
||||
}
|
||||
|
||||
// Verify timer was removed
|
||||
tm.mu.RLock()
|
||||
if _, exists := tm.timers[vmID]; exists {
|
||||
t.Error("Timer should be removed after cancellation")
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestCancelTTL_NonExistentTimer(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
// Should not panic when cancelling non-existent timer
|
||||
tm.CancelTTL("non-existent-vm")
|
||||
|
||||
// Verify timers map is still empty
|
||||
tm.mu.RLock()
|
||||
if len(tm.timers) != 0 {
|
||||
t.Errorf("Expected empty timers map, got %d entries", len(tm.timers))
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestCancelAll_MultipleTimers(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
numVMs := 3
|
||||
expiredCount := int32(0)
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < numVMs; i++ {
|
||||
vmID := fmt.Sprintf("vm-%d", i)
|
||||
tm.SetTTL(vmID, 200*time.Millisecond, func() {
|
||||
mu.Lock()
|
||||
expiredCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
tm.mu.RLock()
|
||||
if len(tm.timers) != numVMs {
|
||||
t.Errorf("Expected %d timers, got %d", numVMs, len(tm.timers))
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
|
||||
tm.CancelAll()
|
||||
|
||||
tm.mu.RLock()
|
||||
if len(tm.timers) != 0 {
|
||||
t.Errorf("Expected 0 timers after CancelAll, got %d", len(tm.timers))
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
finalCount := expiredCount
|
||||
mu.Unlock()
|
||||
|
||||
if finalCount != 0 {
|
||||
t.Errorf("Expected 0 timers to expire after CancelAll, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCancelAll_EmptyManager(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
tm.CancelAll()
|
||||
|
||||
tm.mu.RLock()
|
||||
if len(tm.timers) != 0 {
|
||||
t.Errorf("Expected empty timers map, got %d entries", len(tm.timers))
|
||||
}
|
||||
tm.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
vmID := fmt.Sprintf("concurrent-vm-%d", id)
|
||||
cancelFunc := tm.SetTTL(vmID, 100*time.Millisecond, func() {})
|
||||
|
||||
// Sometimes cancel immediately
|
||||
if id%2 == 0 {
|
||||
cancelFunc()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
vmID := fmt.Sprintf("concurrent-vm-%d", id)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
tm.CancelTTL(vmID)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
tm.CancelAll()
|
||||
|
||||
// This test primarily checks that no race conditions occur
|
||||
// The actual state at the end is unpredictable due to timing
|
||||
}
|
||||
|
||||
func TestSetTTL_ZeroDuration(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
mu := sync.Mutex{}
|
||||
expired := false
|
||||
vmID := "zero-duration-vm"
|
||||
|
||||
tm.SetTTL(vmID, 0, func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
expired = true
|
||||
})
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
if !expired {
|
||||
t.Error("TTL with zero duration should expire immediately")
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func TestSetTTL_NegativeDuration(t *testing.T) {
|
||||
tm := NewTTLManager()
|
||||
|
||||
mu := sync.Mutex{}
|
||||
expired := false
|
||||
vmID := "negative-duration-vm"
|
||||
|
||||
tm.SetTTL(vmID, -100*time.Millisecond, func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
expired = true
|
||||
})
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if !expired {
|
||||
t.Error("TTL with negative duration should expire immediately")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user