COCOS-152 - Refactor algorithm execution and add visibility through logging and events (#176)

* * feat(agent): add support for binary algorithm execution
*
* feat(agent/algorithm): add Algorithm interface and binary implementation
*
* feat(agent/algorithm/binary): implement Run method for binary algorithm execution
*
* feat(agent/algorithm/logging): implement Stdout and Stderr writers for algorithm logging
*
* feat(agent/algorithm/logging_test): add tests for Stdout and Stderr writers
*
* feat(agent/events): add Service interface for sending events
*
* feat(agent/events/mocks): add mock implementation for Service interface
*
* refactor(agent/service): update runComputation method to use binary algorithm implementation

Signed-off-by: SammyOina <sammyoina@gmail.com>

* * fix(logging.go): handle error when sending event in Write method of Stderr struct
* test(logging_test.go): add copyright header
* fix(backend_info.go): add missing type declaration in function signature
* fix(agent.go): rename progressbar variable to pb for clarity and consistency

Signed-off-by: SammyOina <sammyoina@gmail.com>

---------

Signed-off-by: SammyOina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-07-11 13:24:19 +03:00
committed by GitHub
parent dc16e8a997
commit 2ceb1c3562
9 changed files with 322 additions and 56 deletions
+9
View File
@@ -0,0 +1,9 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package algorithm
// Algorithm is an interface that specifies the API for an algorithm.
type Algorithm interface {
// Run executes the algorithm and returns the result.
Run() ([]byte, error)
}
+79
View File
@@ -0,0 +1,79 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package binary
import (
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/pkg/socket"
)
const socketPath = "unix_socket"
var _ algorithm.Algorithm = (*binary)(nil)
type binary struct {
algoFile string
datasets []string
logger *slog.Logger
stderr io.Writer
stdout io.Writer
}
func New(logger *slog.Logger, eventsSvc events.Service, algoFile string, datasets ...string) algorithm.Algorithm {
return &binary{
algoFile: algoFile,
datasets: datasets,
logger: logger,
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &algorithm.Stdout{Logger: logger},
}
}
func (b *binary) Run() ([]byte, error) {
defer os.Remove(b.algoFile)
defer func() {
for _, file := range b.datasets {
os.Remove(file)
}
}()
listener, err := socket.StartUnixSocketServer(socketPath)
if err != nil {
return nil, fmt.Errorf("error creating stdout pipe: %v", err)
}
defer listener.Close()
// Create channels for received data and errors
dataChannel := make(chan []byte)
errorChannel := make(chan error)
var result []byte
go socket.AcceptConnection(listener, dataChannel, errorChannel)
args := append([]string{socketPath}, b.datasets...)
cmd := exec.Command(b.algoFile, args...)
cmd.Stderr = b.stderr
cmd.Stdout = b.stdout
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("error starting algorithm: %v", err)
}
if err := cmd.Wait(); err != nil {
return nil, fmt.Errorf("algorithm execution error: %v", err)
}
select {
case result = <-dataChannel:
return result, nil
case err = <-errorChannel:
return nil, fmt.Errorf("error receiving data: %v", err)
}
}
+73
View File
@@ -0,0 +1,73 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package algorithm
import (
"bytes"
"io"
"log/slog"
"github.com/ultravioletrs/cocos/agent/events"
)
var (
_ io.Writer = &Stdout{}
_ io.Writer = &Stderr{}
)
const bufSize = 1024
type Stdout struct {
Logger *slog.Logger
}
// Write implements io.Writer.
func (s *Stdout) Write(p []byte) (n int, err error) {
inBuf := bytes.NewBuffer(p)
buf := make([]byte, bufSize)
for {
n, err := inBuf.Read(buf)
if err != nil {
if err == io.EOF {
break
}
return len(p) - inBuf.Len(), err
}
s.Logger.Debug(string(buf[:n]))
}
return len(p), nil
}
type Stderr struct {
Logger *slog.Logger
EventSvc events.Service
}
// Write implements io.Writer.
func (s *Stderr) Write(p []byte) (n int, err error) {
inBuf := bytes.NewBuffer(p)
buf := make([]byte, bufSize)
for {
n, err := inBuf.Read(buf)
if err != nil {
if err == io.EOF {
break
}
return len(p) - inBuf.Len(), err
}
s.Logger.Error(string(buf[:n]))
}
if err := s.EventSvc.SendEvent("algorithm-run", "failed", nil); err != nil {
return len(p), err
}
return len(p), nil
}
+85
View File
@@ -0,0 +1,85 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package algorithm
import (
"strings"
"testing"
mglog "github.com/absmach/magistrala/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
func TestStdoutWrite(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Single line",
input: "Hello, World!",
expected: []string{"Hello, World!"},
},
{
name: "Multiple lines",
input: "Line 1\nLine 2\nLine 3",
expected: []string{"Line 1\nLine 2\nLine 3"},
},
{
name: "Long input",
input: strings.Repeat("a", bufSize+100),
expected: []string{strings.Repeat("a", bufSize), strings.Repeat("a", 100)},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stdout := &Stdout{Logger: mglog.NewMock()}
n, err := stdout.Write([]byte(tt.input))
assert.NoError(t, err)
assert.Equal(t, len(tt.input), n)
})
}
}
func TestStderrWrite(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Single line",
input: "Error: Something went wrong",
expected: []string{"Error: Something went wrong"},
},
{
name: "Multiple lines",
input: "Error 1\nError 2\nError 3",
expected: []string{"Error 1\nError 2\nError 3"},
},
{
name: "Long input",
input: strings.Repeat("e", bufSize+100),
expected: []string{strings.Repeat("e", bufSize), strings.Repeat("e", 100)},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockEventService := mocks.NewService(t)
mockEventService.On("SendEvent", "algorithm-run", "failed", mock.Anything).Return(nil)
stderr := &Stderr{Logger: mglog.NewMock(), EventSvc: mockEventService}
n, err := stderr.Write([]byte(tt.input))
assert.NoError(t, err)
assert.Equal(t, len(tt.input), n)
mockEventService.AssertExpectations(t)
})
}
}
+1
View File
@@ -27,6 +27,7 @@ type AgentEvent struct {
Status string `json:"status,omitempty"`
}
//go:generate mockery --name Service --output=./mocks --filename events.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type Service interface {
SendEvent(event, status string, details json.RawMessage) error
Close() error
+67
View File
@@ -0,0 +1,67 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
json "encoding/json"
mock "github.com/stretchr/testify/mock"
)
// Service is an autogenerated mock type for the Service type
type Service struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *Service) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// SendEvent provides a mock function with given fields: event, status, details
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
ret := _m.Called(event, status, details)
if len(ret) == 0 {
panic("no return value specified for SendEvent")
}
var r0 error
if rf, ok := ret.Get(0).(func(string, string, json.RawMessage) error); ok {
r0 = rf(event, status, details)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
mock.TestingT
Cleanup(func())
}) *Service {
mock := &Service{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+3 -51
View File
@@ -4,19 +4,17 @@
package agent
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"slices"
"github.com/google/go-sev-guest/client"
"github.com/ultravioletrs/cocos/agent/algorithm/binary"
"github.com/ultravioletrs/cocos/agent/events"
"github.com/ultravioletrs/cocos/pkg/socket"
"golang.org/x/crypto/sha3"
)
@@ -25,7 +23,6 @@ var _ Service = (*agentService)(nil)
const (
// ReportDataSize is the size of the report data expected by the attestation service.
ReportDataSize = 64
socketPath = "unix_socket"
algoFilePermission = 0o700
)
@@ -212,7 +209,8 @@ func (as *agentService) runComputation() {
as.sm.logger.Debug("computation run started")
defer as.sm.SendEvent(runComplete)
as.publishEvent("in-progress", json.RawMessage{})()
result, err := as.run(as.algorithm, as.datasets)
algorithm := binary.New(as.sm.logger, as.eventSvc, as.algorithm, as.datasets...)
result, err := algorithm.Run()
if err != nil {
as.runError = err
as.sm.logger.Warn(fmt.Sprintf("computation failed with error: %s", err.Error()))
@@ -230,49 +228,3 @@ func (as *agentService) publishEvent(status string, details json.RawMessage) fun
}
}
}
func (as *agentService) run(algoFile string, dataFiles []string) ([]byte, error) {
defer os.Remove(algoFile)
defer func() {
for _, file := range dataFiles {
os.Remove(file)
}
}()
listener, err := socket.StartUnixSocketServer(socketPath)
if err != nil {
return nil, fmt.Errorf("error creating stdout pipe: %v", err)
}
defer listener.Close()
// Create channels for received data and errors
dataChannel := make(chan []byte)
errorChannel := make(chan error)
var result []byte
var outStd, outErr bytes.Buffer
go socket.AcceptConnection(listener, dataChannel, errorChannel)
args := append([]string{socketPath}, dataFiles...)
cmd := exec.Command(algoFile, args...)
cmd.Stderr = &outErr
cmd.Stdout = &outStd
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("error starting algorithm: %v", err)
}
if err := cmd.Wait(); err != nil {
as.sm.logger.Debug(outErr.String())
return nil, fmt.Errorf("algorithm execution error: %v", err)
}
select {
case result = <-dataChannel:
as.sm.logger.Debug(outStd.String())
return result, nil
case err = <-errorChannel:
return nil, fmt.Errorf("error receiving data: %v", err)
}
}
+1 -1
View File
@@ -103,7 +103,7 @@ func (cli *CLI) NewAddHostDataCmd() *cobra.Command {
}
}
func changeAttestationConfiguration(fileName string, base64Data string, expectedLength int, field fieldType) error {
func changeAttestationConfiguration(fileName, base64Data string, expectedLength int, field fieldType) error {
data, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return errDecode
+4 -4
View File
@@ -61,8 +61,8 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe
}
algoBuffer := bytes.NewBuffer(algorithm.Algorithm)
progressbar := progressbar.New()
if err := progressbar.SendAlgorithm(algoProgressBarDescription, algoBuffer, &stream); err != nil {
pb := progressbar.New()
if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, &stream); err != nil {
sdk.logger.Error("Failed to send Algorithm")
return err
}
@@ -85,8 +85,8 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an
}
dataBuffer := bytes.NewBuffer(dataset.Dataset)
progressbar := progressbar.New()
if err := progressbar.SendData(dataProgressBarDescription, dataBuffer, &stream); err != nil {
pb := progressbar.New()
if err := pb.SendData(dataProgressBarDescription, dataBuffer, &stream); err != nil {
sdk.logger.Error("Failed to send Data")
return err
}