mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-168 - Allow running Computations without datasets (#175)
* feat(agent): Allow empty dataset Allow running of algorithm with empty dataset since not all algorithms require datasets. Allow state-machine transition from algo-received state to running state incase of no dataset provided Fixes https://github.com/ultravioletrs/cocos/issues/168 Signed-off-by: Rodney Osodo <socials@rodneyosodo.com> * chore(gitignore): Remove build artefacts Signed-off-by: Rodney Osodo <socials@rodneyosodo.com> * feat(algorithms): Add test algorithm for addition Signed-off-by: Rodney Osodo <socials@rodneyosodo.com> * refactor(addition): Modify addition algo to one file Signed-off-by: Rodney Osodo <socials@rodneyosodo.com> * fix(agent): move state transition to callback func Move state transition from `receivingAlgorithm` to `running` to state call back function Signed-off-by: Rodney Osodo <socials@rodneyosodo.com> * feat(agent-event): Add `algoReceivedNoData` event `algoReceivedNoData` is an event that is sent if we receive an algorithm and it should not have a dataset hence changes the state from `receivingAlgorithm` to `running` * fix(agent-state): Change state depending on manifest Change state from `receivingAlgorithm` to either `receivingData` if there is a dataset or `running` if there is no dataset provided Signed-off-by: Rodney Osodo <socials@rodneyosodo.com> --------- Signed-off-by: Rodney Osodo <socials@rodneyosodo.com>
This commit is contained in:
@@ -8,3 +8,7 @@ cmd/manager/tmp
|
||||
.cov
|
||||
|
||||
*.pem
|
||||
|
||||
dist/
|
||||
result.bin
|
||||
*.spec
|
||||
|
||||
+1
-1
@@ -73,7 +73,7 @@ var _ Service = (*agentService)(nil)
|
||||
// New instantiates the agent service implementation.
|
||||
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation) Service {
|
||||
svc := &agentService{
|
||||
sm: NewStateMachine(logger),
|
||||
sm: NewStateMachine(logger, cmp),
|
||||
eventSvc: eventSvc,
|
||||
}
|
||||
|
||||
|
||||
+9
-4
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
//go:generate stringer -type=state
|
||||
type state int
|
||||
type state uint8
|
||||
|
||||
const (
|
||||
idle state = iota
|
||||
@@ -22,7 +22,7 @@ const (
|
||||
complete
|
||||
)
|
||||
|
||||
type event int
|
||||
type event uint8
|
||||
|
||||
const (
|
||||
start event = iota
|
||||
@@ -45,7 +45,7 @@ type StateMachine struct {
|
||||
}
|
||||
|
||||
// NewStateMachine creates a new StateMachine.
|
||||
func NewStateMachine(logger *slog.Logger) *StateMachine {
|
||||
func NewStateMachine(logger *slog.Logger, cmp Computation) *StateMachine {
|
||||
sm := &StateMachine{
|
||||
State: idle,
|
||||
EventChan: make(chan event),
|
||||
@@ -62,7 +62,12 @@ func NewStateMachine(logger *slog.Logger) *StateMachine {
|
||||
sm.Transitions[receivingManifest][manifestReceived] = receivingAlgorithm
|
||||
|
||||
sm.Transitions[receivingAlgorithm] = make(map[event]state)
|
||||
sm.Transitions[receivingAlgorithm][algorithmReceived] = receivingData
|
||||
switch len(cmp.Datasets) {
|
||||
case 0:
|
||||
sm.Transitions[receivingAlgorithm][algorithmReceived] = running
|
||||
default:
|
||||
sm.Transitions[receivingAlgorithm][algorithmReceived] = receivingData
|
||||
}
|
||||
|
||||
sm.Transitions[receivingData] = make(map[event]state)
|
||||
sm.Transitions[receivingData][dataReceived] = running
|
||||
|
||||
+26
-15
@@ -10,34 +10,45 @@ import (
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
)
|
||||
|
||||
var cmp = Computation{
|
||||
Datasets: []Dataset{
|
||||
{
|
||||
Dataset: []byte("test"),
|
||||
UserKey: []byte("test"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestStateMachineTransitions(t *testing.T) {
|
||||
testCases := []struct {
|
||||
cases := []struct {
|
||||
fromState state
|
||||
event event
|
||||
expected state
|
||||
cmp Computation
|
||||
}{
|
||||
{idle, start, receivingManifest},
|
||||
{receivingManifest, manifestReceived, receivingAlgorithm},
|
||||
{receivingAlgorithm, algorithmReceived, receivingData},
|
||||
{receivingData, dataReceived, running},
|
||||
{running, runComplete, resultsReady},
|
||||
{resultsReady, resultsConsumed, complete},
|
||||
{idle, start, receivingManifest, cmp},
|
||||
{receivingManifest, manifestReceived, receivingAlgorithm, cmp},
|
||||
{receivingAlgorithm, algorithmReceived, receivingData, cmp},
|
||||
{receivingAlgorithm, algorithmReceived, running, Computation{}},
|
||||
{receivingData, dataReceived, running, cmp},
|
||||
{running, runComplete, resultsReady, cmp},
|
||||
{resultsReady, resultsConsumed, complete, cmp},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(fmt.Sprintf("Transition from %v to %v", testCase.fromState, testCase.expected), func(t *testing.T) {
|
||||
sm := NewStateMachine(mglog.NewMock())
|
||||
for _, tc := range cases {
|
||||
t.Run(fmt.Sprintf("Transition from %v to %v", tc.fromState, tc.expected), func(t *testing.T) {
|
||||
sm := NewStateMachine(mglog.NewMock(), tc.cmp)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
sm.Start(ctx)
|
||||
}()
|
||||
sm.wg.Wait()
|
||||
sm.SetState(testCase.fromState)
|
||||
sm.SetState(tc.fromState)
|
||||
|
||||
sm.SendEvent(testCase.event)
|
||||
sm.SendEvent(tc.event)
|
||||
|
||||
if sm.GetState() != testCase.expected {
|
||||
t.Errorf("Expected state %v after the event, but got %v", testCase.expected, sm.GetState())
|
||||
if sm.GetState() != tc.expected {
|
||||
t.Errorf("Expected state %v after the event, but got %v", tc.expected, sm.GetState())
|
||||
}
|
||||
close(sm.EventChan)
|
||||
cancel()
|
||||
@@ -46,7 +57,7 @@ func TestStateMachineTransitions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateMachineInvalidTransition(t *testing.T) {
|
||||
sm := NewStateMachine(mglog.NewMock())
|
||||
sm := NewStateMachine(mglog.NewMock(), cmp)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go sm.Start(ctx)
|
||||
|
||||
|
||||
@@ -49,26 +49,34 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au
|
||||
s.logger.Error(fmt.Sprintf("failed to read algorithm file: %s", err))
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
s.logger.Error(fmt.Sprintf("failed to read data file: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
pubKey, err := os.ReadFile(pubKeyFile)
|
||||
if err != nil {
|
||||
s.logger.Error(fmt.Sprintf("failed to read public key file: %s", err))
|
||||
return
|
||||
}
|
||||
pubPem, _ := pem.Decode(pubKey)
|
||||
|
||||
var dataset []*manager.Dataset
|
||||
if dataPath != "" {
|
||||
data, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
s.logger.Error(fmt.Sprintf("failed to read data file: %s", err))
|
||||
return
|
||||
}
|
||||
dataHash := sha3.Sum256(data)
|
||||
|
||||
dataset = []*manager.Dataset{{Hash: dataHash[:], UserKey: pubPem.Bytes}}
|
||||
}
|
||||
|
||||
algoHash := sha3.Sum256(algo)
|
||||
dataHash := sha3.Sum256(data)
|
||||
reqChan <- &manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: &manager.ComputationRunReq{
|
||||
Id: "1",
|
||||
Name: "sample computation",
|
||||
Description: "sample descrption",
|
||||
Datasets: []*manager.Dataset{{Hash: dataHash[:], UserKey: pubPem.Bytes}},
|
||||
Datasets: dataset,
|
||||
Algorithm: &manager.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
|
||||
ResultConsumers: []*manager.ResultConsumer{{UserKey: pubPem.Bytes}},
|
||||
AgentConfig: &manager.AgentConfig{
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import sys, io
|
||||
import joblib
|
||||
import socket
|
||||
|
||||
class Computation:
|
||||
result = 0
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes a new instance of the Computation class.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute(self, a, b):
|
||||
"""
|
||||
Computes the sum of two numbers.
|
||||
"""
|
||||
self.result = a + b
|
||||
|
||||
def send_result(self, socket_path):
|
||||
"""
|
||||
Sends the result to a socket.
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
|
||||
try:
|
||||
joblib.dump(self.result, buffer)
|
||||
except Exception as e:
|
||||
print("Failed to dump the result to the buffer: ", e)
|
||||
return
|
||||
|
||||
data = buffer.getvalue()
|
||||
|
||||
client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
try:
|
||||
client.connect(socket_path)
|
||||
except Exception as e:
|
||||
print("Failed to connect to the socket: ", e)
|
||||
return
|
||||
try:
|
||||
client.send(data)
|
||||
except Exception as e:
|
||||
print("Failed to send data to the socket: ", e)
|
||||
return
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
def read_results_from_file(self, results_file):
|
||||
"""
|
||||
Reads the results from a file.
|
||||
"""
|
||||
try:
|
||||
results = joblib.load(results_file)
|
||||
print("Results: ", results)
|
||||
except Exception as e:
|
||||
print("Failed to load results from file: ", e)
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = 5
|
||||
b = 10
|
||||
computation = Computation()
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
print("Please provide a socket path or a file path")
|
||||
exit(1)
|
||||
|
||||
if sys.argv[1] == "test" and len(sys.argv) == 3:
|
||||
computation.read_results_from_file(sys.argv[2])
|
||||
elif len(sys.argv) == 2:
|
||||
computation.compute(a, b)
|
||||
computation.send_result(sys.argv[1])
|
||||
else:
|
||||
print("Invalid arguments")
|
||||
exit(1)
|
||||
|
||||
Reference in New Issue
Block a user