COCOS-278 - Abstract state machine (#280)

* abstract state machine

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* perpetual results consumption

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* async action

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix failing tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix failing test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-09 14:19:12 +03:00
committed by GitHub
parent fb0fbaeb9a
commit db7f3c7a4b
11 changed files with 661 additions and 311 deletions
+29
View File
@@ -0,0 +1,29 @@
// Code generated by "stringer -type=AgentEvent"; DO NOT EDIT.
package agent
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Start-0]
_ = x[ManifestReceived-1]
_ = x[AlgorithmReceived-2]
_ = x[DataReceived-3]
_ = x[RunComplete-4]
_ = x[ResultsConsumed-5]
_ = x[RunFailed-6]
}
const _AgentEvent_name = "StartManifestReceivedAlgorithmReceivedDataReceivedRunCompleteResultsConsumedRunFailed"
var _AgentEvent_index = [...]uint8{0, 5, 21, 38, 50, 61, 76, 85}
func (i AgentEvent) String() string {
if i < 0 || i >= AgentEvent(len(_AgentEvent_index)-1) {
return "AgentEvent(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _AgentEvent_name[_AgentEvent_index[i]:_AgentEvent_index[i+1]]
}
+30
View File
@@ -0,0 +1,30 @@
// Code generated by "stringer -type=AgentState"; DO NOT EDIT.
package agent
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Idle-0]
_ = x[ReceivingManifest-1]
_ = x[ReceivingAlgorithm-2]
_ = x[ReceivingData-3]
_ = x[Running-4]
_ = x[ConsumingResults-5]
_ = x[Complete-6]
_ = x[Failed-7]
}
const _AgentState_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailed"
var _AgentState_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89}
func (i AgentState) String() string {
if i < 0 || i >= AgentState(len(_AgentState_index)-1) {
return "AgentState(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _AgentState_name[_AgentState_index[i]:_AgentState_index[i+1]]
}
+124 -53
View File
@@ -20,12 +20,52 @@ import (
"github.com/ultravioletrs/cocos/agent/algorithm/python"
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
"github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/agent/statemachine"
"github.com/ultravioletrs/cocos/internal"
"golang.org/x/crypto/sha3"
)
var _ Service = (*agentService)(nil)
//go:generate stringer -type=AgentState
type AgentState int
const (
Idle AgentState = iota
ReceivingManifest
ReceivingAlgorithm
ReceivingData
Running
ConsumingResults
Complete
Failed
)
//go:generate stringer -type=AgentEvent
type AgentEvent int
const (
Start AgentEvent = iota
ManifestReceived
AlgorithmReceived
DataReceived
RunComplete
ResultsConsumed
RunFailed
)
//go:generate stringer -type=Status
type Status uint8
const (
IdleState Status = iota
InProgress
Ready
Completed
Terminated
Warning
)
const (
// ReportDataSize is the size of the report data expected by the attestation service.
ReportDataSize = 64
@@ -71,40 +111,69 @@ type Service interface {
}
type agentService struct {
computation Computation // Holds the current computation request details.
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
result []byte // Stores the result of the computation.
sm *StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
eventSvc events.Service // Service for publishing events related to computation.
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
computation Computation // Holds the current computation request details.
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
result []byte // Stores the result of the computation.
sm statemachine.StateMachine // Manages the state transitions of the agent service.
runError error // Stores any error encountered during the computation run.
eventSvc events.Service // Service for publishing events related to computation.
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
logger *slog.Logger // Logger for the agent service.
resultsConsumed bool // Indicates if the results have been consumed.
}
var _ Service = (*agentService)(nil)
// New instantiates the agent service implementation.
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation, quoteProvider client.QuoteProvider) Service {
sm := statemachine.NewStateMachine(Idle)
svc := &agentService{
sm: NewStateMachine(logger, cmp),
sm: sm,
eventSvc: eventSvc,
quoteProvider: quoteProvider,
logger: logger,
computation: cmp,
}
svc.sm.StateFunctions[Idle] = svc.publishEvent(IdleState.String(), json.RawMessage{})
svc.sm.StateFunctions[ReceivingManifest] = svc.publishEvent(InProgress.String(), json.RawMessage{})
svc.sm.StateFunctions[ReceivingAlgorithm] = svc.publishEvent(InProgress.String(), json.RawMessage{})
svc.sm.StateFunctions[ReceivingData] = svc.publishEvent(InProgress.String(), json.RawMessage{})
svc.sm.StateFunctions[ConsumingResults] = svc.publishEvent(Ready.String(), json.RawMessage{})
svc.sm.StateFunctions[Complete] = svc.publishEvent(Completed.String(), json.RawMessage{})
svc.sm.StateFunctions[Running] = svc.runComputation
svc.sm.StateFunctions[Failed] = svc.publishEvent(Failed.String(), json.RawMessage{})
transitions := []statemachine.Transition{
{From: Idle, Event: Start, To: ReceivingManifest},
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
}
go svc.sm.Start(ctx)
svc.sm.SendEvent(start)
if len(cmp.Datasets) == 0 {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
} else {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
}
svc.computation = cmp
transitions = append(transitions, []statemachine.Transition{
{From: Running, Event: RunComplete, To: ConsumingResults},
{From: Running, Event: RunFailed, To: Failed},
{From: ConsumingResults, Event: ResultsConsumed, To: Complete},
}...)
for _, t := range transitions {
sm.AddTransition(t)
}
sm.SetAction(Idle, svc.publishEvent(IdleState.String()))
sm.SetAction(ReceivingManifest, svc.publishEvent(InProgress.String()))
sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String()))
sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String()))
sm.SetAction(Running, svc.runComputation)
sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String()))
sm.SetAction(Complete, svc.publishEvent(Completed.String()))
sm.SetAction(Failed, svc.publishEvent(Failed.String()))
go func() {
if err := sm.Start(ctx); err != nil {
logger.Error(err.Error())
}
}()
sm.SendEvent(Start)
defer sm.SendEvent(ManifestReceived)
svc.sm.SendEvent(manifestReceived)
return svc
}
@@ -153,7 +222,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
switch algoType {
case string(algorithm.AlgoTypeBin):
as.algorithm = binary.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args)
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
case string(algorithm.AlgoTypePython):
var requirementsFile string
if len(algo.Requirements) > 0 {
@@ -171,11 +240,11 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
requirementsFile = fr.Name()
}
runtime := python.PythonRunTimeFromContext(ctx)
as.algorithm = python.NewAlgorithm(as.sm.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args)
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args)
case string(algorithm.AlgoTypeWasm):
as.algorithm = wasm.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args)
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
case string(algorithm.AlgoTypeDocker):
as.algorithm = docker.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name())
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name())
}
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
@@ -183,7 +252,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
}
if as.algorithm != nil {
as.sm.SendEvent(algorithmReceived)
as.sm.SendEvent(AlgorithmReceived)
}
return nil
@@ -236,27 +305,30 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
}
if len(as.computation.Datasets) == 0 {
defer as.sm.SendEvent(dataReceived)
defer as.sm.SendEvent(DataReceived)
}
return nil
}
func (as *agentService) Result(ctx context.Context) ([]byte, error) {
if as.sm.GetState() != ConsumingResults && as.sm.GetState() != Failed {
currentState := as.sm.GetState()
if currentState != ConsumingResults && currentState != Complete && currentState != Failed {
return []byte{}, ErrResultsNotReady
}
if len(as.computation.ResultConsumers) == 0 {
return []byte{}, ErrAllResultsConsumed
}
index, ok := IndexFromContext(ctx)
if !ok {
return []byte{}, ErrUndeclaredConsumer
}
as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1)
if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults {
defer as.sm.SendEvent(resultsConsumed)
if index < 0 || index >= len(as.computation.ResultConsumers) {
return []byte{}, ErrUndeclaredConsumer
}
if !as.resultsConsumed && currentState == ConsumingResults {
as.resultsConsumed = true
defer as.sm.SendEvent(ResultsConsumed)
}
return as.result, as.runError
@@ -271,59 +343,58 @@ func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataS
return rawQuote, nil
}
func (as *agentService) runComputation() {
as.publishEvent(InProgress.String(), json.RawMessage{})()
as.sm.logger.Debug("computation run started")
func (as *agentService) runComputation(state statemachine.State) {
as.publishEvent(InProgress.String())(state)
as.logger.Debug("computation run started")
defer func() {
if as.runError != nil {
as.sm.SendEvent(runFailed)
as.sm.SendEvent(RunFailed)
} else {
as.sm.SendEvent(runComplete)
as.sm.SendEvent(RunComplete)
}
}()
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
as.sm.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String(), json.RawMessage{})()
as.logger.Warn(as.runError.Error())
as.publishEvent(Failed.String())(state)
return
}
defer func() {
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
as.sm.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
}
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
as.sm.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
}
}()
as.publishEvent(InProgress.String(), json.RawMessage{})()
as.publishEvent(InProgress.String())(state)
if err := as.algorithm.Run(); err != nil {
as.runError = err
as.sm.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
as.publishEvent(Failed.String(), json.RawMessage{})()
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
if err != nil {
as.runError = err
as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
as.publishEvent(Failed.String(), json.RawMessage{})()
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
as.publishEvent(Failed.String())(state)
return
}
as.publishEvent(Completed.String(), json.RawMessage{})()
as.publishEvent(Completed.String())(state)
as.result = results
}
func (as *agentService) publishEvent(status string, details json.RawMessage) func() {
return func() {
st := as.sm.GetState().String()
if err := as.eventSvc.SendEvent(st, status, details); err != nil {
as.sm.logger.Warn(err.Error())
func (as *agentService) publishEvent(status string) statemachine.Action {
return func(state statemachine.State) {
if err := as.eventSvc.SendEvent(state.String(), status, json.RawMessage{}); err != nil {
as.logger.Warn(err.Error())
}
}
}
+17 -15
View File
@@ -20,6 +20,8 @@ import (
"github.com/ultravioletrs/cocos/agent/events/mocks"
"github.com/ultravioletrs/cocos/agent/quoteprovider"
mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
"github.com/ultravioletrs/cocos/agent/statemachine"
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc/metadata"
)
@@ -249,45 +251,36 @@ func TestResult(t *testing.T) {
err error
setup func(svc *agentService)
ctxSetup func(ctx context.Context) context.Context
state statemachine.State
}{
{
name: "Test results not ready",
err: ErrResultsNotReady,
setup: func(svc *agentService) {
},
},
{
name: "Test all results consumed",
err: ErrAllResultsConsumed,
setup: func(svc *agentService) {
svc.sm.SetState(ConsumingResults)
svc.computation.ResultConsumers = []ResultConsumer{}
},
ctxSetup: func(ctx context.Context) context.Context {
return IndexToContext(ctx, 0)
},
state: Running,
},
{
name: "Test undeclared consumer",
err: ErrUndeclaredConsumer,
setup: func(svc *agentService) {
svc.sm.SetState(ConsumingResults)
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
},
ctxSetup: func(ctx context.Context) context.Context {
return ctx
},
state: ConsumingResults,
},
{
name: "Test results consumed and event sent",
err: nil,
setup: func(svc *agentService) {
svc.sm.SetState(ConsumingResults)
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
},
ctxSetup: func(ctx context.Context) context.Context {
return IndexToContext(ctx, 0)
},
state: ConsumingResults,
},
}
@@ -301,14 +294,23 @@ func TestResult(t *testing.T) {
ctx = tc.ctxSetup(ctx)
}
sm := new(smmocks.StateMachine)
sm.On("Start", ctx).Return(nil)
sm.On("GetState").Return(tc.state)
sm.On("SendEvent", mock.Anything).Return()
svc := &agentService{
sm: NewStateMachine(mglog.NewMock(), testComputation(t)),
sm: sm,
eventSvc: events,
quoteProvider: qp,
computation: testComputation(t),
}
go svc.sm.Start(ctx)
go func() {
if err := svc.sm.Start(ctx); err != nil {
t.Errorf("Error starting state machine: %v", err)
}
}()
tc.setup(svc)
_, err := svc.Result(ctx)
t.Cleanup(func() {
-150
View File
@@ -1,150 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package agent
import (
"context"
"fmt"
"log/slog"
"sync"
)
//go:generate stringer -type=State
type State uint8
const (
Idle State = iota
ReceivingManifest
ReceivingAlgorithm
ReceivingData
Running
ConsumingResults
Complete
Failed
AlgorithmRun
)
//go:generate stringer -type=Status
type Status uint8
const (
IdleState Status = iota
InProgress
Ready
Completed
Terminated
Warning
)
type event uint8
const (
start event = iota
manifestReceived
algorithmReceived
dataReceived
runComplete
resultsConsumed
runFailed
)
// StateMachine represents the state machine.
type StateMachine struct {
mu sync.Mutex
State State
EventChan chan event
Transitions map[State]map[event]State
StateFunctions map[State]func()
logger *slog.Logger
wg *sync.WaitGroup
}
// NewStateMachine creates a new StateMachine.
func NewStateMachine(logger *slog.Logger, cmp Computation) *StateMachine {
sm := &StateMachine{
State: Idle,
EventChan: make(chan event),
Transitions: make(map[State]map[event]State),
StateFunctions: make(map[State]func()),
logger: logger,
wg: &sync.WaitGroup{},
}
sm.Transitions[Idle] = make(map[event]State)
sm.Transitions[Idle][start] = ReceivingManifest
sm.Transitions[ReceivingManifest] = make(map[event]State)
sm.Transitions[ReceivingManifest][manifestReceived] = ReceivingAlgorithm
sm.Transitions[ReceivingAlgorithm] = make(map[event]State)
switch len(cmp.Datasets) {
case 0:
sm.Transitions[ReceivingAlgorithm][algorithmReceived] = Running
default:
sm.Transitions[ReceivingAlgorithm][algorithmReceived] = ReceivingData
}
sm.Transitions[ReceivingData] = make(map[event]State)
sm.Transitions[ReceivingData][dataReceived] = Running
sm.Transitions[Running] = make(map[event]State)
sm.Transitions[Running][runComplete] = ConsumingResults
sm.Transitions[Running][runFailed] = Failed
sm.Transitions[ConsumingResults] = make(map[event]State)
sm.Transitions[ConsumingResults][resultsConsumed] = Complete
return sm
}
// Start the state machine.
func (sm *StateMachine) Start(ctx context.Context) {
sm.wg.Add(1)
defer sm.wg.Done()
for {
select {
case event := <-sm.EventChan:
currentState := sm.GetState()
var nextState State
var stateFunc func()
var valid bool
sm.mu.Lock()
nextState, valid = sm.Transitions[sm.State][event]
if valid {
sm.State = nextState
stateFunc = sm.StateFunctions[nextState]
}
sm.mu.Unlock()
if valid {
sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState))
if stateFunc != nil {
go stateFunc()
}
} else {
sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State))
}
case <-ctx.Done():
return
}
}
}
// SendEvent sends an event to the state machine.
func (sm *StateMachine) SendEvent(event event) {
sm.EventChan <- event
}
func (sm *StateMachine) GetState() State {
sm.mu.Lock()
defer sm.mu.Unlock()
return sm.State
}
func (sm *StateMachine) SetState(state State) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.State = state
}
-31
View File
@@ -1,31 +0,0 @@
// Code generated by "stringer -type=State"; DO NOT EDIT.
package agent
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Idle-0]
_ = x[ReceivingManifest-1]
_ = x[ReceivingAlgorithm-2]
_ = x[ReceivingData-3]
_ = x[Running-4]
_ = x[ConsumingResults-5]
_ = x[Complete-6]
_ = x[Failed-7]
_ = x[AlgorithmRun-8]
}
const _State_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailedAlgorithmRun"
var _State_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89, 101}
func (i State) String() string {
if i >= State(len(_State_index)-1) {
return "State(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _State_name[_State_index[i]:_State_index[i+1]]
}
+161 -54
View File
@@ -4,75 +4,182 @@ package agent
import (
"context"
"fmt"
sync "sync"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/ultravioletrs/cocos/agent/statemachine"
)
var cmp = Computation{
Datasets: []Dataset{
{
Dataset: []byte("test"),
UserKey: []byte("test"),
},
},
type MockState int
type MockEvent int
func (s MockState) String() string {
return []string{"State1", "State2", "State3"}[s]
}
func TestStateMachineTransitions(t *testing.T) {
cases := []struct {
fromState State
event event
expected State
cmp Computation
}{
{Idle, start, ReceivingManifest, cmp},
{ReceivingManifest, manifestReceived, ReceivingAlgorithm, cmp},
{ReceivingAlgorithm, algorithmReceived, ReceivingData, cmp},
{ReceivingAlgorithm, algorithmReceived, Running, Computation{}},
{ReceivingData, dataReceived, Running, cmp},
{Running, runComplete, ConsumingResults, cmp},
{ConsumingResults, resultsConsumed, Complete, cmp},
func (e MockEvent) String() string {
return []string{"Event1", "Event2", "Event3"}[e]
}
const (
State1 MockState = iota
State2
State3
)
const (
Event1 MockEvent = iota
Event2
Event3
)
func TestNewStateMachine(t *testing.T) {
sm := statemachine.NewStateMachine(State1)
if sm == nil {
t.Fatal("NewStateMachine returned nil")
}
for _, tc := range cases {
t.Run(fmt.Sprintf("Transition from %v to %v", tc.fromState, tc.expected), func(t *testing.T) {
sm := NewStateMachine(mglog.NewMock(), tc.cmp)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go sm.Start(ctx)
time.Sleep(50 * time.Millisecond)
sm.SetState(tc.fromState)
sm.SendEvent(tc.event)
time.Sleep(50 * time.Millisecond)
if sm.GetState() != tc.expected {
t.Errorf("Expected state %v after the event, but got %v", tc.expected, sm.GetState())
}
})
if sm.GetState() != State1 {
t.Errorf("Initial state not set correctly, got %v, want %v", sm.GetState(), State1)
}
}
func TestStateMachineInvalidTransition(t *testing.T) {
sm := NewStateMachine(mglog.NewMock(), cmp)
ctx, cancel := context.WithCancel(context.Background())
func TestAddTransition(t *testing.T) {
sm := statemachine.NewStateMachine(State1)
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
go sm.Start(ctx)
go func() {
if err := sm.Start(ctx); err != context.Canceled {
t.Errorf("Start returned error: %v", err)
}
}()
sm.SendEvent(Event1)
time.Sleep(50 * time.Millisecond)
sm.SetState(Idle)
sm.SendEvent(dataReceived)
time.Sleep(50 * time.Millisecond)
if sm.GetState() != Idle {
t.Errorf("State should not change on an invalid event, but got %v", sm.GetState())
if sm.GetState() != State2 {
t.Errorf("Transition not applied correctly, got state %v, want %v", sm.GetState(), State2)
}
}
func TestSetAction(t *testing.T) {
sm := statemachine.NewStateMachine(State1)
var wg sync.WaitGroup
wg.Add(1)
sm.SetAction(State2, func(s statemachine.State) {
defer wg.Done()
})
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
go func() {
if err := sm.Start(ctx); err != context.Canceled {
t.Errorf("Start returned error: %v", err)
}
}()
sm.SendEvent(Event1)
wg.Wait()
if ctx.Err() != nil {
t.Error("Action was not called within the expected time")
}
}
func TestInvalidTransition(t *testing.T) {
sm := statemachine.NewStateMachine(State1)
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
errChan := make(chan error)
go func() {
errChan <- sm.Start(ctx)
}()
sm.SendEvent(Event2)
select {
case err := <-errChan:
if err == nil {
t.Errorf("Expected invalid transition error, got: %v", err)
}
case <-time.After(150 * time.Millisecond):
t.Error("Timeout waiting for invalid transition error")
}
}
func TestMultipleTransitions(t *testing.T) {
sm := statemachine.NewStateMachine(State1)
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State3})
sm.AddTransition(statemachine.Transition{From: State3, Event: Event3, To: State1})
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
go func() {
if err := sm.Start(ctx); err != context.Canceled {
t.Errorf("Start returned error: %v", err)
}
}()
transitions := []struct {
event MockEvent
want MockState
}{
{Event1, State2},
{Event2, State3},
{Event3, State1},
}
for _, tt := range transitions {
sm.SendEvent(tt.event)
time.Sleep(50 * time.Millisecond)
if sm.GetState() != tt.want {
t.Errorf("After event %v, got state %v, want %v", tt.event, sm.GetState(), tt.want)
}
}
}
func TestConcurrency(t *testing.T) {
sm := statemachine.NewStateMachine(State1)
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State1})
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
go func() {
if err := sm.Start(ctx); err == nil {
t.Errorf("Expected context error, got nil")
}
}()
for i := 0; i < 100; i++ {
go func() {
sm.SendEvent(Event1)
sm.SendEvent(Event2)
}()
}
time.Sleep(400 * time.Millisecond)
finalState := sm.GetState()
if finalState != State1 && finalState != State2 {
t.Errorf("Unexpected final state: %v", finalState)
}
}
+86
View File
@@ -0,0 +1,86 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
context "context"
agent "github.com/ultravioletrs/cocos/agent/statemachine"
mock "github.com/stretchr/testify/mock"
)
// StateMachine is an autogenerated mock type for the StateMachine type
type StateMachine struct {
mock.Mock
}
// AddTransition provides a mock function with given fields: t
func (_m *StateMachine) AddTransition(t agent.Transition) {
_m.Called(t)
}
// GetState provides a mock function with given fields:
func (_m *StateMachine) GetState() agent.State {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for GetState")
}
var r0 agent.State
if rf, ok := ret.Get(0).(func() agent.State); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(agent.State)
}
}
return r0
}
// SendEvent provides a mock function with given fields: event
func (_m *StateMachine) SendEvent(event agent.Event) {
_m.Called(event)
}
// SetAction provides a mock function with given fields: state, action
func (_m *StateMachine) SetAction(state agent.State, action agent.Action) {
_m.Called(state, action)
}
// Start provides a mock function with given fields: ctx
func (_m *StateMachine) Start(ctx context.Context) error {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for Start")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewStateMachine(t interface {
mock.TestingT
Cleanup(func())
}) *StateMachine {
mock := &StateMachine{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+113
View File
@@ -0,0 +1,113 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package statemachine
import (
"context"
"fmt"
"sync"
)
type State interface {
String() string
}
type Event interface {
String() string
}
type Action func(State)
type Transition struct {
From State
Event Event
To State
}
//go:generate mockery --name StateMachine --output=mocks --filename state.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type StateMachine interface {
AddTransition(t Transition)
SetAction(state State, action Action)
GetState() State
SendEvent(event Event)
Start(ctx context.Context) error
}
type stateMachine struct {
mu sync.Mutex
currentState State
transitions map[State]map[Event]State
actions map[State]Action
eventChan chan Event
}
func NewStateMachine(initialState State) StateMachine {
return &stateMachine{
currentState: initialState,
transitions: make(map[State]map[Event]State),
actions: make(map[State]Action),
eventChan: make(chan Event),
}
}
func (sm *stateMachine) AddTransition(t Transition) {
sm.mu.Lock()
defer sm.mu.Unlock()
if _, ok := sm.transitions[t.From]; !ok {
sm.transitions[t.From] = make(map[Event]State)
}
sm.transitions[t.From][t.Event] = t.To
}
func (sm *stateMachine) SetAction(state State, action Action) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.actions[state] = action
}
func (sm *stateMachine) GetState() State {
sm.mu.Lock()
defer sm.mu.Unlock()
return sm.currentState
}
func (sm *stateMachine) SendEvent(event Event) {
sm.eventChan <- event
}
func (sm *stateMachine) Start(ctx context.Context) error {
for {
select {
case event := <-sm.eventChan:
if err := sm.handleEvent(event); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func (sm *stateMachine) handleEvent(event Event) error {
sm.mu.Lock()
currentState := sm.currentState
nextState, valid := sm.transitions[currentState][event]
sm.mu.Unlock()
if !valid {
return fmt.Errorf("invalid transition: %v -> %v", currentState, event)
}
sm.mu.Lock()
sm.currentState = nextState
action := sm.actions[nextState]
sm.mu.Unlock()
if action != nil {
go action(nextState)
}
return nil
}
+29 -2
View File
@@ -4,13 +4,17 @@ package cli
import (
"encoding/pem"
"fmt"
"os"
"github.com/fatih/color"
"github.com/spf13/cobra"
)
const resultFilePath = "results.zip"
const (
resultFilePrefix = "results"
resultFileExt = ".zip"
)
func (cli *CLI) NewResultsCmd() *cobra.Command {
return &cobra.Command{
@@ -42,12 +46,35 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
return
}
resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt)
if err != nil {
printError(cmd, "Error generating unique file path: %v ❌ ", err)
return
}
if err := os.WriteFile(resultFilePath, result, 0o644); err != nil {
printError(cmd, "Error saving computation result file: %v ❌ ", err)
return
}
cmd.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ "))
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath))
},
}
}
func getUniqueFilePath(prefix, ext string) (string, error) {
for i := 0; ; i++ {
var filename string
if i == 0 {
filename = prefix + ext
} else {
filename = fmt.Sprintf("%s_%d%s", prefix, i, ext)
}
if _, err := os.Stat(filename); os.IsNotExist(err) {
return filename, nil
} else if err != nil {
return "", err
}
}
}
+72 -6
View File
@@ -1,11 +1,14 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cli
import (
"bytes"
"errors"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/mock"
@@ -32,12 +35,50 @@ func TestResultsCmd_Success(t *testing.T) {
require.Contains(t, buf.String(), "Computation result retrieved and saved successfully")
resultFile, err := os.ReadFile("results.zip")
files, err := filepath.Glob("results*.zip")
require.NoError(t, err)
require.Len(t, files, 1)
resultFile, err := os.ReadFile(files[0])
require.NoError(t, err)
require.Equal(t, compResult, string(resultFile))
t.Cleanup(func() {
os.Remove("results.zip")
for _, file := range files {
os.Remove(file)
}
os.Remove(privateKeyFile)
})
}
func TestResultsCmd_MultipleExecutions(t *testing.T) {
mockSDK := new(mocks.SDK)
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
testCLI := New(mockSDK)
err := generateRSAPrivateKeyFile(privateKeyFile)
require.NoError(t, err)
cmd := testCLI.NewResultsCmd()
buf := new(bytes.Buffer)
cmd.SetOut(buf)
cmd.SetArgs([]string{privateKeyFile})
for i := 0; i < 3; i++ {
err = cmd.Execute()
require.NoError(t, err)
require.Contains(t, buf.String(), "Computation result retrieved and saved successfully")
buf.Reset()
}
files, err := filepath.Glob("results*.zip")
require.NoError(t, err)
require.Len(t, files, 3)
t.Cleanup(func() {
for _, file := range files {
os.Remove(file)
}
os.Remove(privateKeyFile)
})
}
@@ -87,8 +128,8 @@ func TestResultsCmd_SaveFailure(t *testing.T) {
err := generateRSAPrivateKeyFile(privateKeyFile)
require.NoError(t, err)
// Simulate failure in saving the result file by making a directory with the same name as the result file
err = os.Mkdir("results.zip", 0o755)
// Simulate failure in saving the result file by making all files read-only
err = os.Chmod(".", 0o555)
require.NoError(t, err)
cmd := testCLI.NewResultsCmd()
@@ -102,8 +143,10 @@ func TestResultsCmd_SaveFailure(t *testing.T) {
mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything)
t.Cleanup(func() {
os.Remove("results.zip")
os.Remove(privateKeyFile)
err := os.Chmod(".", 0o755)
require.NoError(t, err)
err = os.Remove(privateKeyFile)
require.NoError(t, err)
})
}
@@ -132,3 +175,26 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
require.Contains(t, buf.String(), "Error decoding private key")
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything)
}
func TestGetUniqueFilePath(t *testing.T) {
prefix := "test"
ext := ".txt"
path, err := getUniqueFilePath(prefix, ext)
require.NoError(t, err)
require.Equal(t, "test.txt", path)
_, err = os.Create("test.txt")
require.NoError(t, err)
defer os.Remove("test.txt")
for i := 1; i < 3; i++ {
fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext)
_, err := os.Create(fileName)
require.NoError(t, err)
defer os.Remove(fileName)
}
path, err = getUniqueFilePath(prefix, ext)
require.NoError(t, err)
require.Equal(t, "test_3.txt", path)
}