mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-152 - Refactor algorithm execution and add visibility through logging and events (#176)
* * feat(agent): add support for binary algorithm execution * * feat(agent/algorithm): add Algorithm interface and binary implementation * * feat(agent/algorithm/binary): implement Run method for binary algorithm execution * * feat(agent/algorithm/logging): implement Stdout and Stderr writers for algorithm logging * * feat(agent/algorithm/logging_test): add tests for Stdout and Stderr writers * * feat(agent/events): add Service interface for sending events * * feat(agent/events/mocks): add mock implementation for Service interface * * refactor(agent/service): update runComputation method to use binary algorithm implementation Signed-off-by: SammyOina <sammyoina@gmail.com> * * fix(logging.go): handle error when sending event in Write method of Stderr struct * test(logging_test.go): add copyright header * fix(backend_info.go): add missing type declaration in function signature * fix(agent.go): rename progressbar variable to pb for clarity and consistency Signed-off-by: SammyOina <sammyoina@gmail.com> --------- Signed-off-by: SammyOina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
dc16e8a997
commit
2ceb1c3562
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package algorithm
|
||||
|
||||
// Algorithm is an interface that specifies the API for an algorithm.
|
||||
type Algorithm interface {
|
||||
// Run executes the algorithm and returns the result.
|
||||
Run() ([]byte, error)
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package binary
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"github.com/ultravioletrs/cocos/pkg/socket"
|
||||
)
|
||||
|
||||
const socketPath = "unix_socket"
|
||||
|
||||
var _ algorithm.Algorithm = (*binary)(nil)
|
||||
|
||||
type binary struct {
|
||||
algoFile string
|
||||
datasets []string
|
||||
logger *slog.Logger
|
||||
stderr io.Writer
|
||||
stdout io.Writer
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, eventsSvc events.Service, algoFile string, datasets ...string) algorithm.Algorithm {
|
||||
return &binary{
|
||||
algoFile: algoFile,
|
||||
datasets: datasets,
|
||||
logger: logger,
|
||||
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &algorithm.Stdout{Logger: logger},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *binary) Run() ([]byte, error) {
|
||||
defer os.Remove(b.algoFile)
|
||||
defer func() {
|
||||
for _, file := range b.datasets {
|
||||
os.Remove(file)
|
||||
}
|
||||
}()
|
||||
listener, err := socket.StartUnixSocketServer(socketPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating stdout pipe: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Create channels for received data and errors
|
||||
dataChannel := make(chan []byte)
|
||||
errorChannel := make(chan error)
|
||||
|
||||
var result []byte
|
||||
|
||||
go socket.AcceptConnection(listener, dataChannel, errorChannel)
|
||||
|
||||
args := append([]string{socketPath}, b.datasets...)
|
||||
cmd := exec.Command(b.algoFile, args...)
|
||||
cmd.Stderr = b.stderr
|
||||
cmd.Stdout = b.stdout
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return nil, fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case result = <-dataChannel:
|
||||
return result, nil
|
||||
case err = <-errorChannel:
|
||||
return nil, fmt.Errorf("error receiving data: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package algorithm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
var (
|
||||
_ io.Writer = &Stdout{}
|
||||
_ io.Writer = &Stderr{}
|
||||
)
|
||||
|
||||
const bufSize = 1024
|
||||
|
||||
type Stdout struct {
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (s *Stdout) Write(p []byte) (n int, err error) {
|
||||
inBuf := bytes.NewBuffer(p)
|
||||
|
||||
buf := make([]byte, bufSize)
|
||||
|
||||
for {
|
||||
n, err := inBuf.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
|
||||
s.Logger.Debug(string(buf[:n]))
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
type Stderr struct {
|
||||
Logger *slog.Logger
|
||||
EventSvc events.Service
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (s *Stderr) Write(p []byte) (n int, err error) {
|
||||
inBuf := bytes.NewBuffer(p)
|
||||
|
||||
buf := make([]byte, bufSize)
|
||||
|
||||
for {
|
||||
n, err := inBuf.Read(buf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
|
||||
s.Logger.Error(string(buf[:n]))
|
||||
}
|
||||
|
||||
if err := s.EventSvc.SendEvent("algorithm-run", "failed", nil); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package algorithm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
func TestStdoutWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Single line",
|
||||
input: "Hello, World!",
|
||||
expected: []string{"Hello, World!"},
|
||||
},
|
||||
{
|
||||
name: "Multiple lines",
|
||||
input: "Line 1\nLine 2\nLine 3",
|
||||
expected: []string{"Line 1\nLine 2\nLine 3"},
|
||||
},
|
||||
{
|
||||
name: "Long input",
|
||||
input: strings.Repeat("a", bufSize+100),
|
||||
expected: []string{strings.Repeat("a", bufSize), strings.Repeat("a", 100)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stdout := &Stdout{Logger: mglog.NewMock()}
|
||||
n, err := stdout.Write([]byte(tt.input))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.input), n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStderrWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Single line",
|
||||
input: "Error: Something went wrong",
|
||||
expected: []string{"Error: Something went wrong"},
|
||||
},
|
||||
{
|
||||
name: "Multiple lines",
|
||||
input: "Error 1\nError 2\nError 3",
|
||||
expected: []string{"Error 1\nError 2\nError 3"},
|
||||
},
|
||||
{
|
||||
name: "Long input",
|
||||
input: strings.Repeat("e", bufSize+100),
|
||||
expected: []string{strings.Repeat("e", bufSize), strings.Repeat("e", 100)},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockEventService := mocks.NewService(t)
|
||||
mockEventService.On("SendEvent", "algorithm-run", "failed", mock.Anything).Return(nil)
|
||||
|
||||
stderr := &Stderr{Logger: mglog.NewMock(), EventSvc: mockEventService}
|
||||
n, err := stderr.Write([]byte(tt.input))
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(tt.input), n)
|
||||
mockEventService.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ type AgentEvent struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
//go:generate mockery --name Service --output=./mocks --filename events.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Service interface {
|
||||
SendEvent(event, status string, details json.RawMessage) error
|
||||
Close() error
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
json "encoding/json"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *Service) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Close")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: event, status, details
|
||||
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
|
||||
ret := _m.Called(event, status, details)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendEvent")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, string, json.RawMessage) error); ok {
|
||||
r0 = rf(event, status, details)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
+3
-51
@@ -4,19 +4,17 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"github.com/ultravioletrs/cocos/pkg/socket"
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
@@ -25,7 +23,6 @@ var _ Service = (*agentService)(nil)
|
||||
const (
|
||||
// ReportDataSize is the size of the report data expected by the attestation service.
|
||||
ReportDataSize = 64
|
||||
socketPath = "unix_socket"
|
||||
algoFilePermission = 0o700
|
||||
)
|
||||
|
||||
@@ -212,7 +209,8 @@ func (as *agentService) runComputation() {
|
||||
as.sm.logger.Debug("computation run started")
|
||||
defer as.sm.SendEvent(runComplete)
|
||||
as.publishEvent("in-progress", json.RawMessage{})()
|
||||
result, err := as.run(as.algorithm, as.datasets)
|
||||
algorithm := binary.New(as.sm.logger, as.eventSvc, as.algorithm, as.datasets...)
|
||||
result, err := algorithm.Run()
|
||||
if err != nil {
|
||||
as.runError = err
|
||||
as.sm.logger.Warn(fmt.Sprintf("computation failed with error: %s", err.Error()))
|
||||
@@ -230,49 +228,3 @@ func (as *agentService) publishEvent(status string, details json.RawMessage) fun
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (as *agentService) run(algoFile string, dataFiles []string) ([]byte, error) {
|
||||
defer os.Remove(algoFile)
|
||||
defer func() {
|
||||
for _, file := range dataFiles {
|
||||
os.Remove(file)
|
||||
}
|
||||
}()
|
||||
listener, err := socket.StartUnixSocketServer(socketPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating stdout pipe: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// Create channels for received data and errors
|
||||
dataChannel := make(chan []byte)
|
||||
errorChannel := make(chan error)
|
||||
|
||||
var result []byte
|
||||
|
||||
var outStd, outErr bytes.Buffer
|
||||
|
||||
go socket.AcceptConnection(listener, dataChannel, errorChannel)
|
||||
|
||||
args := append([]string{socketPath}, dataFiles...)
|
||||
cmd := exec.Command(algoFile, args...)
|
||||
cmd.Stderr = &outErr
|
||||
cmd.Stdout = &outStd
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
as.sm.logger.Debug(outErr.String())
|
||||
return nil, fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case result = <-dataChannel:
|
||||
as.sm.logger.Debug(outStd.String())
|
||||
return result, nil
|
||||
case err = <-errorChannel:
|
||||
return nil, fmt.Errorf("error receiving data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -103,7 +103,7 @@ func (cli *CLI) NewAddHostDataCmd() *cobra.Command {
|
||||
}
|
||||
}
|
||||
|
||||
func changeAttestationConfiguration(fileName string, base64Data string, expectedLength int, field fieldType) error {
|
||||
func changeAttestationConfiguration(fileName, base64Data string, expectedLength int, field fieldType) error {
|
||||
data, err := base64.StdEncoding.DecodeString(base64Data)
|
||||
if err != nil {
|
||||
return errDecode
|
||||
|
||||
+4
-4
@@ -61,8 +61,8 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
|
||||
}
|
||||
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
|
||||
|
||||
progressbar := progressbar.New()
|
||||
if err := progressbar.SendAlgorithm(algoProgressBarDescription, algoBuffer, &stream); err != nil {
|
||||
pb := progressbar.New()
|
||||
if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, &stream); err != nil {
|
||||
sdk.logger.Error("Failed to send Algorithm")
|
||||
return err
|
||||
}
|
||||
@@ -85,8 +85,8 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
|
||||
}
|
||||
dataBuffer := bytes.NewBuffer(dataset.Dataset)
|
||||
|
||||
progressbar := progressbar.New()
|
||||
if err := progressbar.SendData(dataProgressBarDescription, dataBuffer, &stream); err != nil {
|
||||
pb := progressbar.New()
|
||||
if err := pb.SendData(dataProgressBarDescription, dataBuffer, &stream); err != nil {
|
||||
sdk.logger.Error("Failed to send Data")
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user