From db7f3c7a4bac2a240683ce3e2d51356e4205e565 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:19:12 +0300 Subject: [PATCH] COCOS-278 - Abstract state machine (#280) * abstract state machine Signed-off-by: Sammy Oina * perpetual results consumption Signed-off-by: Sammy Oina * async action Signed-off-by: Sammy Oina * fix failing tests Signed-off-by: Sammy Oina * fix failing test Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/agentevent_string.go | 29 ++++ agent/agentstate_string.go | 30 +++++ agent/service.go | 177 ++++++++++++++++-------- agent/service_test.go | 32 ++--- agent/state.go | 150 --------------------- agent/state_string.go | 31 ----- agent/state_test.go | 215 ++++++++++++++++++++++-------- agent/statemachine/mocks/state.go | 86 ++++++++++++ agent/statemachine/state.go | 113 ++++++++++++++++ cli/result.go | 31 ++++- cli/result_test.go | 78 ++++++++++- 11 files changed, 661 insertions(+), 311 deletions(-) create mode 100644 agent/agentevent_string.go create mode 100644 agent/agentstate_string.go delete mode 100644 agent/state.go delete mode 100644 agent/state_string.go create mode 100644 agent/statemachine/mocks/state.go create mode 100644 agent/statemachine/state.go diff --git a/agent/agentevent_string.go b/agent/agentevent_string.go new file mode 100644 index 00000000..1cb344eb --- /dev/null +++ b/agent/agentevent_string.go @@ -0,0 +1,29 @@ +// Code generated by "stringer -type=AgentEvent"; DO NOT EDIT. + +package agent + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Start-0] + _ = x[ManifestReceived-1] + _ = x[AlgorithmReceived-2] + _ = x[DataReceived-3] + _ = x[RunComplete-4] + _ = x[ResultsConsumed-5] + _ = x[RunFailed-6] +} + +const _AgentEvent_name = "StartManifestReceivedAlgorithmReceivedDataReceivedRunCompleteResultsConsumedRunFailed" + +var _AgentEvent_index = [...]uint8{0, 5, 21, 38, 50, 61, 76, 85} + +func (i AgentEvent) String() string { + if i < 0 || i >= AgentEvent(len(_AgentEvent_index)-1) { + return "AgentEvent(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _AgentEvent_name[_AgentEvent_index[i]:_AgentEvent_index[i+1]] +} diff --git a/agent/agentstate_string.go b/agent/agentstate_string.go new file mode 100644 index 00000000..e620f39c --- /dev/null +++ b/agent/agentstate_string.go @@ -0,0 +1,30 @@ +// Code generated by "stringer -type=AgentState"; DO NOT EDIT. + +package agent + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Idle-0] + _ = x[ReceivingManifest-1] + _ = x[ReceivingAlgorithm-2] + _ = x[ReceivingData-3] + _ = x[Running-4] + _ = x[ConsumingResults-5] + _ = x[Complete-6] + _ = x[Failed-7] +} + +const _AgentState_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailed" + +var _AgentState_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89} + +func (i AgentState) String() string { + if i < 0 || i >= AgentState(len(_AgentState_index)-1) { + return "AgentState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _AgentState_name[_AgentState_index[i]:_AgentState_index[i+1]] +} diff --git a/agent/service.go b/agent/service.go index 540d475e..9f4e5e1a 100644 --- a/agent/service.go +++ b/agent/service.go @@ -20,12 +20,52 @@ import ( "github.com/ultravioletrs/cocos/agent/algorithm/python" "github.com/ultravioletrs/cocos/agent/algorithm/wasm" "github.com/ultravioletrs/cocos/agent/events" + "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" "golang.org/x/crypto/sha3" ) var _ Service = (*agentService)(nil) +//go:generate stringer -type=AgentState +type AgentState int + +const ( + Idle AgentState = iota + ReceivingManifest + ReceivingAlgorithm + ReceivingData + Running + ConsumingResults + Complete + Failed +) + +//go:generate stringer -type=AgentEvent +type AgentEvent int + +const ( + Start AgentEvent = iota + ManifestReceived + AlgorithmReceived + DataReceived + RunComplete + ResultsConsumed + RunFailed +) + +//go:generate stringer -type=Status +type Status uint8 + +const ( + IdleState Status = iota + InProgress + Ready + Completed + Terminated + Warning +) + const ( // ReportDataSize is the size of the report data expected by the attestation service. ReportDataSize = 64 @@ -71,40 +111,69 @@ type Service interface { } type agentService struct { - computation Computation // Holds the current computation request details. - algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. - result []byte // Stores the result of the computation. - sm *StateMachine // Manages the state transitions of the agent service. - runError error // Stores any error encountered during the computation run. - eventSvc events.Service // Service for publishing events related to computation. - quoteProvider client.QuoteProvider // Provider for generating attestation quotes. + computation Computation // Holds the current computation request details. + algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. + result []byte // Stores the result of the computation. + sm statemachine.StateMachine // Manages the state transitions of the agent service. + runError error // Stores any error encountered during the computation run. + eventSvc events.Service // Service for publishing events related to computation. + quoteProvider client.QuoteProvider // Provider for generating attestation quotes. + logger *slog.Logger // Logger for the agent service. + resultsConsumed bool // Indicates if the results have been consumed. } var _ Service = (*agentService)(nil) // New instantiates the agent service implementation. func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation, quoteProvider client.QuoteProvider) Service { + sm := statemachine.NewStateMachine(Idle) svc := &agentService{ - sm: NewStateMachine(logger, cmp), + sm: sm, eventSvc: eventSvc, quoteProvider: quoteProvider, + logger: logger, + computation: cmp, } - svc.sm.StateFunctions[Idle] = svc.publishEvent(IdleState.String(), json.RawMessage{}) - svc.sm.StateFunctions[ReceivingManifest] = svc.publishEvent(InProgress.String(), json.RawMessage{}) - svc.sm.StateFunctions[ReceivingAlgorithm] = svc.publishEvent(InProgress.String(), json.RawMessage{}) - svc.sm.StateFunctions[ReceivingData] = svc.publishEvent(InProgress.String(), json.RawMessage{}) - svc.sm.StateFunctions[ConsumingResults] = svc.publishEvent(Ready.String(), json.RawMessage{}) - svc.sm.StateFunctions[Complete] = svc.publishEvent(Completed.String(), json.RawMessage{}) - svc.sm.StateFunctions[Running] = svc.runComputation - svc.sm.StateFunctions[Failed] = svc.publishEvent(Failed.String(), json.RawMessage{}) + transitions := []statemachine.Transition{ + {From: Idle, Event: Start, To: ReceivingManifest}, + {From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm}, + } - go svc.sm.Start(ctx) - svc.sm.SendEvent(start) + if len(cmp.Datasets) == 0 { + transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running}) + } else { + transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData}) + transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running}) + } - svc.computation = cmp + transitions = append(transitions, []statemachine.Transition{ + {From: Running, Event: RunComplete, To: ConsumingResults}, + {From: Running, Event: RunFailed, To: Failed}, + {From: ConsumingResults, Event: ResultsConsumed, To: Complete}, + }...) + + for _, t := range transitions { + sm.AddTransition(t) + } + + sm.SetAction(Idle, svc.publishEvent(IdleState.String())) + sm.SetAction(ReceivingManifest, svc.publishEvent(InProgress.String())) + sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String())) + sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String())) + sm.SetAction(Running, svc.runComputation) + sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String())) + sm.SetAction(Complete, svc.publishEvent(Completed.String())) + sm.SetAction(Failed, svc.publishEvent(Failed.String())) + + go func() { + if err := sm.Start(ctx); err != nil { + logger.Error(err.Error()) + } + }() + sm.SendEvent(Start) + defer sm.SendEvent(ManifestReceived) - svc.sm.SendEvent(manifestReceived) return svc } @@ -153,7 +222,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { switch algoType { case string(algorithm.AlgoTypeBin): - as.algorithm = binary.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args) + as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args) case string(algorithm.AlgoTypePython): var requirementsFile string if len(algo.Requirements) > 0 { @@ -171,11 +240,11 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { requirementsFile = fr.Name() } runtime := python.PythonRunTimeFromContext(ctx) - as.algorithm = python.NewAlgorithm(as.sm.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args) + as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args) case string(algorithm.AlgoTypeWasm): - as.algorithm = wasm.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args) + as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args) case string(algorithm.AlgoTypeDocker): - as.algorithm = docker.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name()) + as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name()) } if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil { @@ -183,7 +252,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { } if as.algorithm != nil { - as.sm.SendEvent(algorithmReceived) + as.sm.SendEvent(AlgorithmReceived) } return nil @@ -236,27 +305,30 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { } if len(as.computation.Datasets) == 0 { - defer as.sm.SendEvent(dataReceived) + defer as.sm.SendEvent(DataReceived) } return nil } func (as *agentService) Result(ctx context.Context) ([]byte, error) { - if as.sm.GetState() != ConsumingResults && as.sm.GetState() != Failed { + currentState := as.sm.GetState() + if currentState != ConsumingResults && currentState != Complete && currentState != Failed { return []byte{}, ErrResultsNotReady } - if len(as.computation.ResultConsumers) == 0 { - return []byte{}, ErrAllResultsConsumed - } + index, ok := IndexFromContext(ctx) if !ok { return []byte{}, ErrUndeclaredConsumer } - as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1) - if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults { - defer as.sm.SendEvent(resultsConsumed) + if index < 0 || index >= len(as.computation.ResultConsumers) { + return []byte{}, ErrUndeclaredConsumer + } + + if !as.resultsConsumed && currentState == ConsumingResults { + as.resultsConsumed = true + defer as.sm.SendEvent(ResultsConsumed) } return as.result, as.runError @@ -271,59 +343,58 @@ func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataS return rawQuote, nil } -func (as *agentService) runComputation() { - as.publishEvent(InProgress.String(), json.RawMessage{})() - as.sm.logger.Debug("computation run started") +func (as *agentService) runComputation(state statemachine.State) { + as.publishEvent(InProgress.String())(state) + as.logger.Debug("computation run started") defer func() { if as.runError != nil { - as.sm.SendEvent(runFailed) + as.sm.SendEvent(RunFailed) } else { - as.sm.SendEvent(runComplete) + as.sm.SendEvent(RunComplete) } }() if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil { as.runError = fmt.Errorf("error creating results directory: %s", err.Error()) - as.sm.logger.Warn(as.runError.Error()) - as.publishEvent(Failed.String(), json.RawMessage{})() + as.logger.Warn(as.runError.Error()) + as.publishEvent(Failed.String())(state) return } defer func() { if err := os.RemoveAll(algorithm.ResultsDir); err != nil { - as.sm.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error())) + as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error())) } if err := os.RemoveAll(algorithm.DatasetsDir); err != nil { - as.sm.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error())) + as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error())) } }() - as.publishEvent(InProgress.String(), json.RawMessage{})() + as.publishEvent(InProgress.String())(state) if err := as.algorithm.Run(); err != nil { as.runError = err - as.sm.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) - as.publishEvent(Failed.String(), json.RawMessage{})() + as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) + as.publishEvent(Failed.String())(state) return } results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir) if err != nil { as.runError = err - as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) - as.publishEvent(Failed.String(), json.RawMessage{})() + as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) + as.publishEvent(Failed.String())(state) return } - as.publishEvent(Completed.String(), json.RawMessage{})() + as.publishEvent(Completed.String())(state) as.result = results } -func (as *agentService) publishEvent(status string, details json.RawMessage) func() { - return func() { - st := as.sm.GetState().String() - if err := as.eventSvc.SendEvent(st, status, details); err != nil { - as.sm.logger.Warn(err.Error()) +func (as *agentService) publishEvent(status string) statemachine.Action { + return func(state statemachine.State) { + if err := as.eventSvc.SendEvent(state.String(), status, json.RawMessage{}); err != nil { + as.logger.Warn(err.Error()) } } } diff --git a/agent/service_test.go b/agent/service_test.go index 545dae48..f237b9df 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -20,6 +20,8 @@ import ( "github.com/ultravioletrs/cocos/agent/events/mocks" "github.com/ultravioletrs/cocos/agent/quoteprovider" mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks" + "github.com/ultravioletrs/cocos/agent/statemachine" + smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" ) @@ -249,45 +251,36 @@ func TestResult(t *testing.T) { err error setup func(svc *agentService) ctxSetup func(ctx context.Context) context.Context + state statemachine.State }{ { name: "Test results not ready", err: ErrResultsNotReady, setup: func(svc *agentService) { }, - }, - { - name: "Test all results consumed", - err: ErrAllResultsConsumed, - setup: func(svc *agentService) { - svc.sm.SetState(ConsumingResults) - svc.computation.ResultConsumers = []ResultConsumer{} - }, - ctxSetup: func(ctx context.Context) context.Context { - return IndexToContext(ctx, 0) - }, + state: Running, }, { name: "Test undeclared consumer", err: ErrUndeclaredConsumer, setup: func(svc *agentService) { - svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}} }, ctxSetup: func(ctx context.Context) context.Context { return ctx }, + state: ConsumingResults, }, { name: "Test results consumed and event sent", err: nil, setup: func(svc *agentService) { - svc.sm.SetState(ConsumingResults) svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}} }, ctxSetup: func(ctx context.Context) context.Context { return IndexToContext(ctx, 0) }, + state: ConsumingResults, }, } @@ -301,14 +294,23 @@ func TestResult(t *testing.T) { ctx = tc.ctxSetup(ctx) } + sm := new(smmocks.StateMachine) + sm.On("Start", ctx).Return(nil) + sm.On("GetState").Return(tc.state) + sm.On("SendEvent", mock.Anything).Return() + svc := &agentService{ - sm: NewStateMachine(mglog.NewMock(), testComputation(t)), + sm: sm, eventSvc: events, quoteProvider: qp, computation: testComputation(t), } - go svc.sm.Start(ctx) + go func() { + if err := svc.sm.Start(ctx); err != nil { + t.Errorf("Error starting state machine: %v", err) + } + }() tc.setup(svc) _, err := svc.Result(ctx) t.Cleanup(func() { diff --git a/agent/state.go b/agent/state.go deleted file mode 100644 index f109735c..00000000 --- a/agent/state.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package agent - -import ( - "context" - "fmt" - "log/slog" - "sync" -) - -//go:generate stringer -type=State -type State uint8 - -const ( - Idle State = iota - ReceivingManifest - ReceivingAlgorithm - ReceivingData - Running - ConsumingResults - Complete - Failed - AlgorithmRun -) - -//go:generate stringer -type=Status -type Status uint8 - -const ( - IdleState Status = iota - InProgress - Ready - Completed - Terminated - Warning -) - -type event uint8 - -const ( - start event = iota - manifestReceived - algorithmReceived - dataReceived - runComplete - resultsConsumed - runFailed -) - -// StateMachine represents the state machine. -type StateMachine struct { - mu sync.Mutex - State State - EventChan chan event - Transitions map[State]map[event]State - StateFunctions map[State]func() - logger *slog.Logger - wg *sync.WaitGroup -} - -// NewStateMachine creates a new StateMachine. -func NewStateMachine(logger *slog.Logger, cmp Computation) *StateMachine { - sm := &StateMachine{ - State: Idle, - EventChan: make(chan event), - Transitions: make(map[State]map[event]State), - StateFunctions: make(map[State]func()), - logger: logger, - wg: &sync.WaitGroup{}, - } - - sm.Transitions[Idle] = make(map[event]State) - sm.Transitions[Idle][start] = ReceivingManifest - - sm.Transitions[ReceivingManifest] = make(map[event]State) - sm.Transitions[ReceivingManifest][manifestReceived] = ReceivingAlgorithm - - sm.Transitions[ReceivingAlgorithm] = make(map[event]State) - switch len(cmp.Datasets) { - case 0: - sm.Transitions[ReceivingAlgorithm][algorithmReceived] = Running - default: - sm.Transitions[ReceivingAlgorithm][algorithmReceived] = ReceivingData - } - - sm.Transitions[ReceivingData] = make(map[event]State) - sm.Transitions[ReceivingData][dataReceived] = Running - - sm.Transitions[Running] = make(map[event]State) - sm.Transitions[Running][runComplete] = ConsumingResults - sm.Transitions[Running][runFailed] = Failed - - sm.Transitions[ConsumingResults] = make(map[event]State) - sm.Transitions[ConsumingResults][resultsConsumed] = Complete - - return sm -} - -// Start the state machine. -func (sm *StateMachine) Start(ctx context.Context) { - sm.wg.Add(1) - defer sm.wg.Done() - for { - select { - case event := <-sm.EventChan: - currentState := sm.GetState() - var nextState State - var stateFunc func() - var valid bool - - sm.mu.Lock() - nextState, valid = sm.Transitions[sm.State][event] - if valid { - sm.State = nextState - stateFunc = sm.StateFunctions[nextState] - } - sm.mu.Unlock() - - if valid { - sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState)) - if stateFunc != nil { - go stateFunc() - } - } else { - sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State)) - } - - case <-ctx.Done(): - return - } - } -} - -// SendEvent sends an event to the state machine. -func (sm *StateMachine) SendEvent(event event) { - sm.EventChan <- event -} - -func (sm *StateMachine) GetState() State { - sm.mu.Lock() - defer sm.mu.Unlock() - return sm.State -} - -func (sm *StateMachine) SetState(state State) { - sm.mu.Lock() - defer sm.mu.Unlock() - sm.State = state -} diff --git a/agent/state_string.go b/agent/state_string.go deleted file mode 100644 index b084ec0b..00000000 --- a/agent/state_string.go +++ /dev/null @@ -1,31 +0,0 @@ -// Code generated by "stringer -type=State"; DO NOT EDIT. - -package agent - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[Idle-0] - _ = x[ReceivingManifest-1] - _ = x[ReceivingAlgorithm-2] - _ = x[ReceivingData-3] - _ = x[Running-4] - _ = x[ConsumingResults-5] - _ = x[Complete-6] - _ = x[Failed-7] - _ = x[AlgorithmRun-8] -} - -const _State_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailedAlgorithmRun" - -var _State_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89, 101} - -func (i State) String() string { - if i >= State(len(_State_index)-1) { - return "State(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _State_name[_State_index[i]:_State_index[i+1]] -} diff --git a/agent/state_test.go b/agent/state_test.go index 43e3fb92..1bd26941 100644 --- a/agent/state_test.go +++ b/agent/state_test.go @@ -4,75 +4,182 @@ package agent import ( "context" - "fmt" + sync "sync" "testing" "time" - mglog "github.com/absmach/magistrala/logger" + "github.com/ultravioletrs/cocos/agent/statemachine" ) -var cmp = Computation{ - Datasets: []Dataset{ - { - Dataset: []byte("test"), - UserKey: []byte("test"), - }, - }, +type MockState int + +type MockEvent int + +func (s MockState) String() string { + return []string{"State1", "State2", "State3"}[s] } -func TestStateMachineTransitions(t *testing.T) { - cases := []struct { - fromState State - event event - expected State - cmp Computation - }{ - {Idle, start, ReceivingManifest, cmp}, - {ReceivingManifest, manifestReceived, ReceivingAlgorithm, cmp}, - {ReceivingAlgorithm, algorithmReceived, ReceivingData, cmp}, - {ReceivingAlgorithm, algorithmReceived, Running, Computation{}}, - {ReceivingData, dataReceived, Running, cmp}, - {Running, runComplete, ConsumingResults, cmp}, - {ConsumingResults, resultsConsumed, Complete, cmp}, +func (e MockEvent) String() string { + return []string{"Event1", "Event2", "Event3"}[e] +} + +const ( + State1 MockState = iota + State2 + State3 +) + +const ( + Event1 MockEvent = iota + Event2 + Event3 +) + +func TestNewStateMachine(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + if sm == nil { + t.Fatal("NewStateMachine returned nil") } - - for _, tc := range cases { - t.Run(fmt.Sprintf("Transition from %v to %v", tc.fromState, tc.expected), func(t *testing.T) { - sm := NewStateMachine(mglog.NewMock(), tc.cmp) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go sm.Start(ctx) - - time.Sleep(50 * time.Millisecond) - - sm.SetState(tc.fromState) - sm.SendEvent(tc.event) - - time.Sleep(50 * time.Millisecond) - - if sm.GetState() != tc.expected { - t.Errorf("Expected state %v after the event, but got %v", tc.expected, sm.GetState()) - } - }) + if sm.GetState() != State1 { + t.Errorf("Initial state not set correctly, got %v, want %v", sm.GetState(), State1) } } -func TestStateMachineInvalidTransition(t *testing.T) { - sm := NewStateMachine(mglog.NewMock(), cmp) - ctx, cancel := context.WithCancel(context.Background()) +func TestAddTransition(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - go sm.Start(ctx) + go func() { + if err := sm.Start(ctx); err != context.Canceled { + t.Errorf("Start returned error: %v", err) + } + }() + + sm.SendEvent(Event1) time.Sleep(50 * time.Millisecond) - sm.SetState(Idle) - sm.SendEvent(dataReceived) - - time.Sleep(50 * time.Millisecond) - - if sm.GetState() != Idle { - t.Errorf("State should not change on an invalid event, but got %v", sm.GetState()) + if sm.GetState() != State2 { + t.Errorf("Transition not applied correctly, got state %v, want %v", sm.GetState(), State2) + } +} + +func TestSetAction(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + + var wg sync.WaitGroup + wg.Add(1) + + sm.SetAction(State2, func(s statemachine.State) { + defer wg.Done() + }) + + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + go func() { + if err := sm.Start(ctx); err != context.Canceled { + t.Errorf("Start returned error: %v", err) + } + }() + + sm.SendEvent(Event1) + + wg.Wait() + + if ctx.Err() != nil { + t.Error("Action was not called within the expected time") + } +} + +func TestInvalidTransition(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + errChan := make(chan error) + go func() { + errChan <- sm.Start(ctx) + }() + + sm.SendEvent(Event2) + + select { + case err := <-errChan: + if err == nil { + t.Errorf("Expected invalid transition error, got: %v", err) + } + case <-time.After(150 * time.Millisecond): + t.Error("Timeout waiting for invalid transition error") + } +} + +func TestMultipleTransitions(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State3}) + sm.AddTransition(statemachine.Transition{From: State3, Event: Event3, To: State1}) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + go func() { + if err := sm.Start(ctx); err != context.Canceled { + t.Errorf("Start returned error: %v", err) + } + }() + + transitions := []struct { + event MockEvent + want MockState + }{ + {Event1, State2}, + {Event2, State3}, + {Event3, State1}, + } + + for _, tt := range transitions { + sm.SendEvent(tt.event) + time.Sleep(50 * time.Millisecond) + + if sm.GetState() != tt.want { + t.Errorf("After event %v, got state %v, want %v", tt.event, sm.GetState(), tt.want) + } + } +} + +func TestConcurrency(t *testing.T) { + sm := statemachine.NewStateMachine(State1) + sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2}) + sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State1}) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + go func() { + if err := sm.Start(ctx); err == nil { + t.Errorf("Expected context error, got nil") + } + }() + + for i := 0; i < 100; i++ { + go func() { + sm.SendEvent(Event1) + sm.SendEvent(Event2) + }() + } + + time.Sleep(400 * time.Millisecond) + + finalState := sm.GetState() + if finalState != State1 && finalState != State2 { + t.Errorf("Unexpected final state: %v", finalState) } } diff --git a/agent/statemachine/mocks/state.go b/agent/statemachine/mocks/state.go new file mode 100644 index 00000000..e3a3a6f8 --- /dev/null +++ b/agent/statemachine/mocks/state.go @@ -0,0 +1,86 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + context "context" + + agent "github.com/ultravioletrs/cocos/agent/statemachine" + + mock "github.com/stretchr/testify/mock" +) + +// StateMachine is an autogenerated mock type for the StateMachine type +type StateMachine struct { + mock.Mock +} + +// AddTransition provides a mock function with given fields: t +func (_m *StateMachine) AddTransition(t agent.Transition) { + _m.Called(t) +} + +// GetState provides a mock function with given fields: +func (_m *StateMachine) GetState() agent.State { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetState") + } + + var r0 agent.State + if rf, ok := ret.Get(0).(func() agent.State); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(agent.State) + } + } + + return r0 +} + +// SendEvent provides a mock function with given fields: event +func (_m *StateMachine) SendEvent(event agent.Event) { + _m.Called(event) +} + +// SetAction provides a mock function with given fields: state, action +func (_m *StateMachine) SetAction(state agent.State, action agent.Action) { + _m.Called(state, action) +} + +// Start provides a mock function with given fields: ctx +func (_m *StateMachine) Start(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewStateMachine(t interface { + mock.TestingT + Cleanup(func()) +}) *StateMachine { + mock := &StateMachine{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/statemachine/state.go b/agent/statemachine/state.go new file mode 100644 index 00000000..27b5d1e4 --- /dev/null +++ b/agent/statemachine/state.go @@ -0,0 +1,113 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package statemachine + +import ( + "context" + "fmt" + "sync" +) + +type State interface { + String() string +} + +type Event interface { + String() string +} + +type Action func(State) + +type Transition struct { + From State + Event Event + To State +} + +//go:generate mockery --name StateMachine --output=mocks --filename state.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0" +type StateMachine interface { + AddTransition(t Transition) + SetAction(state State, action Action) + GetState() State + SendEvent(event Event) + Start(ctx context.Context) error +} + +type stateMachine struct { + mu sync.Mutex + currentState State + transitions map[State]map[Event]State + actions map[State]Action + eventChan chan Event +} + +func NewStateMachine(initialState State) StateMachine { + return &stateMachine{ + currentState: initialState, + transitions: make(map[State]map[Event]State), + actions: make(map[State]Action), + eventChan: make(chan Event), + } +} + +func (sm *stateMachine) AddTransition(t Transition) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if _, ok := sm.transitions[t.From]; !ok { + sm.transitions[t.From] = make(map[Event]State) + } + sm.transitions[t.From][t.Event] = t.To +} + +func (sm *stateMachine) SetAction(state State, action Action) { + sm.mu.Lock() + defer sm.mu.Unlock() + + sm.actions[state] = action +} + +func (sm *stateMachine) GetState() State { + sm.mu.Lock() + defer sm.mu.Unlock() + return sm.currentState +} + +func (sm *stateMachine) SendEvent(event Event) { + sm.eventChan <- event +} + +func (sm *stateMachine) Start(ctx context.Context) error { + for { + select { + case event := <-sm.eventChan: + if err := sm.handleEvent(event); err != nil { + return err + } + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (sm *stateMachine) handleEvent(event Event) error { + sm.mu.Lock() + currentState := sm.currentState + nextState, valid := sm.transitions[currentState][event] + sm.mu.Unlock() + + if !valid { + return fmt.Errorf("invalid transition: %v -> %v", currentState, event) + } + + sm.mu.Lock() + sm.currentState = nextState + action := sm.actions[nextState] + sm.mu.Unlock() + + if action != nil { + go action(nextState) + } + + return nil +} diff --git a/cli/result.go b/cli/result.go index d43fa59d..490172e3 100644 --- a/cli/result.go +++ b/cli/result.go @@ -4,13 +4,17 @@ package cli import ( "encoding/pem" + "fmt" "os" "github.com/fatih/color" "github.com/spf13/cobra" ) -const resultFilePath = "results.zip" +const ( + resultFilePrefix = "results" + resultFileExt = ".zip" +) func (cli *CLI) NewResultsCmd() *cobra.Command { return &cobra.Command{ @@ -42,12 +46,35 @@ func (cli *CLI) NewResultsCmd() *cobra.Command { return } + resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt) + if err != nil { + printError(cmd, "Error generating unique file path: %v ❌ ", err) + return + } + if err := os.WriteFile(resultFilePath, result, 0o644); err != nil { printError(cmd, "Error saving computation result file: %v ❌ ", err) return } - cmd.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ ")) + cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath)) }, } } + +func getUniqueFilePath(prefix, ext string) (string, error) { + for i := 0; ; i++ { + var filename string + if i == 0 { + filename = prefix + ext + } else { + filename = fmt.Sprintf("%s_%d%s", prefix, i, ext) + } + + if _, err := os.Stat(filename); os.IsNotExist(err) { + return filename, nil + } else if err != nil { + return "", err + } + } +} diff --git a/cli/result_test.go b/cli/result_test.go index 6ef363bc..fcd0f998 100644 --- a/cli/result_test.go +++ b/cli/result_test.go @@ -1,11 +1,14 @@ // Copyright (c) Ultraviolet // SPDX-License-Identifier: Apache-2.0 + package cli import ( "bytes" "errors" + "fmt" "os" + "path/filepath" "testing" "github.com/stretchr/testify/mock" @@ -32,12 +35,50 @@ func TestResultsCmd_Success(t *testing.T) { require.Contains(t, buf.String(), "Computation result retrieved and saved successfully") - resultFile, err := os.ReadFile("results.zip") + files, err := filepath.Glob("results*.zip") + require.NoError(t, err) + require.Len(t, files, 1) + + resultFile, err := os.ReadFile(files[0]) require.NoError(t, err) require.Equal(t, compResult, string(resultFile)) t.Cleanup(func() { - os.Remove("results.zip") + for _, file := range files { + os.Remove(file) + } + os.Remove(privateKeyFile) + }) +} + +func TestResultsCmd_MultipleExecutions(t *testing.T) { + mockSDK := new(mocks.SDK) + mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil) + testCLI := New(mockSDK) + + err := generateRSAPrivateKeyFile(privateKeyFile) + require.NoError(t, err) + + cmd := testCLI.NewResultsCmd() + buf := new(bytes.Buffer) + cmd.SetOut(buf) + cmd.SetArgs([]string{privateKeyFile}) + + for i := 0; i < 3; i++ { + err = cmd.Execute() + require.NoError(t, err) + require.Contains(t, buf.String(), "Computation result retrieved and saved successfully") + buf.Reset() + } + + files, err := filepath.Glob("results*.zip") + require.NoError(t, err) + require.Len(t, files, 3) + + t.Cleanup(func() { + for _, file := range files { + os.Remove(file) + } os.Remove(privateKeyFile) }) } @@ -87,8 +128,8 @@ func TestResultsCmd_SaveFailure(t *testing.T) { err := generateRSAPrivateKeyFile(privateKeyFile) require.NoError(t, err) - // Simulate failure in saving the result file by making a directory with the same name as the result file - err = os.Mkdir("results.zip", 0o755) + // Simulate failure in saving the result file by making all files read-only + err = os.Chmod(".", 0o555) require.NoError(t, err) cmd := testCLI.NewResultsCmd() @@ -102,8 +143,10 @@ func TestResultsCmd_SaveFailure(t *testing.T) { mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything) t.Cleanup(func() { - os.Remove("results.zip") - os.Remove(privateKeyFile) + err := os.Chmod(".", 0o755) + require.NoError(t, err) + err = os.Remove(privateKeyFile) + require.NoError(t, err) }) } @@ -132,3 +175,26 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) { require.Contains(t, buf.String(), "Error decoding private key") mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything) } + +func TestGetUniqueFilePath(t *testing.T) { + prefix := "test" + ext := ".txt" + + path, err := getUniqueFilePath(prefix, ext) + require.NoError(t, err) + require.Equal(t, "test.txt", path) + + _, err = os.Create("test.txt") + require.NoError(t, err) + defer os.Remove("test.txt") + for i := 1; i < 3; i++ { + fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext) + _, err := os.Create(fileName) + require.NoError(t, err) + defer os.Remove(fileName) + } + + path, err = getUniqueFilePath(prefix, ext) + require.NoError(t, err) + require.Equal(t, "test_3.txt", path) +}