mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-22 20:00:18 +00:00
COCOS-278 - Abstract state machine (#280)
* abstract state machine Signed-off-by: Sammy Oina <sammyoina@gmail.com> * perpetual results consumption Signed-off-by: Sammy Oina <sammyoina@gmail.com> * async action Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix failing tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix failing test Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
fb0fbaeb9a
commit
db7f3c7a4b
@@ -0,0 +1,29 @@
|
||||
// Code generated by "stringer -type=AgentEvent"; DO NOT EDIT.
|
||||
|
||||
package agent
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Start-0]
|
||||
_ = x[ManifestReceived-1]
|
||||
_ = x[AlgorithmReceived-2]
|
||||
_ = x[DataReceived-3]
|
||||
_ = x[RunComplete-4]
|
||||
_ = x[ResultsConsumed-5]
|
||||
_ = x[RunFailed-6]
|
||||
}
|
||||
|
||||
const _AgentEvent_name = "StartManifestReceivedAlgorithmReceivedDataReceivedRunCompleteResultsConsumedRunFailed"
|
||||
|
||||
var _AgentEvent_index = [...]uint8{0, 5, 21, 38, 50, 61, 76, 85}
|
||||
|
||||
func (i AgentEvent) String() string {
|
||||
if i < 0 || i >= AgentEvent(len(_AgentEvent_index)-1) {
|
||||
return "AgentEvent(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _AgentEvent_name[_AgentEvent_index[i]:_AgentEvent_index[i+1]]
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
// Code generated by "stringer -type=AgentState"; DO NOT EDIT.
|
||||
|
||||
package agent
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Idle-0]
|
||||
_ = x[ReceivingManifest-1]
|
||||
_ = x[ReceivingAlgorithm-2]
|
||||
_ = x[ReceivingData-3]
|
||||
_ = x[Running-4]
|
||||
_ = x[ConsumingResults-5]
|
||||
_ = x[Complete-6]
|
||||
_ = x[Failed-7]
|
||||
}
|
||||
|
||||
const _AgentState_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailed"
|
||||
|
||||
var _AgentState_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89}
|
||||
|
||||
func (i AgentState) String() string {
|
||||
if i < 0 || i >= AgentState(len(_AgentState_index)-1) {
|
||||
return "AgentState(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _AgentState_name[_AgentState_index[i]:_AgentState_index[i+1]]
|
||||
}
|
||||
+124
-53
@@ -20,12 +20,52 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/python"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/wasm"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
var _ Service = (*agentService)(nil)
|
||||
|
||||
//go:generate stringer -type=AgentState
|
||||
type AgentState int
|
||||
|
||||
const (
|
||||
Idle AgentState = iota
|
||||
ReceivingManifest
|
||||
ReceivingAlgorithm
|
||||
ReceivingData
|
||||
Running
|
||||
ConsumingResults
|
||||
Complete
|
||||
Failed
|
||||
)
|
||||
|
||||
//go:generate stringer -type=AgentEvent
|
||||
type AgentEvent int
|
||||
|
||||
const (
|
||||
Start AgentEvent = iota
|
||||
ManifestReceived
|
||||
AlgorithmReceived
|
||||
DataReceived
|
||||
RunComplete
|
||||
ResultsConsumed
|
||||
RunFailed
|
||||
)
|
||||
|
||||
//go:generate stringer -type=Status
|
||||
type Status uint8
|
||||
|
||||
const (
|
||||
IdleState Status = iota
|
||||
InProgress
|
||||
Ready
|
||||
Completed
|
||||
Terminated
|
||||
Warning
|
||||
)
|
||||
|
||||
const (
|
||||
// ReportDataSize is the size of the report data expected by the attestation service.
|
||||
ReportDataSize = 64
|
||||
@@ -71,40 +111,69 @@ type Service interface {
|
||||
}
|
||||
|
||||
type agentService struct {
|
||||
computation Computation // Holds the current computation request details.
|
||||
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
|
||||
result []byte // Stores the result of the computation.
|
||||
sm *StateMachine // Manages the state transitions of the agent service.
|
||||
runError error // Stores any error encountered during the computation run.
|
||||
eventSvc events.Service // Service for publishing events related to computation.
|
||||
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
|
||||
computation Computation // Holds the current computation request details.
|
||||
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
|
||||
result []byte // Stores the result of the computation.
|
||||
sm statemachine.StateMachine // Manages the state transitions of the agent service.
|
||||
runError error // Stores any error encountered during the computation run.
|
||||
eventSvc events.Service // Service for publishing events related to computation.
|
||||
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
|
||||
logger *slog.Logger // Logger for the agent service.
|
||||
resultsConsumed bool // Indicates if the results have been consumed.
|
||||
}
|
||||
|
||||
var _ Service = (*agentService)(nil)
|
||||
|
||||
// New instantiates the agent service implementation.
|
||||
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation, quoteProvider client.QuoteProvider) Service {
|
||||
sm := statemachine.NewStateMachine(Idle)
|
||||
svc := &agentService{
|
||||
sm: NewStateMachine(logger, cmp),
|
||||
sm: sm,
|
||||
eventSvc: eventSvc,
|
||||
quoteProvider: quoteProvider,
|
||||
logger: logger,
|
||||
computation: cmp,
|
||||
}
|
||||
|
||||
svc.sm.StateFunctions[Idle] = svc.publishEvent(IdleState.String(), json.RawMessage{})
|
||||
svc.sm.StateFunctions[ReceivingManifest] = svc.publishEvent(InProgress.String(), json.RawMessage{})
|
||||
svc.sm.StateFunctions[ReceivingAlgorithm] = svc.publishEvent(InProgress.String(), json.RawMessage{})
|
||||
svc.sm.StateFunctions[ReceivingData] = svc.publishEvent(InProgress.String(), json.RawMessage{})
|
||||
svc.sm.StateFunctions[ConsumingResults] = svc.publishEvent(Ready.String(), json.RawMessage{})
|
||||
svc.sm.StateFunctions[Complete] = svc.publishEvent(Completed.String(), json.RawMessage{})
|
||||
svc.sm.StateFunctions[Running] = svc.runComputation
|
||||
svc.sm.StateFunctions[Failed] = svc.publishEvent(Failed.String(), json.RawMessage{})
|
||||
transitions := []statemachine.Transition{
|
||||
{From: Idle, Event: Start, To: ReceivingManifest},
|
||||
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
|
||||
}
|
||||
|
||||
go svc.sm.Start(ctx)
|
||||
svc.sm.SendEvent(start)
|
||||
if len(cmp.Datasets) == 0 {
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
|
||||
} else {
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
|
||||
}
|
||||
|
||||
svc.computation = cmp
|
||||
transitions = append(transitions, []statemachine.Transition{
|
||||
{From: Running, Event: RunComplete, To: ConsumingResults},
|
||||
{From: Running, Event: RunFailed, To: Failed},
|
||||
{From: ConsumingResults, Event: ResultsConsumed, To: Complete},
|
||||
}...)
|
||||
|
||||
for _, t := range transitions {
|
||||
sm.AddTransition(t)
|
||||
}
|
||||
|
||||
sm.SetAction(Idle, svc.publishEvent(IdleState.String()))
|
||||
sm.SetAction(ReceivingManifest, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(Running, svc.runComputation)
|
||||
sm.SetAction(ConsumingResults, svc.publishEvent(Ready.String()))
|
||||
sm.SetAction(Complete, svc.publishEvent(Completed.String()))
|
||||
sm.SetAction(Failed, svc.publishEvent(Failed.String()))
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != nil {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
}()
|
||||
sm.SendEvent(Start)
|
||||
defer sm.SendEvent(ManifestReceived)
|
||||
|
||||
svc.sm.SendEvent(manifestReceived)
|
||||
return svc
|
||||
}
|
||||
|
||||
@@ -153,7 +222,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
|
||||
switch algoType {
|
||||
case string(algorithm.AlgoTypeBin):
|
||||
as.algorithm = binary.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args)
|
||||
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
|
||||
case string(algorithm.AlgoTypePython):
|
||||
var requirementsFile string
|
||||
if len(algo.Requirements) > 0 {
|
||||
@@ -171,11 +240,11 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
requirementsFile = fr.Name()
|
||||
}
|
||||
runtime := python.PythonRunTimeFromContext(ctx)
|
||||
as.algorithm = python.NewAlgorithm(as.sm.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args)
|
||||
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args)
|
||||
case string(algorithm.AlgoTypeWasm):
|
||||
as.algorithm = wasm.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name(), args)
|
||||
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
|
||||
case string(algorithm.AlgoTypeDocker):
|
||||
as.algorithm = docker.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name())
|
||||
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name())
|
||||
}
|
||||
|
||||
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
|
||||
@@ -183,7 +252,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
}
|
||||
|
||||
if as.algorithm != nil {
|
||||
as.sm.SendEvent(algorithmReceived)
|
||||
as.sm.SendEvent(AlgorithmReceived)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -236,27 +305,30 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
}
|
||||
|
||||
if len(as.computation.Datasets) == 0 {
|
||||
defer as.sm.SendEvent(dataReceived)
|
||||
defer as.sm.SendEvent(DataReceived)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (as *agentService) Result(ctx context.Context) ([]byte, error) {
|
||||
if as.sm.GetState() != ConsumingResults && as.sm.GetState() != Failed {
|
||||
currentState := as.sm.GetState()
|
||||
if currentState != ConsumingResults && currentState != Complete && currentState != Failed {
|
||||
return []byte{}, ErrResultsNotReady
|
||||
}
|
||||
if len(as.computation.ResultConsumers) == 0 {
|
||||
return []byte{}, ErrAllResultsConsumed
|
||||
}
|
||||
|
||||
index, ok := IndexFromContext(ctx)
|
||||
if !ok {
|
||||
return []byte{}, ErrUndeclaredConsumer
|
||||
}
|
||||
as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1)
|
||||
|
||||
if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults {
|
||||
defer as.sm.SendEvent(resultsConsumed)
|
||||
if index < 0 || index >= len(as.computation.ResultConsumers) {
|
||||
return []byte{}, ErrUndeclaredConsumer
|
||||
}
|
||||
|
||||
if !as.resultsConsumed && currentState == ConsumingResults {
|
||||
as.resultsConsumed = true
|
||||
defer as.sm.SendEvent(ResultsConsumed)
|
||||
}
|
||||
|
||||
return as.result, as.runError
|
||||
@@ -271,59 +343,58 @@ func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataS
|
||||
return rawQuote, nil
|
||||
}
|
||||
|
||||
func (as *agentService) runComputation() {
|
||||
as.publishEvent(InProgress.String(), json.RawMessage{})()
|
||||
as.sm.logger.Debug("computation run started")
|
||||
func (as *agentService) runComputation(state statemachine.State) {
|
||||
as.publishEvent(InProgress.String())(state)
|
||||
as.logger.Debug("computation run started")
|
||||
defer func() {
|
||||
if as.runError != nil {
|
||||
as.sm.SendEvent(runFailed)
|
||||
as.sm.SendEvent(RunFailed)
|
||||
} else {
|
||||
as.sm.SendEvent(runComplete)
|
||||
as.sm.SendEvent(RunComplete)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil {
|
||||
as.runError = fmt.Errorf("error creating results directory: %s", err.Error())
|
||||
as.sm.logger.Warn(as.runError.Error())
|
||||
as.publishEvent(Failed.String(), json.RawMessage{})()
|
||||
as.logger.Warn(as.runError.Error())
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
|
||||
as.sm.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
|
||||
as.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error()))
|
||||
}
|
||||
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
|
||||
as.sm.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
|
||||
as.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
|
||||
as.publishEvent(InProgress.String(), json.RawMessage{})()
|
||||
as.publishEvent(InProgress.String())(state)
|
||||
if err := as.algorithm.Run(); err != nil {
|
||||
as.runError = err
|
||||
as.sm.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
|
||||
as.publishEvent(Failed.String(), json.RawMessage{})()
|
||||
as.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error()))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir)
|
||||
if err != nil {
|
||||
as.runError = err
|
||||
as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
|
||||
as.publishEvent(Failed.String(), json.RawMessage{})()
|
||||
as.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error()))
|
||||
as.publishEvent(Failed.String())(state)
|
||||
return
|
||||
}
|
||||
|
||||
as.publishEvent(Completed.String(), json.RawMessage{})()
|
||||
as.publishEvent(Completed.String())(state)
|
||||
|
||||
as.result = results
|
||||
}
|
||||
|
||||
func (as *agentService) publishEvent(status string, details json.RawMessage) func() {
|
||||
return func() {
|
||||
st := as.sm.GetState().String()
|
||||
if err := as.eventSvc.SendEvent(st, status, details); err != nil {
|
||||
as.sm.logger.Warn(err.Error())
|
||||
func (as *agentService) publishEvent(status string) statemachine.Action {
|
||||
return func(state statemachine.State) {
|
||||
if err := as.eventSvc.SendEvent(state.String(), status, json.RawMessage{}); err != nil {
|
||||
as.logger.Warn(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+17
-15
@@ -20,6 +20,8 @@ import (
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/quoteprovider"
|
||||
mocks2 "github.com/ultravioletrs/cocos/agent/quoteprovider/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
@@ -249,45 +251,36 @@ func TestResult(t *testing.T) {
|
||||
err error
|
||||
setup func(svc *agentService)
|
||||
ctxSetup func(ctx context.Context) context.Context
|
||||
state statemachine.State
|
||||
}{
|
||||
{
|
||||
name: "Test results not ready",
|
||||
err: ErrResultsNotReady,
|
||||
setup: func(svc *agentService) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Test all results consumed",
|
||||
err: ErrAllResultsConsumed,
|
||||
setup: func(svc *agentService) {
|
||||
svc.sm.SetState(ConsumingResults)
|
||||
svc.computation.ResultConsumers = []ResultConsumer{}
|
||||
},
|
||||
ctxSetup: func(ctx context.Context) context.Context {
|
||||
return IndexToContext(ctx, 0)
|
||||
},
|
||||
state: Running,
|
||||
},
|
||||
{
|
||||
name: "Test undeclared consumer",
|
||||
err: ErrUndeclaredConsumer,
|
||||
setup: func(svc *agentService) {
|
||||
svc.sm.SetState(ConsumingResults)
|
||||
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("user")}}
|
||||
},
|
||||
ctxSetup: func(ctx context.Context) context.Context {
|
||||
return ctx
|
||||
},
|
||||
state: ConsumingResults,
|
||||
},
|
||||
{
|
||||
name: "Test results consumed and event sent",
|
||||
err: nil,
|
||||
setup: func(svc *agentService) {
|
||||
svc.sm.SetState(ConsumingResults)
|
||||
svc.computation.ResultConsumers = []ResultConsumer{{UserKey: []byte("key")}}
|
||||
},
|
||||
ctxSetup: func(ctx context.Context) context.Context {
|
||||
return IndexToContext(ctx, 0)
|
||||
},
|
||||
state: ConsumingResults,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -301,14 +294,23 @@ func TestResult(t *testing.T) {
|
||||
ctx = tc.ctxSetup(ctx)
|
||||
}
|
||||
|
||||
sm := new(smmocks.StateMachine)
|
||||
sm.On("Start", ctx).Return(nil)
|
||||
sm.On("GetState").Return(tc.state)
|
||||
sm.On("SendEvent", mock.Anything).Return()
|
||||
|
||||
svc := &agentService{
|
||||
sm: NewStateMachine(mglog.NewMock(), testComputation(t)),
|
||||
sm: sm,
|
||||
eventSvc: events,
|
||||
quoteProvider: qp,
|
||||
computation: testComputation(t),
|
||||
}
|
||||
|
||||
go svc.sm.Start(ctx)
|
||||
go func() {
|
||||
if err := svc.sm.Start(ctx); err != nil {
|
||||
t.Errorf("Error starting state machine: %v", err)
|
||||
}
|
||||
}()
|
||||
tc.setup(svc)
|
||||
_, err := svc.Result(ctx)
|
||||
t.Cleanup(func() {
|
||||
|
||||
-150
@@ -1,150 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
//go:generate stringer -type=State
|
||||
type State uint8
|
||||
|
||||
const (
|
||||
Idle State = iota
|
||||
ReceivingManifest
|
||||
ReceivingAlgorithm
|
||||
ReceivingData
|
||||
Running
|
||||
ConsumingResults
|
||||
Complete
|
||||
Failed
|
||||
AlgorithmRun
|
||||
)
|
||||
|
||||
//go:generate stringer -type=Status
|
||||
type Status uint8
|
||||
|
||||
const (
|
||||
IdleState Status = iota
|
||||
InProgress
|
||||
Ready
|
||||
Completed
|
||||
Terminated
|
||||
Warning
|
||||
)
|
||||
|
||||
type event uint8
|
||||
|
||||
const (
|
||||
start event = iota
|
||||
manifestReceived
|
||||
algorithmReceived
|
||||
dataReceived
|
||||
runComplete
|
||||
resultsConsumed
|
||||
runFailed
|
||||
)
|
||||
|
||||
// StateMachine represents the state machine.
|
||||
type StateMachine struct {
|
||||
mu sync.Mutex
|
||||
State State
|
||||
EventChan chan event
|
||||
Transitions map[State]map[event]State
|
||||
StateFunctions map[State]func()
|
||||
logger *slog.Logger
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new StateMachine.
|
||||
func NewStateMachine(logger *slog.Logger, cmp Computation) *StateMachine {
|
||||
sm := &StateMachine{
|
||||
State: Idle,
|
||||
EventChan: make(chan event),
|
||||
Transitions: make(map[State]map[event]State),
|
||||
StateFunctions: make(map[State]func()),
|
||||
logger: logger,
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
|
||||
sm.Transitions[Idle] = make(map[event]State)
|
||||
sm.Transitions[Idle][start] = ReceivingManifest
|
||||
|
||||
sm.Transitions[ReceivingManifest] = make(map[event]State)
|
||||
sm.Transitions[ReceivingManifest][manifestReceived] = ReceivingAlgorithm
|
||||
|
||||
sm.Transitions[ReceivingAlgorithm] = make(map[event]State)
|
||||
switch len(cmp.Datasets) {
|
||||
case 0:
|
||||
sm.Transitions[ReceivingAlgorithm][algorithmReceived] = Running
|
||||
default:
|
||||
sm.Transitions[ReceivingAlgorithm][algorithmReceived] = ReceivingData
|
||||
}
|
||||
|
||||
sm.Transitions[ReceivingData] = make(map[event]State)
|
||||
sm.Transitions[ReceivingData][dataReceived] = Running
|
||||
|
||||
sm.Transitions[Running] = make(map[event]State)
|
||||
sm.Transitions[Running][runComplete] = ConsumingResults
|
||||
sm.Transitions[Running][runFailed] = Failed
|
||||
|
||||
sm.Transitions[ConsumingResults] = make(map[event]State)
|
||||
sm.Transitions[ConsumingResults][resultsConsumed] = Complete
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// Start the state machine.
|
||||
func (sm *StateMachine) Start(ctx context.Context) {
|
||||
sm.wg.Add(1)
|
||||
defer sm.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case event := <-sm.EventChan:
|
||||
currentState := sm.GetState()
|
||||
var nextState State
|
||||
var stateFunc func()
|
||||
var valid bool
|
||||
|
||||
sm.mu.Lock()
|
||||
nextState, valid = sm.Transitions[sm.State][event]
|
||||
if valid {
|
||||
sm.State = nextState
|
||||
stateFunc = sm.StateFunctions[nextState]
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
if valid {
|
||||
sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState))
|
||||
if stateFunc != nil {
|
||||
go stateFunc()
|
||||
}
|
||||
} else {
|
||||
sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State))
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendEvent sends an event to the state machine.
|
||||
func (sm *StateMachine) SendEvent(event event) {
|
||||
sm.EventChan <- event
|
||||
}
|
||||
|
||||
func (sm *StateMachine) GetState() State {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
return sm.State
|
||||
}
|
||||
|
||||
func (sm *StateMachine) SetState(state State) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.State = state
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Code generated by "stringer -type=State"; DO NOT EDIT.
|
||||
|
||||
package agent
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Idle-0]
|
||||
_ = x[ReceivingManifest-1]
|
||||
_ = x[ReceivingAlgorithm-2]
|
||||
_ = x[ReceivingData-3]
|
||||
_ = x[Running-4]
|
||||
_ = x[ConsumingResults-5]
|
||||
_ = x[Complete-6]
|
||||
_ = x[Failed-7]
|
||||
_ = x[AlgorithmRun-8]
|
||||
}
|
||||
|
||||
const _State_name = "IdleReceivingManifestReceivingAlgorithmReceivingDataRunningConsumingResultsCompleteFailedAlgorithmRun"
|
||||
|
||||
var _State_index = [...]uint8{0, 4, 21, 39, 52, 59, 75, 83, 89, 101}
|
||||
|
||||
func (i State) String() string {
|
||||
if i >= State(len(_State_index)-1) {
|
||||
return "State(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _State_name[_State_index[i]:_State_index[i+1]]
|
||||
}
|
||||
+161
-54
@@ -4,75 +4,182 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
sync "sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
)
|
||||
|
||||
var cmp = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Dataset: []byte("test"),
|
||||
UserKey: []byte("test"),
|
||||
},
|
||||
},
|
||||
type MockState int
|
||||
|
||||
type MockEvent int
|
||||
|
||||
func (s MockState) String() string {
|
||||
return []string{"State1", "State2", "State3"}[s]
|
||||
}
|
||||
|
||||
func TestStateMachineTransitions(t *testing.T) {
|
||||
cases := []struct {
|
||||
fromState State
|
||||
event event
|
||||
expected State
|
||||
cmp Computation
|
||||
}{
|
||||
{Idle, start, ReceivingManifest, cmp},
|
||||
{ReceivingManifest, manifestReceived, ReceivingAlgorithm, cmp},
|
||||
{ReceivingAlgorithm, algorithmReceived, ReceivingData, cmp},
|
||||
{ReceivingAlgorithm, algorithmReceived, Running, Computation{}},
|
||||
{ReceivingData, dataReceived, Running, cmp},
|
||||
{Running, runComplete, ConsumingResults, cmp},
|
||||
{ConsumingResults, resultsConsumed, Complete, cmp},
|
||||
func (e MockEvent) String() string {
|
||||
return []string{"Event1", "Event2", "Event3"}[e]
|
||||
}
|
||||
|
||||
const (
|
||||
State1 MockState = iota
|
||||
State2
|
||||
State3
|
||||
)
|
||||
|
||||
const (
|
||||
Event1 MockEvent = iota
|
||||
Event2
|
||||
Event3
|
||||
)
|
||||
|
||||
func TestNewStateMachine(t *testing.T) {
|
||||
sm := statemachine.NewStateMachine(State1)
|
||||
if sm == nil {
|
||||
t.Fatal("NewStateMachine returned nil")
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(fmt.Sprintf("Transition from %v to %v", tc.fromState, tc.expected), func(t *testing.T) {
|
||||
sm := NewStateMachine(mglog.NewMock(), tc.cmp)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go sm.Start(ctx)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sm.SetState(tc.fromState)
|
||||
sm.SendEvent(tc.event)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if sm.GetState() != tc.expected {
|
||||
t.Errorf("Expected state %v after the event, but got %v", tc.expected, sm.GetState())
|
||||
}
|
||||
})
|
||||
if sm.GetState() != State1 {
|
||||
t.Errorf("Initial state not set correctly, got %v, want %v", sm.GetState(), State1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateMachineInvalidTransition(t *testing.T) {
|
||||
sm := NewStateMachine(mglog.NewMock(), cmp)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func TestAddTransition(t *testing.T) {
|
||||
sm := statemachine.NewStateMachine(State1)
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go sm.Start(ctx)
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sm.SendEvent(Event1)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
sm.SetState(Idle)
|
||||
sm.SendEvent(dataReceived)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if sm.GetState() != Idle {
|
||||
t.Errorf("State should not change on an invalid event, but got %v", sm.GetState())
|
||||
if sm.GetState() != State2 {
|
||||
t.Errorf("Transition not applied correctly, got state %v, want %v", sm.GetState(), State2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAction(t *testing.T) {
|
||||
sm := statemachine.NewStateMachine(State1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
sm.SetAction(State2, func(s statemachine.State) {
|
||||
defer wg.Done()
|
||||
})
|
||||
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sm.SendEvent(Event1)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if ctx.Err() != nil {
|
||||
t.Error("Action was not called within the expected time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidTransition(t *testing.T) {
|
||||
sm := statemachine.NewStateMachine(State1)
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- sm.Start(ctx)
|
||||
}()
|
||||
|
||||
sm.SendEvent(Event2)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err == nil {
|
||||
t.Errorf("Expected invalid transition error, got: %v", err)
|
||||
}
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
t.Error("Timeout waiting for invalid transition error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleTransitions(t *testing.T) {
|
||||
sm := statemachine.NewStateMachine(State1)
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State3})
|
||||
sm.AddTransition(statemachine.Transition{From: State3, Event: Event3, To: State1})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err != context.Canceled {
|
||||
t.Errorf("Start returned error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
transitions := []struct {
|
||||
event MockEvent
|
||||
want MockState
|
||||
}{
|
||||
{Event1, State2},
|
||||
{Event2, State3},
|
||||
{Event3, State1},
|
||||
}
|
||||
|
||||
for _, tt := range transitions {
|
||||
sm.SendEvent(tt.event)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if sm.GetState() != tt.want {
|
||||
t.Errorf("After event %v, got state %v, want %v", tt.event, sm.GetState(), tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
sm := statemachine.NewStateMachine(State1)
|
||||
sm.AddTransition(statemachine.Transition{From: State1, Event: Event1, To: State2})
|
||||
sm.AddTransition(statemachine.Transition{From: State2, Event: Event2, To: State1})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
if err := sm.Start(ctx); err == nil {
|
||||
t.Errorf("Expected context error, got nil")
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
sm.SendEvent(Event1)
|
||||
sm.SendEvent(Event2)
|
||||
}()
|
||||
}
|
||||
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
|
||||
finalState := sm.GetState()
|
||||
if finalState != State1 && finalState != State2 {
|
||||
t.Errorf("Unexpected final state: %v", finalState)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent/statemachine"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// StateMachine is an autogenerated mock type for the StateMachine type
|
||||
type StateMachine struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// AddTransition provides a mock function with given fields: t
|
||||
func (_m *StateMachine) AddTransition(t agent.Transition) {
|
||||
_m.Called(t)
|
||||
}
|
||||
|
||||
// GetState provides a mock function with given fields:
|
||||
func (_m *StateMachine) GetState() agent.State {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetState")
|
||||
}
|
||||
|
||||
var r0 agent.State
|
||||
if rf, ok := ret.Get(0).(func() agent.State); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(agent.State)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: event
|
||||
func (_m *StateMachine) SendEvent(event agent.Event) {
|
||||
_m.Called(event)
|
||||
}
|
||||
|
||||
// SetAction provides a mock function with given fields: state, action
|
||||
func (_m *StateMachine) SetAction(state agent.State, action agent.Action) {
|
||||
_m.Called(state, action)
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields: ctx
|
||||
func (_m *StateMachine) Start(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Start")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new instance of StateMachine. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewStateMachine(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *StateMachine {
|
||||
mock := &StateMachine{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package statemachine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type State interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type Event interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
type Action func(State)
|
||||
|
||||
type Transition struct {
|
||||
From State
|
||||
Event Event
|
||||
To State
|
||||
}
|
||||
|
||||
//go:generate mockery --name StateMachine --output=mocks --filename state.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type StateMachine interface {
|
||||
AddTransition(t Transition)
|
||||
SetAction(state State, action Action)
|
||||
GetState() State
|
||||
SendEvent(event Event)
|
||||
Start(ctx context.Context) error
|
||||
}
|
||||
|
||||
type stateMachine struct {
|
||||
mu sync.Mutex
|
||||
currentState State
|
||||
transitions map[State]map[Event]State
|
||||
actions map[State]Action
|
||||
eventChan chan Event
|
||||
}
|
||||
|
||||
func NewStateMachine(initialState State) StateMachine {
|
||||
return &stateMachine{
|
||||
currentState: initialState,
|
||||
transitions: make(map[State]map[Event]State),
|
||||
actions: make(map[State]Action),
|
||||
eventChan: make(chan Event),
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *stateMachine) AddTransition(t Transition) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if _, ok := sm.transitions[t.From]; !ok {
|
||||
sm.transitions[t.From] = make(map[Event]State)
|
||||
}
|
||||
sm.transitions[t.From][t.Event] = t.To
|
||||
}
|
||||
|
||||
func (sm *stateMachine) SetAction(state State, action Action) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
sm.actions[state] = action
|
||||
}
|
||||
|
||||
func (sm *stateMachine) GetState() State {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
return sm.currentState
|
||||
}
|
||||
|
||||
func (sm *stateMachine) SendEvent(event Event) {
|
||||
sm.eventChan <- event
|
||||
}
|
||||
|
||||
func (sm *stateMachine) Start(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case event := <-sm.eventChan:
|
||||
if err := sm.handleEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *stateMachine) handleEvent(event Event) error {
|
||||
sm.mu.Lock()
|
||||
currentState := sm.currentState
|
||||
nextState, valid := sm.transitions[currentState][event]
|
||||
sm.mu.Unlock()
|
||||
|
||||
if !valid {
|
||||
return fmt.Errorf("invalid transition: %v -> %v", currentState, event)
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
sm.currentState = nextState
|
||||
action := sm.actions[nextState]
|
||||
sm.mu.Unlock()
|
||||
|
||||
if action != nil {
|
||||
go action(nextState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+29
-2
@@ -4,13 +4,17 @@ package cli
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const resultFilePath = "results.zip"
|
||||
const (
|
||||
resultFilePrefix = "results"
|
||||
resultFileExt = ".zip"
|
||||
)
|
||||
|
||||
func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
@@ -42,12 +46,35 @@ func (cli *CLI) NewResultsCmd() *cobra.Command {
|
||||
return
|
||||
}
|
||||
|
||||
resultFilePath, err := getUniqueFilePath(resultFilePrefix, resultFileExt)
|
||||
if err != nil {
|
||||
printError(cmd, "Error generating unique file path: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resultFilePath, result, 0o644); err != nil {
|
||||
printError(cmd, "Error saving computation result file: %v ❌ ", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmd.Println(color.New(color.FgGreen).Sprint("Computation result retrieved and saved successfully! ✔ "))
|
||||
cmd.Println(color.New(color.FgGreen).Sprintf("Computation result retrieved and saved successfully as %s! ✔ ", resultFilePath))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getUniqueFilePath(prefix, ext string) (string, error) {
|
||||
for i := 0; ; i++ {
|
||||
var filename string
|
||||
if i == 0 {
|
||||
filename = prefix + ext
|
||||
} else {
|
||||
filename = fmt.Sprintf("%s_%d%s", prefix, i, ext)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filename); os.IsNotExist(err) {
|
||||
return filename, nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+72
-6
@@ -1,11 +1,14 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/mock"
|
||||
@@ -32,12 +35,50 @@ func TestResultsCmd_Success(t *testing.T) {
|
||||
|
||||
require.Contains(t, buf.String(), "Computation result retrieved and saved successfully")
|
||||
|
||||
resultFile, err := os.ReadFile("results.zip")
|
||||
files, err := filepath.Glob("results*.zip")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1)
|
||||
|
||||
resultFile, err := os.ReadFile(files[0])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, compResult, string(resultFile))
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove("results.zip")
|
||||
for _, file := range files {
|
||||
os.Remove(file)
|
||||
}
|
||||
os.Remove(privateKeyFile)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResultsCmd_MultipleExecutions(t *testing.T) {
|
||||
mockSDK := new(mocks.SDK)
|
||||
mockSDK.On("Result", mock.Anything, mock.Anything).Return([]byte(compResult), nil)
|
||||
testCLI := New(mockSDK)
|
||||
|
||||
err := generateRSAPrivateKeyFile(privateKeyFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd := testCLI.NewResultsCmd()
|
||||
buf := new(bytes.Buffer)
|
||||
cmd.SetOut(buf)
|
||||
cmd.SetArgs([]string{privateKeyFile})
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
err = cmd.Execute()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, buf.String(), "Computation result retrieved and saved successfully")
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
files, err := filepath.Glob("results*.zip")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 3)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, file := range files {
|
||||
os.Remove(file)
|
||||
}
|
||||
os.Remove(privateKeyFile)
|
||||
})
|
||||
}
|
||||
@@ -87,8 +128,8 @@ func TestResultsCmd_SaveFailure(t *testing.T) {
|
||||
err := generateRSAPrivateKeyFile(privateKeyFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate failure in saving the result file by making a directory with the same name as the result file
|
||||
err = os.Mkdir("results.zip", 0o755)
|
||||
// Simulate failure in saving the result file by making all files read-only
|
||||
err = os.Chmod(".", 0o555)
|
||||
require.NoError(t, err)
|
||||
|
||||
cmd := testCLI.NewResultsCmd()
|
||||
@@ -102,8 +143,10 @@ func TestResultsCmd_SaveFailure(t *testing.T) {
|
||||
mockSDK.AssertCalled(t, "Result", mock.Anything, mock.Anything)
|
||||
|
||||
t.Cleanup(func() {
|
||||
os.Remove("results.zip")
|
||||
os.Remove(privateKeyFile)
|
||||
err := os.Chmod(".", 0o755)
|
||||
require.NoError(t, err)
|
||||
err = os.Remove(privateKeyFile)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -132,3 +175,26 @@ func TestResultsCmd_InvalidPrivateKey(t *testing.T) {
|
||||
require.Contains(t, buf.String(), "Error decoding private key")
|
||||
mockSDK.AssertNotCalled(t, "Result", mock.Anything, mock.Anything)
|
||||
}
|
||||
|
||||
func TestGetUniqueFilePath(t *testing.T) {
|
||||
prefix := "test"
|
||||
ext := ".txt"
|
||||
|
||||
path, err := getUniqueFilePath(prefix, ext)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test.txt", path)
|
||||
|
||||
_, err = os.Create("test.txt")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("test.txt")
|
||||
for i := 1; i < 3; i++ {
|
||||
fileName := fmt.Sprintf("%s_%d%s", prefix, i, ext)
|
||||
_, err := os.Create(fileName)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(fileName)
|
||||
}
|
||||
|
||||
path, err = getUniqueFilePath(prefix, ext)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test_3.txt", path)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user