COCOS-344 - New agent structure (#350)
CI / checkproto (push) Has been cancelled
CI / ci (push) Has been cancelled

* new agent structure

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* minor fixes and testing

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* cvm tests fix

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix cli test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* rename

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* rename cvm to cvms plural

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* rename service

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* remove context

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* refactor: reorder parameters in NewAlgorithm functions and update CVMClient to CVMSClient

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix(tests): update SendEvent mock to include an additional parameter

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* move expectations

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix(tests): move event initialization to the correct scope in service tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix(tests): update SendEvent mock to use EXPECT instead of On in service tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2025-01-17 14:50:53 +03:00
committed by GitHub
parent 59b8057e5c
commit ecad6514f3
53 changed files with 3300 additions and 1340 deletions
+2 -2
View File
@@ -33,8 +33,8 @@ jobs:
- name: Set up protoc
run: |
PROTOC_VERSION=28.1
PROTOC_GEN_VERSION=v1.34.2
PROTOC_VERSION=29.0
PROTOC_GEN_VERSION=v1.36.0
PROTOC_GRPC_VERSION=v1.5.1
# Download and install protoc
+1
View File
@@ -37,6 +37,7 @@ protoc:
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/agent.proto
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative manager/manager.proto
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/cvms/cvms.proto
mocks:
mockery --config ./mockery.yml
+57 -176
View File
@@ -3,8 +3,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.28.1
// protoc-gen-go v1.36.0
// protoc v5.29.0
// source: agent/agent.proto
package agent
@@ -24,21 +24,18 @@ const (
)
type AlgoRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"`
Requirements []byte `protobuf:"bytes,2,opt,name=requirements,proto3" json:"requirements,omitempty"`
unknownFields protoimpl.UnknownFields
Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"`
Requirements []byte `protobuf:"bytes,2,opt,name=requirements,proto3" json:"requirements,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *AlgoRequest) Reset() {
*x = AlgoRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AlgoRequest) String() string {
@@ -49,7 +46,7 @@ func (*AlgoRequest) ProtoMessage() {}
func (x *AlgoRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -79,18 +76,16 @@ func (x *AlgoRequest) GetRequirements() []byte {
}
type AlgoResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AlgoResponse) Reset() {
*x = AlgoResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AlgoResponse) String() string {
@@ -101,7 +96,7 @@ func (*AlgoResponse) ProtoMessage() {}
func (x *AlgoResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -117,21 +112,18 @@ func (*AlgoResponse) Descriptor() ([]byte, []int) {
}
type DataRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
Dataset []byte `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"`
Filename string `protobuf:"bytes,2,opt,name=filename,proto3" json:"filename,omitempty"`
unknownFields protoimpl.UnknownFields
Dataset []byte `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"`
Filename string `protobuf:"bytes,2,opt,name=filename,proto3" json:"filename,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *DataRequest) Reset() {
*x = DataRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *DataRequest) String() string {
@@ -142,7 +134,7 @@ func (*DataRequest) ProtoMessage() {}
func (x *DataRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -172,18 +164,16 @@ func (x *DataRequest) GetFilename() string {
}
type DataResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *DataResponse) Reset() {
*x = DataResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *DataResponse) String() string {
@@ -194,7 +184,7 @@ func (*DataResponse) ProtoMessage() {}
func (x *DataResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -210,18 +200,16 @@ func (*DataResponse) Descriptor() ([]byte, []int) {
}
type ResultRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ResultRequest) Reset() {
*x = ResultRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ResultRequest) String() string {
@@ -232,7 +220,7 @@ func (*ResultRequest) ProtoMessage() {}
func (x *ResultRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -248,20 +236,17 @@ func (*ResultRequest) Descriptor() ([]byte, []int) {
}
type ResultResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
unknownFields protoimpl.UnknownFields
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *ResultResponse) Reset() {
*x = ResultResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ResultResponse) String() string {
@@ -272,7 +257,7 @@ func (*ResultResponse) ProtoMessage() {}
func (x *ResultResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -295,20 +280,17 @@ func (x *ResultResponse) GetFile() []byte {
}
type AttestationRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64.
unknownFields protoimpl.UnknownFields
ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64.
sizeCache protoimpl.SizeCache
}
func (x *AttestationRequest) Reset() {
*x = AttestationRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AttestationRequest) String() string {
@@ -319,7 +301,7 @@ func (*AttestationRequest) ProtoMessage() {}
func (x *AttestationRequest) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[6]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -342,20 +324,17 @@ func (x *AttestationRequest) GetReportData() []byte {
}
type AttestationResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
unknownFields protoimpl.UnknownFields
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
sizeCache protoimpl.SizeCache
}
func (x *AttestationResponse) Reset() {
*x = AttestationResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_agent_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_agent_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AttestationResponse) String() string {
@@ -366,7 +345,7 @@ func (*AttestationResponse) ProtoMessage() {}
func (x *AttestationResponse) ProtoReflect() protoreflect.Message {
mi := &file_agent_agent_proto_msgTypes[7]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -477,104 +456,6 @@ func file_agent_agent_proto_init() {
if File_agent_agent_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_agent_agent_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*AlgoRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*AlgoResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*DataRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[3].Exporter = func(v any, i int) any {
switch v := v.(*DataResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[4].Exporter = func(v any, i int) any {
switch v := v.(*ResultRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[5].Exporter = func(v any, i int) any {
switch v := v.(*ResultResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[6].Exporter = func(v any, i int) any {
switch v := v.(*AttestationRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_agent_proto_msgTypes[7].Exporter = func(v any, i int) any {
switch v := v.(*AttestationResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
+1 -1
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.28.1
// - protoc v5.29.0
// source: agent/agent.proto
package agent
+3
View File
@@ -46,4 +46,7 @@ func AlgorithmArgsFromContext(ctx context.Context) []string {
type Algorithm interface {
// Run executes the algorithm and returns the result.
Run() error
// Stop stops the algorithm.
Stop() error
}
+24 -7
View File
@@ -20,29 +20,46 @@ type binary struct {
stderr io.Writer
stdout io.Writer
args []string
cmd *exec.Cmd
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string, cmpID string) algorithm.Algorithm {
return &binary{
algoFile: algoFile,
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
stdout: &logging.Stdout{Logger: logger},
args: args,
}
}
func (b *binary) Run() error {
cmd := exec.Command(b.algoFile, b.args...)
cmd.Stderr = b.stderr
cmd.Stdout = b.stdout
b.cmd = exec.Command(b.algoFile, b.args...)
b.cmd.Stderr = b.stderr
b.cmd.Stdout = b.stdout
if err := cmd.Start(); err != nil {
if err := b.cmd.Start(); err != nil {
return fmt.Errorf("error starting algorithm: %v", err)
}
if err := cmd.Wait(); err != nil {
if err := b.cmd.Wait(); err != nil {
return fmt.Errorf("algorithm execution error: %v", err)
}
return nil
}
func (b *binary) Stop() error {
if b.cmd == nil {
return nil
}
if b.cmd.ProcessState != nil && b.cmd.ProcessState.Exited() {
return nil
}
if err := b.cmd.Process.Kill(); err != nil {
return fmt.Errorf("error stopping algorithm: %v", err)
}
return nil
}
+2 -2
View File
@@ -18,7 +18,7 @@ func TestNewAlgorithm(t *testing.T) {
algoFile := "/path/to/algo"
args := []string{"arg1", "arg2"}
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
algo := NewAlgorithm(logger, eventsSvc, algoFile, args, "")
b, ok := algo.(*binary)
if !ok {
@@ -74,7 +74,7 @@ func TestBinaryRun(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventsSvc := new(mocks.Service)
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args).(*binary)
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args, "").(*binary)
var stdout, stderr bytes.Buffer
b.stdout = &stdout
+8 -4
View File
@@ -35,11 +35,11 @@ type docker struct {
stdout io.Writer
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string) algorithm.Algorithm {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile, cmpID string) algorithm.Algorithm {
d := &docker{
algoFile: algoFile,
logger: logger,
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
stdout: &logging.Stdout{Logger: logger},
}
@@ -47,8 +47,6 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
}
func (d *docker) Run() error {
ctx := context.Background()
// Create a new Docker client.
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
@@ -62,6 +60,7 @@ func (d *docker) Run() error {
}
defer imageFile.Close()
ctx := context.Background()
// Load the Docker image from the tar file.
resp, err := cli.ImageLoad(ctx, imageFile, true)
if err != nil {
@@ -176,3 +175,8 @@ func writeToOut(readCloser io.ReadCloser, ioWriter io.Writer) error {
return nil
}
func (d *docker) Stop() error {
// To be supported later.
return nil
}
+1 -1
View File
@@ -18,7 +18,7 @@ func TestNewAlgorithm(t *testing.T) {
eventsSvc := new(mocks.Service)
algoFile := "/path/to/algo.tar"
algo := NewAlgorithm(logger, eventsSvc, algoFile)
algo := NewAlgorithm(logger, eventsSvc, algoFile, "")
d, ok := algo.(*docker)
assert.True(t, ok, "NewAlgorithm should return a *docker")
+2 -3
View File
@@ -50,6 +50,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) {
type Stderr struct {
Logger *slog.Logger
EventSvc events.Service
CmpID string
}
// Write implements io.Writer.
@@ -70,9 +71,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) {
s.Logger.Error(string(buf[:n]))
}
if err := s.EventSvc.SendEvent(algorithmRun, warningStatus, json.RawMessage{}); err != nil {
return len(p), err
}
s.EventSvc.SendEvent(s.CmpID, algorithmRun, warningStatus, json.RawMessage{})
return len(p), nil
}
+1 -1
View File
@@ -73,7 +73,7 @@ func TestStderrWrite(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockEventService := mocks.NewService(t)
mockEventService.On("SendEvent", "AlgorithmRun", manager.Warning.String(), mock.Anything).Return(nil)
mockEventService.On("SendEvent", mock.Anything, "AlgorithmRun", manager.Warning.String(), mock.Anything).Return(nil)
stderr := &Stderr{Logger: mglog.NewMock(), EventSvc: mockEventService}
n, err := stderr.Write([]byte(tt.input))
+24 -7
View File
@@ -39,12 +39,13 @@ type python struct {
runtime string
requirementsFile string
args []string
cmd *exec.Cmd
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string) algorithm.Algorithm {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string, cmpID string) algorithm.Algorithm {
p := &python{
algoFile: algoFile,
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
stdout: &logging.Stdout{Logger: logger},
requirementsFile: requirementsFile,
args: args,
@@ -85,15 +86,15 @@ func (p *python) Run() error {
}
args := append([]string{p.algoFile}, p.args...)
cmd := exec.Command(pythonPath, args...)
cmd.Stderr = p.stderr
cmd.Stdout = p.stdout
p.cmd = exec.Command(pythonPath, args...)
p.cmd.Stderr = p.stderr
p.cmd.Stdout = p.stdout
if err := cmd.Start(); err != nil {
if err := p.cmd.Start(); err != nil {
return fmt.Errorf("error starting algorithm: %v", err)
}
if err := cmd.Wait(); err != nil {
if err := p.cmd.Wait(); err != nil {
return fmt.Errorf("algorithm execution error: %v", err)
}
@@ -103,3 +104,19 @@ func (p *python) Run() error {
return nil
}
func (p *python) Stop() error {
if p.cmd == nil {
return nil
}
if p.cmd.ProcessState != nil && p.cmd.ProcessState.Exited() {
return nil
}
if err := p.cmd.Process.Kill(); err != nil {
return fmt.Errorf("error stopping algorithm: %v", err)
}
return nil
}
+1 -1
View File
@@ -50,7 +50,7 @@ func TestNewAlgorithm(t *testing.T) {
algoFile := "algorithm.py"
args := []string{"--arg1", "value1"}
algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args)
algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args, "")
p, ok := algo.(*python)
if !ok {
+24 -7
View File
@@ -24,12 +24,13 @@ type wasm struct {
stderr io.Writer
stdout io.Writer
args []string
cmd *exec.Cmd
}
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string, algoFile, cmpID string) algorithm.Algorithm {
return &wasm{
algoFile: algoFile,
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
stdout: &logging.Stdout{Logger: logger},
args: args,
}
@@ -38,17 +39,33 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
func (w *wasm) Run() error {
args := append(mapDirOption, w.algoFile)
args = append(args, w.args...)
cmd := exec.Command(wasmRuntime, args...)
cmd.Stderr = w.stderr
cmd.Stdout = w.stdout
w.cmd = exec.Command(wasmRuntime, args...)
w.cmd.Stderr = w.stderr
w.cmd.Stdout = w.stdout
if err := cmd.Start(); err != nil {
if err := w.cmd.Start(); err != nil {
return fmt.Errorf("error starting algorithm: %v", err)
}
if err := cmd.Wait(); err != nil {
if err := w.cmd.Wait(); err != nil {
return fmt.Errorf("algorithm execution error: %v", err)
}
return nil
}
func (w *wasm) Stop() error {
if w.cmd == nil {
return nil
}
if w.cmd.ProcessState != nil && w.cmd.ProcessState.Exited() {
return nil
}
if err := w.cmd.Process.Kill(); err != nil {
return fmt.Errorf("error stopping algorithm: %v", err)
}
return nil
}
+2 -2
View File
@@ -18,7 +18,7 @@ func TestNewAlgorithm(t *testing.T) {
algoFile := "test.wasm"
args := []string{"arg1", "arg2"}
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
algo := NewAlgorithm(logger, eventsSvc, args, algoFile, "")
w, ok := algo.(*wasm)
if !ok {
@@ -54,7 +54,7 @@ func TestRunError(t *testing.T) {
algoFile := "test.wasm"
args := []string{"arg1", "arg2"}
w := NewAlgorithm(logger, eventsSvc, algoFile, args).(*wasm)
w := NewAlgorithm(logger, eventsSvc, args, algoFile, "").(*wasm)
err := w.Run()
if err == nil {
+28
View File
@@ -27,6 +27,34 @@ func LoggingMiddleware(svc agent.Service, logger *slog.Logger) agent.Service {
return &loggingMiddleware{logger, svc}
}
// InitComputation implements agent.Service.
func (lm *loggingMiddleware) InitComputation(ctx context.Context, cmp agent.Computation) (err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method InitComputation for computation id %s took %s to complete", cmp.ID, time.Since(begin))
if err != nil {
lm.logger.WithGroup(cmp.ID).Warn(fmt.Sprintf("%s with error: %s", message, err))
return
}
lm.logger.WithGroup(cmp.ID).Info(fmt.Sprintf("%s without errors", message))
}(time.Now())
return lm.svc.InitComputation(ctx, cmp)
}
// StopComputation implements agent.Service.
func (lm *loggingMiddleware) StopComputation(ctx context.Context) (err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method StopComputation took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s", message, err))
return
}
lm.logger.Info(fmt.Sprintf("%s without errors", message))
}(time.Now())
return lm.svc.StopComputation(ctx)
}
func (lm *loggingMiddleware) Algo(ctx context.Context, algorithm agent.Algorithm) (err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method Algo took %s to complete", time.Since(begin))
+20
View File
@@ -32,6 +32,26 @@ func MetricsMiddleware(svc agent.Service, counter metrics.Counter, latency metri
}
}
// InitComputation implements agent.Service.
func (ms *metricsMiddleware) InitComputation(ctx context.Context, cmp agent.Computation) error {
defer func(begin time.Time) {
ms.counter.With("method", "init_computation").Add(1)
ms.latency.With("method", "init_computation").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.InitComputation(ctx, cmp)
}
// StopComputation implements agent.Service.
func (ms *metricsMiddleware) StopComputation(ctx context.Context) error {
defer func(begin time.Time) {
ms.counter.With("method", "stop_computation").Add(1)
ms.latency.With("method", "stop_computation").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.StopComputation(ctx)
}
func (ms *metricsMiddleware) Algo(ctx context.Context, algorithm agent.Algorithm) error {
defer func(begin time.Time) {
ms.counter.With("method", "algo").Add(1)
-2
View File
@@ -13,7 +13,6 @@ import (
var _ fmt.Stringer = (*Datasets)(nil)
type AgentConfig struct {
LogLevel string `json:"log_level,omitempty"`
Host string `json:"host,omitempty"`
Port string `json:"port,omitempty"`
CertFile string `json:"cert_file,omitempty"`
@@ -30,7 +29,6 @@ type Computation struct {
Datasets Datasets `json:"datasets,omitempty"`
Algorithm Algorithm `json:"algorithm,omitempty"`
ResultConsumers []ResultConsumer `json:"result_consumers,omitempty"`
AgentConfig AgentConfig `json:"agent_config,omitempty"`
}
type ResultConsumer struct {
-1
View File
@@ -106,7 +106,6 @@ func TestDecompressToContext(t *testing.T) {
func TestAgentConfigJSON(t *testing.T) {
config := AgentConfig{
LogLevel: "info",
Host: "localhost",
Port: "8080",
CertFile: "cert.pem",
+261
View File
@@ -0,0 +1,261 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"log/slog"
"sync"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/agent/cvms/server"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
)
var (
errCorruptedManifest = errors.New("received manifest may be corrupted")
errUnknonwMessageType = errors.New("unknown message type")
sendTimeout = 5 * time.Second
)
type CVMSClient struct {
mu sync.Mutex
stream cvms.Service_ProcessClient
svc agent.Service
messageQueue chan *cvms.ClientStreamMessage
logger *slog.Logger
runReqManager *runRequestManager
sp server.AgentServer
}
// NewClient returns new gRPC client instance.
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer) CVMSClient {
return CVMSClient{
stream: stream,
svc: svc,
messageQueue: messageQueue,
logger: logger,
runReqManager: newRunRequestManager(),
sp: sp,
}
}
func (client *CVMSClient) Process(ctx context.Context, cancel context.CancelFunc) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return client.handleIncomingMessages(ctx)
})
eg.Go(func() error {
return client.handleOutgoingMessages(ctx)
})
return eg.Wait()
}
func (client *CVMSClient) handleIncomingMessages(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
req, err := client.stream.Recv()
if err != nil {
return err
}
if err := client.processIncomingMessage(ctx, req); err != nil {
return err
}
}
}
}
func (client *CVMSClient) processIncomingMessage(ctx context.Context, req *cvms.ServerStreamMessage) error {
switch mes := req.Message.(type) {
case *cvms.ServerStreamMessage_RunReqChunks:
return client.handleRunReqChunks(ctx, mes)
case *cvms.ServerStreamMessage_StopComputation:
go client.handleStopComputation(ctx, mes)
default:
return errUnknonwMessageType
}
return nil
}
func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error {
buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast)
if complete {
var runReq cvms.ComputationRunReq
if err := proto.Unmarshal(buffer, &runReq); err != nil {
return errors.Wrap(err, errCorruptedManifest)
}
go client.executeRun(ctx, &runReq)
}
return nil
}
func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.ComputationRunReq) {
ac := agent.Computation{
ID: runReq.Id,
Name: runReq.Name,
Description: runReq.Description,
}
if runReq.Algorithm != nil {
ac.Algorithm = agent.Algorithm{
Hash: [32]byte(runReq.Algorithm.Hash),
UserKey: runReq.Algorithm.UserKey,
}
}
for _, ds := range runReq.Datasets {
ac.Datasets = append(ac.Datasets, agent.Dataset{
Hash: [32]byte(ds.Hash),
UserKey: ds.UserKey,
})
}
for _, rc := range runReq.ResultConsumers {
ac.ResultConsumers = append(ac.ResultConsumers, agent.ResultConsumer{
UserKey: rc.UserKey,
})
}
if err := client.svc.InitComputation(ctx, ac); err != nil {
client.logger.Warn(err.Error())
return
}
client.mu.Lock()
defer client.mu.Unlock()
if runReq.AgentConfig == nil {
runReq.AgentConfig = &cvms.AgentConfig{}
}
runRes := &cvms.ClientStreamMessage_RunRes{
RunRes: &cvms.RunResponse{
ComputationId: runReq.Id,
},
}
err := client.sp.Start(ctx, agent.AgentConfig{
Port: runReq.AgentConfig.Port,
Host: runReq.AgentConfig.Host,
CertFile: runReq.AgentConfig.CertFile,
KeyFile: runReq.AgentConfig.KeyFile,
ServerCAFile: runReq.AgentConfig.ServerCaFile,
ClientCAFile: runReq.AgentConfig.ClientCaFile,
AttestedTls: runReq.AgentConfig.AttestedTls,
}, ac)
if err != nil {
client.logger.Warn(err.Error())
runRes.RunRes.Error = err.Error()
}
client.sendMessage(&cvms.ClientStreamMessage{Message: runRes})
}
func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.ServerStreamMessage_StopComputation) {
msg := &cvms.ClientStreamMessage_StopComputationRes{
StopComputationRes: &cvms.StopComputationResponse{
ComputationId: mes.StopComputation.ComputationId,
},
}
if err := client.svc.StopComputation(ctx); err != nil {
msg.StopComputationRes.Message = err.Error()
}
client.mu.Lock()
defer client.mu.Unlock()
if err := client.sp.Stop(); err != nil {
msg.StopComputationRes.Message = err.Error()
}
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
}
func (client *CVMSClient) handleOutgoingMessages(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case mes := <-client.messageQueue:
if err := client.stream.Send(mes); err != nil {
return err
}
}
}
}
func (client *CVMSClient) sendMessage(mes *cvms.ClientStreamMessage) {
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
defer cancel()
select {
case client.messageQueue <- mes:
case <-ctx.Done():
client.logger.Warn("Failed to send message: timeout exceeded")
}
}
type runRequestManager struct {
requests map[string]*runRequest
mu sync.Mutex
}
type runRequest struct {
buffer []byte
lastChunk time.Time
timer *time.Timer
}
func newRunRequestManager() *runRequestManager {
return &runRequestManager{
requests: make(map[string]*runRequest),
}
}
func (m *runRequestManager) addChunk(id string, chunk []byte, isLast bool) ([]byte, bool) {
m.mu.Lock()
defer m.mu.Unlock()
req, exists := m.requests[id]
if !exists {
req = &runRequest{
buffer: make([]byte, 0),
lastChunk: time.Now(),
timer: time.AfterFunc(runReqTimeout, func() { m.timeoutRequest(id) }),
}
m.requests[id] = req
}
req.buffer = append(req.buffer, chunk...)
req.lastChunk = time.Now()
req.timer.Reset(runReqTimeout)
if isLast {
delete(m.requests, id)
req.timer.Stop()
return req.buffer, true
}
return nil, false
}
func (m *runRequestManager) timeoutRequest(id string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.requests, id)
// Log timeout or handle it as needed
}
+202
View File
@@ -0,0 +1,202 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent/cvms"
servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks"
"github.com/ultravioletrs/cocos/agent/mocks"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
type mockStream struct {
mock.Mock
grpc.ClientStream
}
func (m *mockStream) Recv() (*cvms.ServerStreamMessage, error) {
args := m.Called()
return args.Get(0).(*cvms.ServerStreamMessage), args.Error(1)
}
func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error {
args := m.Called(msg)
return args.Error(0)
}
func TestManagerClient_Process1(t *testing.T) {
tests := []struct {
name string
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider)
expectError bool
errorMsg string
}{
{
name: "Stop computation",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider) {
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
Message: &cvms.ServerStreamMessage_StopComputation{
StopComputation: &cvms.StopComputation{},
},
}, nil)
mockStream.On("Send", mock.Anything).Return(nil)
mockSvc.On("StopComputation", mock.Anything).Return(nil)
mockServerSvc.On("Stop").Return(nil)
},
expectError: true,
errorMsg: "context deadline exceeded",
},
{
name: "Run request chunks",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider) {
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
Message: &cvms.ServerStreamMessage_RunReqChunks{
RunReqChunks: &cvms.RunReqChunks{},
},
}, nil)
mockStream.On("Send", mock.Anything).Return(nil).Once()
mockSvc.On("Run", mock.Anything, mock.Anything).Return("", assert.AnError).Once()
},
expectError: true,
},
{
name: "Receive error",
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider) {
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{}, assert.AnError)
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServerProvider)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
tc.setupMocks(mockStream, mockSvc, mockServerSvc)
err := client.Process(ctx, cancel)
if tc.expectError {
assert.Error(t, err)
if tc.errorMsg != "" {
assert.Contains(t, err.Error(), tc.errorMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestManagerClient_handleRunReqChunks(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServerProvider)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc)
runReq := &cvms.ComputationRunReq{
Id: "test-id",
}
runReqBytes, _ := proto.Marshal(runReq)
chunk1 := &cvms.ServerStreamMessage_RunReqChunks{
RunReqChunks: &cvms.RunReqChunks{
Id: "chunk-1",
Data: runReqBytes[:len(runReqBytes)/2],
IsLast: false,
},
}
chunk2 := &cvms.ServerStreamMessage_RunReqChunks{
RunReqChunks: &cvms.RunReqChunks{
Id: "chunk-1",
Data: runReqBytes[len(runReqBytes)/2:],
IsLast: true,
},
}
mockSvc.On("InitComputation", mock.Anything, mock.Anything).Return(nil)
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := client.handleRunReqChunks(context.Background(), chunk1)
assert.NoError(t, err)
err = client.handleRunReqChunks(context.Background(), chunk2)
assert.NoError(t, err)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
runRes, ok := msg.Message.(*cvms.ClientStreamMessage_RunRes)
assert.True(t, ok)
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
}
func TestManagerClient_handleStopComputation(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
mockServerSvc := new(servermocks.AgentServerProvider)
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc)
stopReq := &cvms.ServerStreamMessage_StopComputation{
StopComputation: &cvms.StopComputation{
ComputationId: "test-comp-id",
},
}
mockSvc.On("StopComputation", mock.Anything).Return(nil)
mockServerSvc.On("Stop").Return(nil)
client.handleStopComputation(context.Background(), stopReq)
// Wait for the goroutine to finish
time.Sleep(50 * time.Millisecond)
mockSvc.AssertExpectations(t)
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
stopRes, ok := msg.Message.(*cvms.ClientStreamMessage_StopComputationRes)
assert.True(t, ok)
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
assert.Empty(t, stopRes.StopComputationRes.Message)
}
func TestManagerClient_timeoutRequest(t *testing.T) {
rm := newRunRequestManager()
rm.requests["test-id"] = &runRequest{
timer: time.NewTimer(100 * time.Millisecond),
buffer: []byte("test-data"),
lastChunk: time.Now(),
}
rm.timeoutRequest("test-id")
assert.Len(t, rm.requests, 0)
}
+5
View File
@@ -0,0 +1,5 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Package grpc contains implementation of kit service gRPC API.
package grpc
+133
View File
@@ -0,0 +1,133 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"bytes"
"context"
"errors"
"io"
"time"
"github.com/ultravioletrs/cocos/agent/cvms"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
"google.golang.org/protobuf/proto"
)
var (
_ cvms.ServiceServer = (*grpcServer)(nil)
ErrUnexpectedMsg = errors.New("unknown message type")
)
const (
bufferSize = 1024 * 1024 // 1 MB
runReqTimeout = 30 * time.Second
)
type SendFunc func(*cvms.ServerStreamMessage) error
type grpcServer struct {
cvms.UnimplementedServiceServer
incoming chan *cvms.ClientStreamMessage
svc Service
}
type Service interface {
Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo)
}
// NewServer returns new AuthServiceServer instance.
func NewServer(incoming chan *cvms.ClientStreamMessage, svc Service) cvms.ServiceServer {
return &grpcServer{
incoming: incoming,
svc: svc,
}
}
func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
client, ok := peer.FromContext(stream.Context())
if !ok {
return errors.New("failed to get peer info")
}
eg, ctx := errgroup.WithContext(stream.Context())
eg.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
req, err := stream.Recv()
if err != nil {
return err
}
s.incoming <- req
}
}
})
eg.Go(func() error {
sendMessage := func(msg *cvms.ServerStreamMessage) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
switch m := msg.Message.(type) {
case *cvms.ServerStreamMessage_RunReq:
return s.sendRunReqInChunks(stream, m.RunReq)
default:
return stream.Send(msg)
}
}
}
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
return nil
})
return eg.Wait()
}
func (s *grpcServer) sendRunReqInChunks(stream cvms.Service_ProcessServer, runReq *cvms.ComputationRunReq) error {
data, err := proto.Marshal(runReq)
if err != nil {
return err
}
dataBuffer := bytes.NewBuffer(data)
buf := make([]byte, bufferSize)
for {
n, err := dataBuffer.Read(buf)
isLast := false
if err == io.EOF {
isLast = true
} else if err != nil {
return err
}
chunk := &cvms.ServerStreamMessage{
Message: &cvms.ServerStreamMessage_RunReqChunks{
RunReqChunks: &cvms.RunReqChunks{
Id: runReq.Id,
Data: buf[:n],
IsLast: isLast,
},
},
}
if err := stream.Send(chunk); err != nil {
return err
}
if isLast {
break
}
}
return nil
}
+273
View File
@@ -0,0 +1,273 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"testing"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent/cvms"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
)
type mockServerStream struct {
mock.Mock
cvms.Service_ProcessServer
}
func (m *mockServerStream) Send(msg *cvms.ServerStreamMessage) error {
args := m.Called(msg)
return args.Error(0)
}
func (m *mockServerStream) Recv() (*cvms.ClientStreamMessage, error) {
args := m.Called()
return args.Get(0).(*cvms.ClientStreamMessage), args.Error(1)
}
func (m *mockServerStream) Context() context.Context {
args := m.Called()
return args.Get(0).(context.Context)
}
type mockService struct {
mock.Mock
}
func (m *mockService) Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) {
m.Called(ctx, ipAddress, sendMessage, authInfo)
}
func TestNewServer(t *testing.T) {
incoming := make(chan *cvms.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc)
assert.NotNil(t, server)
assert.IsType(t, &grpcServer{}, server)
}
func TestGrpcServer_Process(t *testing.T) {
tests := []struct {
name string
recvReturn *cvms.ClientStreamMessage
recvError error
expectedError string
}{
{
name: "Process with context deadline exceeded",
recvReturn: &cvms.ClientStreamMessage{},
recvError: nil,
expectedError: "context deadline exceeded",
},
{
name: "Process with Recv error",
recvReturn: &cvms.ClientStreamMessage{},
recvError: errors.New("recv error"),
expectedError: "recv error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
incoming := make(chan *cvms.ClientStreamMessage, 1)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
Addr: mockAddr{},
AuthInfo: mockAuthInfo{},
}))
if tt.recvError == nil {
go func() {
for mes := range incoming {
assert.NotNil(t, mes)
}
}()
}
mockStream.On("Recv").Return(tt.recvReturn, tt.recvError)
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
err := server.Process(mockStream)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.expectedError)
mockStream.AssertExpectations(t)
mockSvc.AssertExpectations(t)
})
}
}
func TestGrpcServer_sendRunReqInChunks(t *testing.T) {
incoming := make(chan *cvms.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
runReq := &cvms.ComputationRunReq{
Id: "test-id",
}
largePayload := make([]byte, bufferSize*2)
for i := range largePayload {
largePayload[i] = byte(i % 256)
}
runReq.Algorithm = &cvms.Algorithm{}
runReq.Algorithm.UserKey = largePayload
mockStream.On("Send", mock.AnythingOfType("*cvms.ServerStreamMessage")).Return(nil).Times(4)
err := server.sendRunReqInChunks(mockStream, runReq)
assert.NoError(t, err)
mockStream.AssertExpectations(t)
calls := mockStream.Calls
assert.Equal(t, 4, len(calls))
for i, call := range calls {
msg := call.Arguments[0].(*cvms.ServerStreamMessage)
chunk := msg.GetRunReqChunks()
assert.NotNil(t, chunk)
assert.Equal(t, "test-id", chunk.Id)
if i < 3 {
assert.False(t, chunk.IsLast)
} else {
assert.Equal(t, 0, len(chunk.Data))
assert.True(t, chunk.IsLast)
}
}
}
type mockAddr struct{}
func (mockAddr) Network() string { return "test network" }
func (mockAddr) String() string { return "test" }
type mockAuthInfo struct{}
func (mockAuthInfo) AuthType() string { return "test auth" }
func TestGrpcServer_ProcessWithMockService(t *testing.T) {
tests := []struct {
name string
setupMockFn func(*mockService, *mockServerStream)
}{
{
name: "Run Request Test",
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
Run(func(args mock.Arguments) {
sendFunc := args.Get(2).(SendFunc)
runReq := &cvms.ComputationRunReq{Id: "test-run-id"}
err := sendFunc(&cvms.ServerStreamMessage{
Message: &cvms.ServerStreamMessage_RunReq{
RunReq: runReq,
},
})
assert.NoError(t, err)
}).
Return()
mockStream.On("Send", mock.MatchedBy(func(msg *cvms.ServerStreamMessage) bool {
chunks := msg.GetRunReqChunks()
return chunks != nil && chunks.Id == "test-run-id"
})).Return(nil)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
incoming := make(chan *cvms.ClientStreamMessage, 10)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
go func() {
for mes := range incoming {
assert.NotNil(t, mes)
}
}()
mockStream := new(mockServerStream)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
peerCtx := peer.NewContext(ctx, &peer.Peer{
Addr: mockAddr{},
AuthInfo: mockAuthInfo{},
})
mockStream.On("Context").Return(peerCtx)
mockStream.On("Recv").Return(&cvms.ClientStreamMessage{}, nil).Maybe()
tt.setupMockFn(mockSvc, mockStream)
go func() {
time.Sleep(150 * time.Millisecond)
cancel()
}()
err := server.Process(mockStream)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
mockStream.AssertExpectations(t)
mockSvc.AssertExpectations(t)
})
}
}
func TestGrpcServer_sendRunReqInChunksError(t *testing.T) {
incoming := make(chan *cvms.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
runReq := &cvms.ComputationRunReq{
Id: "test-id",
}
// Simulate an error when sending
mockStream.On("Send", mock.AnythingOfType("*cvms.ServerStreamMessage")).Return(errors.New("send error")).Once()
err := server.sendRunReqInChunks(mockStream, runReq)
assert.Error(t, err)
assert.Contains(t, err.Error(), "send error")
mockStream.AssertExpectations(t)
}
func TestGrpcServer_ProcessMissingPeerInfo(t *testing.T) {
incoming := make(chan *cvms.ClientStreamMessage)
mockSvc := new(mockService)
server := NewServer(incoming, mockSvc).(*grpcServer)
mockStream := new(mockServerStream)
ctx := context.Background()
// Return a context without peer info
mockStream.On("Context").Return(ctx)
err := server.Process(mockStream)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to get peer info")
mockStream.AssertExpectations(t)
}
File diff suppressed because it is too large Load Diff
+103
View File
@@ -0,0 +1,103 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
import "google/protobuf/timestamp.proto";
package cvms;
option go_package = "./cvms";
service Service {
rpc Process(stream ClientStreamMessage) returns (stream ServerStreamMessage) {}
}
message StopComputation {
string computation_id = 1;
}
message StopComputationResponse {
string computation_id = 1;
string message = 2;
}
message RunResponse{
string computation_id = 1;
string error = 2;
}
message AgentEvent {
string event_type = 1;
google.protobuf.Timestamp timestamp = 2;
string computation_id = 3;
bytes details = 4;
string originator = 5;
string status = 6;
}
message AgentLog {
string message = 1;
string computation_id = 2;
string level = 3;
google.protobuf.Timestamp timestamp = 4;
}
message ClientStreamMessage {
oneof message {
AgentLog agent_log = 1;
AgentEvent agent_event = 2;
RunResponse run_res = 3;
StopComputationResponse stopComputationRes = 4;
}
}
message ServerStreamMessage {
oneof message {
RunReqChunks runReqChunks = 1;
ComputationRunReq runReq = 2;
StopComputation stopComputation = 3;
}
}
message RunReqChunks {
bytes data = 1;
string id = 2;
bool is_last = 3;
}
message ComputationRunReq {
string id = 1;
string name = 2;
string description = 3;
repeated Dataset datasets = 4;
Algorithm algorithm = 5;
repeated ResultConsumer result_consumers = 6;
AgentConfig agent_config = 7;
}
message ResultConsumer {
bytes userKey = 1;
}
message Dataset {
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
bytes userKey = 2;
string filename = 3;
}
message Algorithm {
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
bytes userKey = 2;
}
message AgentConfig {
string port = 1;
string host = 2;
string cert_file = 3;
string key_file = 4;
string client_ca_file = 5;
string server_ca_file = 6;
string log_level = 7;
bool attested_tls = 8;
}
+118
View File
@@ -0,0 +1,118 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.29.0
// source: agent/cvms/cvms.proto
package cvms
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
Service_Process_FullMethodName = "/cvms.Service/Process"
)
// ServiceClient is the client API for Service service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ServiceClient interface {
Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error)
}
type serviceClient struct {
cc grpc.ClientConnInterface
}
func NewServiceClient(cc grpc.ClientConnInterface) ServiceClient {
return &serviceClient{cc}
}
func (c *serviceClient) Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
stream, err := c.cc.NewStream(ctx, &Service_ServiceDesc.Streams[0], Service_Process_FullMethodName, cOpts...)
if err != nil {
return nil, err
}
x := &grpc.GenericClientStream[ClientStreamMessage, ServerStreamMessage]{ClientStream: stream}
return x, nil
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type Service_ProcessClient = grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage]
// ServiceServer is the server API for Service service.
// All implementations must embed UnimplementedServiceServer
// for forward compatibility.
type ServiceServer interface {
Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error
mustEmbedUnimplementedServiceServer()
}
// UnimplementedServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedServiceServer struct{}
func (UnimplementedServiceServer) Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error {
return status.Errorf(codes.Unimplemented, "method Process not implemented")
}
func (UnimplementedServiceServer) mustEmbedUnimplementedServiceServer() {}
func (UnimplementedServiceServer) testEmbeddedByValue() {}
// UnsafeServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to ServiceServer will
// result in compilation errors.
type UnsafeServiceServer interface {
mustEmbedUnimplementedServiceServer()
}
func RegisterServiceServer(s grpc.ServiceRegistrar, srv ServiceServer) {
// If the following call pancis, it indicates UnimplementedServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&Service_ServiceDesc, srv)
}
func _Service_Process_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ServiceServer).Process(&grpc.GenericServerStream[ClientStreamMessage, ServerStreamMessage]{ServerStream: stream})
}
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type Service_ProcessServer = grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]
// Service_ServiceDesc is the grpc.ServiceDesc for Service service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Service_ServiceDesc = grpc.ServiceDesc{
ServiceName: "cvms.Service",
HandlerType: (*ServiceServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
{
StreamName: "Process",
Handler: _Service_Process_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "agent/cvms/cvms.proto",
}
+89
View File
@@ -0,0 +1,89 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package server
import (
context "context"
"fmt"
"log/slog"
"github.com/ultravioletrs/cocos/agent"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
const (
svcName = "agent"
defSvcGRPCPort = "7002"
)
type AgentServer interface {
Start(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation) error
Stop() error
}
type agentServer struct {
gs server.Server
logger *slog.Logger
svc agent.Service
}
func NewServer(logger *slog.Logger, svc agent.Service) AgentServer {
return &agentServer{
logger: logger,
svc: svc,
}
}
func (as *agentServer) Start(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation) error {
if cfg.Port == "" {
cfg.Port = defSvcGRPCPort
}
agentGrpcServerConfig := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: cfg.Host,
Port: cfg.Port,
CertFile: cfg.CertFile,
KeyFile: cfg.KeyFile,
ServerCAFile: cfg.ServerCAFile,
ClientCAFile: cfg.ClientCAFile,
},
},
AttestedTLS: cfg.AttestedTls,
}
registerAgentServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(as.svc))
}
authSvc, err := auth.New(cmp)
if err != nil {
as.logger.WithGroup(cmp.ID).Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
return err
}
qp, err := quoteprovider.GetQuoteProvider()
if err != nil {
as.logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error()))
return err
}
ctx, cancel := context.WithCancel(ctx)
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, qp, authSvc)
return as.gs.Start()
}
func (as *agentServer) Stop() error {
return as.gs.Stop()
}
+134
View File
@@ -0,0 +1,134 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery v2.43.2. DO NOT EDIT.
package mocks
import (
context "context"
agent "github.com/ultravioletrs/cocos/agent"
mock "github.com/stretchr/testify/mock"
)
// AgentServerProvider is an autogenerated mock type for the AgentServerProvider type
type AgentServerProvider struct {
mock.Mock
}
type AgentServerProvider_Expecter struct {
mock *mock.Mock
}
func (_m *AgentServerProvider) EXPECT() *AgentServerProvider_Expecter {
return &AgentServerProvider_Expecter{mock: &_m.Mock}
}
// Start provides a mock function with given fields: ctx, cfg, cmp
func (_m *AgentServerProvider) Start(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation) error {
ret := _m.Called(ctx, cfg, cmp)
if len(ret) == 0 {
panic("no return value specified for Start")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, agent.AgentConfig, agent.Computation) error); ok {
r0 = rf(ctx, cfg, cmp)
} else {
r0 = ret.Error(0)
}
return r0
}
// AgentServerProvider_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
type AgentServerProvider_Start_Call struct {
*mock.Call
}
// Start is a helper method to define mock.On call
// - ctx context.Context
// - cfg agent.AgentConfig
// - cmp agent.Computation
func (_e *AgentServerProvider_Expecter) Start(ctx interface{}, cfg interface{}, cmp interface{}) *AgentServerProvider_Start_Call {
return &AgentServerProvider_Start_Call{Call: _e.mock.On("Start", ctx, cfg, cmp)}
}
func (_c *AgentServerProvider_Start_Call) Run(run func(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation)) *AgentServerProvider_Start_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(agent.AgentConfig), args[2].(agent.Computation))
})
return _c
}
func (_c *AgentServerProvider_Start_Call) Return(_a0 error) *AgentServerProvider_Start_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *AgentServerProvider_Start_Call) RunAndReturn(run func(context.Context, agent.AgentConfig, agent.Computation) error) *AgentServerProvider_Start_Call {
_c.Call.Return(run)
return _c
}
// Stop provides a mock function with given fields:
func (_m *AgentServerProvider) Stop() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Stop")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// AgentServerProvider_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
type AgentServerProvider_Stop_Call struct {
*mock.Call
}
// Stop is a helper method to define mock.On call
func (_e *AgentServerProvider_Expecter) Stop() *AgentServerProvider_Stop_Call {
return &AgentServerProvider_Stop_Call{Call: _e.mock.On("Stop")}
}
func (_c *AgentServerProvider_Stop_Call) Run(run func()) *AgentServerProvider_Stop_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *AgentServerProvider_Stop_Call) Return(_a0 error) *AgentServerProvider_Stop_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *AgentServerProvider_Stop_Call) RunAndReturn(run func() error) *AgentServerProvider_Stop_Call {
_c.Call.Return(run)
return _c
}
// NewAgentServerProvider creates a new instance of AgentServerProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewAgentServerProvider(t interface {
mock.TestingT
Cleanup(func())
}) *AgentServerProvider {
mock := &AgentServerProvider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+19 -24
View File
@@ -4,43 +4,38 @@ package events
import (
"encoding/json"
"io"
"google.golang.org/protobuf/proto"
"github.com/ultravioletrs/cocos/agent/cvms"
"google.golang.org/protobuf/types/known/timestamppb"
)
type service struct {
service string
computationID string
conn io.Writer
service string
queue chan *cvms.ClientStreamMessage
}
type Service interface {
SendEvent(event, status string, details json.RawMessage) error
SendEvent(cmpID, event, status string, details json.RawMessage)
}
func New(svc, computationID string, conn io.Writer) (Service, error) {
func New(svc string, queue chan *cvms.ClientStreamMessage) (Service, error) {
return &service{
service: svc,
computationID: computationID,
conn: conn,
service: svc,
queue: queue,
}, nil
}
func (s *service) SendEvent(event, status string, details json.RawMessage) error {
body := EventsLogs{Message: &EventsLogs_AgentEvent{AgentEvent: &AgentEvent{
EventType: event,
Timestamp: timestamppb.Now(),
ComputationId: s.computationID,
Originator: s.service,
Status: status,
Details: details,
}}}
protoBody, err := proto.Marshal(&body)
if err != nil {
return err
func (s *service) SendEvent(cmpID, event, status string, details json.RawMessage) {
s.queue <- &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_AgentEvent{
AgentEvent: &cvms.AgentEvent{
EventType: event,
Timestamp: timestamppb.Now(),
ComputationId: cmpID,
Originator: s.service,
Status: status,
Details: details,
},
},
}
_, err = s.conn.Write(protoBody)
return err
}
+36 -79
View File
@@ -3,8 +3,8 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.28.1
// protoc-gen-go v1.36.0
// protoc v5.29.0
// source: agent/events/events.proto
package events
@@ -25,25 +25,22 @@ const (
)
type AgentEvent struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"`
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AgentEvent) Reset() {
*x = AgentEvent{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_events_events_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_events_events_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AgentEvent) String() string {
@@ -54,7 +51,7 @@ func (*AgentEvent) ProtoMessage() {}
func (x *AgentEvent) ProtoReflect() protoreflect.Message {
mi := &file_agent_events_events_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -112,23 +109,20 @@ func (x *AgentEvent) GetStatus() string {
}
type AgentLog struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
state protoimpl.MessageState `protogen:"open.v1"`
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AgentLog) Reset() {
*x = AgentLog{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_events_events_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_events_events_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AgentLog) String() string {
@@ -139,7 +133,7 @@ func (*AgentLog) ProtoMessage() {}
func (x *AgentLog) ProtoReflect() protoreflect.Message {
mi := &file_agent_events_events_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -183,24 +177,21 @@ func (x *AgentLog) GetTimestamp() *timestamppb.Timestamp {
}
type EventsLogs struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Message:
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Message:
//
// *EventsLogs_AgentLog
// *EventsLogs_AgentEvent
Message isEventsLogs_Message `protobuf_oneof:"message"`
Message isEventsLogs_Message `protobuf_oneof:"message"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *EventsLogs) Reset() {
*x = EventsLogs{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_events_events_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
mi := &file_agent_events_events_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *EventsLogs) String() string {
@@ -211,7 +202,7 @@ func (*EventsLogs) ProtoMessage() {}
func (x *EventsLogs) ProtoReflect() protoreflect.Message {
mi := &file_agent_events_events_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
@@ -226,23 +217,27 @@ func (*EventsLogs) Descriptor() ([]byte, []int) {
return file_agent_events_events_proto_rawDescGZIP(), []int{2}
}
func (m *EventsLogs) GetMessage() isEventsLogs_Message {
if m != nil {
return m.Message
func (x *EventsLogs) GetMessage() isEventsLogs_Message {
if x != nil {
return x.Message
}
return nil
}
func (x *EventsLogs) GetAgentLog() *AgentLog {
if x, ok := x.GetMessage().(*EventsLogs_AgentLog); ok {
return x.AgentLog
if x != nil {
if x, ok := x.Message.(*EventsLogs_AgentLog); ok {
return x.AgentLog
}
}
return nil
}
func (x *EventsLogs) GetAgentEvent() *AgentEvent {
if x, ok := x.GetMessage().(*EventsLogs_AgentEvent); ok {
return x.AgentEvent
if x != nil {
if x, ok := x.Message.(*EventsLogs_AgentEvent); ok {
return x.AgentEvent
}
}
return nil
}
@@ -342,44 +337,6 @@ func file_agent_events_events_proto_init() {
if File_agent_events_events_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_agent_events_events_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*AgentEvent); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_events_events_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*AgentLog); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_events_events_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*EventsLogs); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_agent_events_events_proto_msgTypes[2].OneofWrappers = []any{
(*EventsLogs_AgentLog)(nil),
(*EventsLogs_AgentEvent)(nil),
+17 -43
View File
@@ -3,62 +3,36 @@
package events
import (
"bytes"
"encoding/json"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"github.com/ultravioletrs/cocos/agent/cvms"
)
type mockConn struct {
writeErr error
buf bytes.Buffer
}
func (m *mockConn) Write(p []byte) (n int, err error) {
if m.writeErr != nil {
return 0, m.writeErr
}
return m.buf.Write(p)
}
func TestSendEventSuccess(t *testing.T) {
mockConnection := &mockConn{}
svc, err := New("test_service", "12345", mockConnection)
queue := make(chan *cvms.ClientStreamMessage, 1)
svc, err := New("test_service", queue)
assert.NoError(t, err)
details := json.RawMessage(`{"key": "value"}`)
err = svc.SendEvent("test_event", "success", details)
assert.NoError(t, err)
go func() {
msg := <-queue
assert.NotNil(t, msg)
assert.NotNil(t, msg.GetAgentEvent())
assert.Equal(t, "test_event", msg.GetAgentEvent().EventType)
assert.Equal(t, "testid", msg.GetAgentEvent().ComputationId)
assert.Equal(t, "test_service", msg.GetAgentEvent().Originator)
assert.Equal(t, "success", msg.GetAgentEvent().Status)
var writtenMessage EventsLogs
err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage)
assert.NoError(t, err)
now := time.Now()
eventTimestamp := msg.GetAgentEvent().GetTimestamp().AsTime()
assert.WithinDuration(t, now, eventTimestamp, 1*time.Second)
}()
assert.Equal(t, "test_event", writtenMessage.GetAgentEvent().EventType)
assert.Equal(t, "12345", writtenMessage.GetAgentEvent().ComputationId)
assert.Equal(t, "test_service", writtenMessage.GetAgentEvent().Originator)
assert.Equal(t, "success", writtenMessage.GetAgentEvent().Status)
svc.SendEvent("testid", "test_event", "success", details)
now := time.Now()
eventTimestamp := writtenMessage.GetAgentEvent().GetTimestamp().AsTime()
assert.WithinDuration(t, now, eventTimestamp, 1*time.Second)
}
func TestSendEventFailure(t *testing.T) {
mockConnection := &mockConn{writeErr: errors.New("write error")}
svc, err := New("test_service", "12345", mockConnection)
assert.NoError(t, err)
details := json.RawMessage(`{"key": "value"}`)
err = svc.SendEvent("test_event", "failure", details)
assert.Error(t, err)
assert.Equal(t, "write error", err.Error())
time.Sleep(1 * time.Second)
}
+11 -23
View File
@@ -24,22 +24,9 @@ func (_m *Service) EXPECT() *Service_Expecter {
return &Service_Expecter{mock: &_m.Mock}
}
// SendEvent provides a mock function with given fields: event, status, details
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
ret := _m.Called(event, status, details)
if len(ret) == 0 {
panic("no return value specified for SendEvent")
}
var r0 error
if rf, ok := ret.Get(0).(func(string, string, json.RawMessage) error); ok {
r0 = rf(event, status, details)
} else {
r0 = ret.Error(0)
}
return r0
// SendEvent provides a mock function with given fields: cmpID, event, status, details
func (_m *Service) SendEvent(cmpID string, event string, status string, details json.RawMessage) {
_m.Called(cmpID, event, status, details)
}
// Service_SendEvent_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendEvent'
@@ -48,26 +35,27 @@ type Service_SendEvent_Call struct {
}
// SendEvent is a helper method to define mock.On call
// - cmpID string
// - event string
// - status string
// - details json.RawMessage
func (_e *Service_Expecter) SendEvent(event interface{}, status interface{}, details interface{}) *Service_SendEvent_Call {
return &Service_SendEvent_Call{Call: _e.mock.On("SendEvent", event, status, details)}
func (_e *Service_Expecter) SendEvent(cmpID interface{}, event interface{}, status interface{}, details interface{}) *Service_SendEvent_Call {
return &Service_SendEvent_Call{Call: _e.mock.On("SendEvent", cmpID, event, status, details)}
}
func (_c *Service_SendEvent_Call) Run(run func(event string, status string, details json.RawMessage)) *Service_SendEvent_Call {
func (_c *Service_SendEvent_Call) Run(run func(cmpID string, event string, status string, details json.RawMessage)) *Service_SendEvent_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(string), args[2].(json.RawMessage))
run(args[0].(string), args[1].(string), args[2].(string), args[3].(json.RawMessage))
})
return _c
}
func (_c *Service_SendEvent_Call) Return(_a0 error) *Service_SendEvent_Call {
_c.Call.Return(_a0)
func (_c *Service_SendEvent_Call) Return() *Service_SendEvent_Call {
_c.Call.Return()
return _c
}
func (_c *Service_SendEvent_Call) RunAndReturn(run func(string, string, json.RawMessage) error) *Service_SendEvent_Call {
func (_c *Service_SendEvent_Call) RunAndReturn(run func(string, string, string, json.RawMessage)) *Service_SendEvent_Call {
_c.Call.Return(run)
return _c
}
+93
View File
@@ -179,6 +179,53 @@ func (_c *Service_Data_Call) RunAndReturn(run func(context.Context, agent.Datase
return _c
}
// InitComputation provides a mock function with given fields: ctx, cmp
func (_m *Service) InitComputation(ctx context.Context, cmp agent.Computation) error {
ret := _m.Called(ctx, cmp)
if len(ret) == 0 {
panic("no return value specified for InitComputation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, agent.Computation) error); ok {
r0 = rf(ctx, cmp)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_InitComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitComputation'
type Service_InitComputation_Call struct {
*mock.Call
}
// InitComputation is a helper method to define mock.On call
// - ctx context.Context
// - cmp agent.Computation
func (_e *Service_Expecter) InitComputation(ctx interface{}, cmp interface{}) *Service_InitComputation_Call {
return &Service_InitComputation_Call{Call: _e.mock.On("InitComputation", ctx, cmp)}
}
func (_c *Service_InitComputation_Call) Run(run func(ctx context.Context, cmp agent.Computation)) *Service_InitComputation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(agent.Computation))
})
return _c
}
func (_c *Service_InitComputation_Call) Return(_a0 error) *Service_InitComputation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Service_InitComputation_Call) RunAndReturn(run func(context.Context, agent.Computation) error) *Service_InitComputation_Call {
_c.Call.Return(run)
return _c
}
// Result provides a mock function with given fields: ctx
func (_m *Service) Result(ctx context.Context) ([]byte, error) {
ret := _m.Called(ctx)
@@ -237,6 +284,52 @@ func (_c *Service_Result_Call) RunAndReturn(run func(context.Context) ([]byte, e
return _c
}
// StopComputation provides a mock function with given fields: ctx
func (_m *Service) StopComputation(ctx context.Context) error {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for StopComputation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_StopComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopComputation'
type Service_StopComputation_Call struct {
*mock.Call
}
// StopComputation is a helper method to define mock.On call
// - ctx context.Context
func (_e *Service_Expecter) StopComputation(ctx interface{}) *Service_StopComputation_Call {
return &Service_StopComputation_Call{Call: _e.mock.On("StopComputation", ctx)}
}
func (_c *Service_StopComputation_Call) Run(run func(ctx context.Context)) *Service_StopComputation_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *Service_StopComputation_Call) Return(_a0 error) *Service_StopComputation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *Service_StopComputation_Call) RunAndReturn(run func(context.Context) error) *Service_StopComputation_Call {
_c.Call.Return(run)
return _c
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
+68 -19
View File
@@ -104,6 +104,8 @@ var (
// Service specifies an API that must be fullfiled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
InitComputation(ctx context.Context, cmp Computation) error
StopComputation(ctx context.Context) error
Algo(ctx context.Context, algorithm Algorithm) error
Data(ctx context.Context, dataset Dataset) error
Result(ctx context.Context) ([]byte, error)
@@ -121,19 +123,21 @@ type agentService struct {
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
logger *slog.Logger // Logger for the agent service.
resultsConsumed bool // Indicates if the results have been consumed.
cancel context.CancelFunc // Cancels the computation context.
}
var _ Service = (*agentService)(nil)
// New instantiates the agent service implementation.
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation, quoteProvider client.QuoteProvider) Service {
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.QuoteProvider) Service {
sm := statemachine.NewStateMachine(Idle)
ctx, cancel := context.WithCancel(ctx)
svc := &agentService{
sm: sm,
eventSvc: eventSvc,
quoteProvider: quoteProvider,
logger: logger,
computation: cmp,
cancel: cancel,
}
transitions := []statemachine.Transition{
@@ -141,13 +145,6 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
}
if len(cmp.Datasets) == 0 {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
} else {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
}
transitions = append(transitions, []statemachine.Transition{
{From: Running, Event: RunComplete, To: ConsumingResults},
{From: Running, Event: RunFailed, To: Failed},
@@ -158,8 +155,6 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
sm.AddTransition(t)
}
sm.SetAction(Idle, svc.publishEvent(IdleState.String()))
sm.SetAction(ReceivingManifest, svc.publishEvent(InProgress.String()))
sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String()))
sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String()))
sm.SetAction(Running, svc.runComputation)
@@ -173,11 +168,67 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
}
}()
sm.SendEvent(Start)
defer sm.SendEvent(ManifestReceived)
return svc
}
func (as *agentService) InitComputation(ctx context.Context, cmp Computation) error {
defer as.sm.SendEvent(ManifestReceived)
if as.sm.GetState() != ReceivingManifest {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
as.computation = cmp
transitions := []statemachine.Transition{}
if len(cmp.Datasets) == 0 {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
} else {
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
}
for _, t := range transitions {
as.sm.AddTransition(t)
}
return nil
}
func (as *agentService) StopComputation(ctx context.Context) error {
as.mu.Lock()
defer as.mu.Unlock()
as.cancel()
if err := as.algorithm.Stop(); err != nil {
return fmt.Errorf("error stopping computation: %v", err)
}
sm := statemachine.NewStateMachine(Idle)
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
return fmt.Errorf("error removing datasets directory: %v", err)
}
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
return fmt.Errorf("error removing results directory: %v", err)
}
as.sm = sm
as.computation = Computation{}
as.algorithm = nil
as.result = nil
as.runError = nil
as.resultsConsumed = false
return nil
}
func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
if as.sm.GetState() != ReceivingAlgorithm {
return ErrStateNotReady
@@ -225,7 +276,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
switch algoType {
case string(algorithm.AlgoTypeBin):
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args, as.computation.ID)
case string(algorithm.AlgoTypePython):
var requirementsFile string
if len(algo.Requirements) > 0 {
@@ -243,11 +294,11 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
requirementsFile = fr.Name()
}
runtime := python.PythonRunTimeFromContext(ctx)
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args)
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args, as.computation.ID)
case string(algorithm.AlgoTypeWasm):
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, args, f.Name(), as.computation.ID)
case string(algorithm.AlgoTypeDocker):
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name())
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name(), as.computation.ID)
}
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
@@ -400,8 +451,6 @@ func (as *agentService) runComputation(state statemachine.State) {
func (as *agentService) publishEvent(status string) statemachine.Action {
return func(state statemachine.State) {
if err := as.eventSvc.SendEvent(state.String(), status, json.RawMessage{}); err != nil {
as.logger.Warn(err.Error())
}
as.eventSvc.SendEvent(as.computation.ID, state.String(), status, json.RawMessage{})
}
}
+23 -29
View File
@@ -35,11 +35,6 @@ var (
const datasetFile = "iris.csv"
func TestAlgo(t *testing.T) {
events := new(mocks.Service)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
qp, err := quoteprovider.GetQuoteProvider()
require.NoError(t, err)
@@ -120,9 +115,15 @@ func TestAlgo(t *testing.T) {
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
)
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
svc := New(ctx, mglog.NewMock(), events, qp)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
time.Sleep(300 * time.Millisecond)
@@ -138,11 +139,6 @@ func TestAlgo(t *testing.T) {
}
func TestData(t *testing.T) {
events := new(mocks.Service)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
qp, err := quoteprovider.GetQuoteProvider()
require.NoError(t, err)
@@ -209,6 +205,9 @@ func TestData(t *testing.T) {
python.PyRuntimeKey, python.PyRuntime),
)
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
if tc.err != ErrUndeclaredDataset {
ctx = IndexToContext(ctx, 0)
}
@@ -216,13 +215,16 @@ func TestData(t *testing.T) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
comp := testComputation(t)
svc := New(ctx, mglog.NewMock(), events, qp)
err := svc.InitComputation(ctx, testComputation(t))
require.NoError(t, err)
svc := New(ctx, mglog.NewMock(), events, comp, qp)
time.Sleep(300 * time.Millisecond)
if tc.err != ErrStateNotReady {
_ = svc.Algo(ctx, alg)
err = svc.Algo(ctx, alg)
require.NoError(t, err)
time.Sleep(300 * time.Millisecond)
}
err = svc.Data(ctx, tc.data)
@@ -238,11 +240,6 @@ func TestData(t *testing.T) {
}
func TestResult(t *testing.T) {
events := new(mocks.Service)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
qp, err := quoteprovider.GetQuoteProvider()
require.NoError(t, err)
@@ -285,6 +282,9 @@ func TestResult(t *testing.T) {
}
for _, tc := range cases {
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
t.Run(tc.name, func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
@@ -323,12 +323,8 @@ func TestResult(t *testing.T) {
}
func TestAttestation(t *testing.T) {
events := new(mocks.Service)
qp := new(mocks2.QuoteProvider)
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
defer evCall.Unset()
cases := []struct {
name string
reportData [ReportDataSize]byte
@@ -350,6 +346,9 @@ func TestAttestation(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
events := new(mocks.Service)
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
ctx := metadata.NewIncomingContext(context.Background(),
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
)
@@ -362,7 +361,7 @@ func TestAttestation(t *testing.T) {
}
defer getQuote.Unset()
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
svc := New(ctx, mglog.NewMock(), events, qp)
time.Sleep(300 * time.Millisecond)
_, err := svc.Attestation(ctx, tc.reportData)
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
@@ -397,10 +396,5 @@ func testComputation(t *testing.T) Computation {
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
AgentConfig: AgentConfig{
Port: "7002",
LogLevel: "debug",
AttestedTls: false,
},
}
}
+1 -1
View File
@@ -131,7 +131,7 @@ func TestManifestChecksum(t *testing.T) {
"name": "Example Computation",
"description": "This is an example computation"
}`,
expectedSum: "868825367c32c4b6d621d5d95e2890f233d8554df2348ab743aac2663a936f08",
expectedSum: "a99683e4d22ba54cefa51aa49fb2e97a92b828c088395992ddff16a6236f3299",
},
{
name: "Invalid JSON",
+57 -182
View File
@@ -3,77 +3,68 @@
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"os"
"os/signal"
"syscall"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/prometheus"
"github.com/cenkalti/backoff/v4"
"github.com/google/go-sev-guest/abi"
"github.com/caarlos0/env/v11"
"github.com/google/go-sev-guest/client"
"github.com/mdlayher/vsock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/api"
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
"github.com/ultravioletrs/cocos/agent/auth"
"github.com/ultravioletrs/cocos/agent/cvms"
cvmapi "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
"github.com/ultravioletrs/cocos/agent/cvms/server"
"github.com/ultravioletrs/cocos/agent/events"
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
managerevents "github.com/ultravioletrs/cocos/manager/events"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
"golang.org/x/crypto/sha3"
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
cvmgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/reflection"
)
const (
svcName = "agent"
defSvcGRPCPort = "7002"
retryInterval = 5 * time.Second
svcName = "agent"
defSvcGRPCPort = "7002"
retryInterval = 5 * time.Second
envPrefixCVMGRPC = "AGENT_CVM_GRPC_"
)
type config struct {
LogLevel string `env:"AGENT_LOG_LEVEL" envDefault:"debug"`
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
cfg, err := readConfig()
if err != nil {
log.Fatalf("failed to read agent configuration from vsock %s", err.Error())
var cfg config
if err := env.Parse(&cfg); err != nil {
log.Fatalf("failed to load %s configuration : %s", svcName, err)
}
conn, err := dialVsock()
if err != nil {
log.Fatal(err)
}
defer conn.Close()
ackConn := ackvsock.NewAckWriter(conn)
var exitCode int
defer mglog.ExitWithError(&exitCode)
var level slog.Level
if err := level.UnmarshalText([]byte(cfg.AgentConfig.LogLevel)); err != nil {
if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil {
log.Println(err)
exitCode = 1
return
}
handler := agentlogger.NewProtoHandler(ackConn, &slog.HandlerOptions{Level: level}, cfg.ID)
eventsLogsQueue := make(chan *cvms.ClientStreamMessage, 1000)
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, eventsLogsQueue)
logger := slog.New(handler)
eventSvc, err := events.New(svcName, cfg.ID, ackConn)
eventSvc, err := events.New(svcName, eventsLogsQueue)
if err != nil {
logger.Error(fmt.Sprintf("failed to create events service %s", err.Error()))
exitCode = 1
@@ -87,64 +78,49 @@ func main() {
return
}
if err := verifyManifest(cfg, qp); err != nil {
cvmGrpcConfig := pkggrpc.CVMClientConfig{}
if err := env.ParseWithOptions(&cvmGrpcConfig, env.Options{Prefix: envPrefixCVMGRPC}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
exitCode = 1
return
}
cvmGRPCClient, cvmClient, err := cvmgrpc.NewCVMClient(cvmGrpcConfig)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer cvmGRPCClient.Close()
pc, err := cvmClient.Process(ctx)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
setDefaultValues(&cfg)
svc := newService(ctx, logger, eventSvc, qp)
svc := newService(ctx, logger, eventSvc, cfg, qp)
agentGrpcServerConfig := server.AgentConfig{
ServerConfig: server.ServerConfig{
BaseConfig: server.BaseConfig{
Host: cfg.AgentConfig.Host,
Port: cfg.AgentConfig.Port,
CertFile: cfg.AgentConfig.CertFile,
KeyFile: cfg.AgentConfig.KeyFile,
ServerCAFile: cfg.AgentConfig.ServerCAFile,
ClientCAFile: cfg.AgentConfig.ClientCAFile,
},
},
AttestedTLS: cfg.AgentConfig.AttestedTls,
}
registerAgentServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(svc))
}
authSvc, err := auth.New(cfg)
if err != nil {
logger.Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
exitCode = 1
return
}
gs := grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, logger, qp, authSvc)
mc := cvmapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc))
g.Go(func() error {
for {
if _, err := io.Copy(io.Discard, conn); err != nil {
log.Printf("vsock connection lost: %v, reconnecting...", err)
conn.Close()
conn, err = dialVsock()
if err != nil {
log.Fatal("failed to reconnect: ", err)
}
}
time.Sleep(retryInterval)
ch := make(chan os.Signal, 1)
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(ch)
select {
case <-ch:
logger.Info("Received signal, shutting down...")
cancel()
return nil
case <-ctx.Done():
return ctx.Err()
}
})
g.Go(func() error {
return gs.Start()
})
g.Go(func() error {
return server.StopHandler(ctx, cancel, logger, svcName, gs)
return mc.Process(ctx, cancel)
})
if err := g.Wait(); err != nil {
@@ -152,8 +128,8 @@ func main() {
}
}
func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp agent.Computation, qp client.QuoteProvider) agent.Service {
svc := agent.New(ctx, logger, eventSvc, cmp, qp)
func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, qp client.QuoteProvider) agent.Service {
svc := agent.New(ctx, logger, eventSvc, qp)
svc = api.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics(svcName, "api")
@@ -161,104 +137,3 @@ func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Servic
return svc
}
func readConfig() (agent.Computation, error) {
l, err := vsock.Listen(qemu.VsockConfigPort, nil)
if err != nil {
return agent.Computation{}, err
}
defer l.Close()
conn, err := l.Accept()
if err != nil {
return agent.Computation{}, err
}
defer conn.Close()
var buffer []byte
for {
chunk := make([]byte, 1024)
n, err := conn.Read(chunk)
if err != nil {
if err == io.EOF {
break
}
return agent.Computation{}, err
}
buffer = append(buffer, chunk[:n]...)
}
ac := agent.Computation{
AgentConfig: agent.AgentConfig{},
}
if err := json.Unmarshal(buffer, &ac); err != nil {
return agent.Computation{}, err
}
return ac, nil
}
func setDefaultValues(cfg *agent.Computation) {
if cfg.AgentConfig.LogLevel == "" {
cfg.AgentConfig.LogLevel = "info"
}
if cfg.AgentConfig.Port == "" {
cfg.AgentConfig.Port = defSvcGRPCPort
}
}
func isTEE() bool {
_, err := os.Stat("/dev/sev-guest")
return !os.IsNotExist(err)
}
func dialVsock() (*vsock.Conn, error) {
var conn *vsock.Conn
var err error
err = backoff.Retry(func() error {
conn, err = vsock.Dial(vsock.Host, managerevents.ManagerVsockPort, nil)
if err == nil {
log.Println("vsock connection established")
return nil
}
log.Printf("vsock connection failed, retrying in %s... Error: %v", retryInterval, err)
return err
}, backoff.NewExponentialBackOff())
if err != nil {
return nil, err
}
return conn, nil
}
func verifyManifest(cfg agent.Computation, qp client.QuoteProvider) error {
if !isTEE() {
return nil
}
ar, err := qp.GetRawQuote(sha3.Sum512([]byte(cfg.ID)))
if err != nil {
return err
}
arProto, err := abi.ReportCertsToProto(ar[:abi.ReportSize])
if err != nil {
return err
}
cfgBytes, err := json.Marshal(cfg)
if err != nil {
return err
}
mcHash := sha3.Sum256(cfgBytes)
if arProto.Report.HostData == nil {
return fmt.Errorf("manifest verification failed: HostData is nil")
}
if !bytes.Equal(arProto.Report.HostData, mcHash[:]) {
return fmt.Errorf("manifest verification failed")
}
return nil
}
-94
View File
@@ -1,94 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package main
import (
"context"
"log/slog"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/agent/events/mocks"
qpmocks "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
)
func TestSetDefaultValues(t *testing.T) {
tests := []struct {
name string
input agent.Computation
expected agent.Computation
}{
{
name: "Empty config",
input: agent.Computation{
AgentConfig: agent.AgentConfig{},
},
expected: agent.Computation{
AgentConfig: agent.AgentConfig{
LogLevel: "info",
Port: "7002",
},
},
},
{
name: "Partial config",
input: agent.Computation{
AgentConfig: agent.AgentConfig{
LogLevel: "debug",
},
},
expected: agent.Computation{
AgentConfig: agent.AgentConfig{
LogLevel: "debug",
Port: "7002",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setDefaultValues(&tt.input)
assert.Equal(t, tt.expected, tt.input)
})
}
}
func TestNewService(t *testing.T) {
ctx := context.Background()
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
eventSvc := new(mocks.Service)
eventSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
cmp := agent.Computation{
ID: "test-computation",
AgentConfig: agent.AgentConfig{
LogLevel: "info",
Port: "7002",
},
}
qp := new(qpmocks.QuoteProvider)
svc := newService(ctx, logger, eventSvc, cmp, qp)
assert.NotNil(t, svc)
}
func TestVerifyManifest(t *testing.T) {
cfg := agent.Computation{
ID: "test-computation",
AgentConfig: agent.AgentConfig{
LogLevel: "info",
Port: "7002",
},
}
mockQP := new(qpmocks.QuoteProvider)
mockQP.On("GetRawQuote", mock.Anything).Return([]byte{}, nil)
err := verifyManifest(cfg, mockQP)
assert.NoError(t, err)
}
+1 -1
View File
@@ -92,7 +92,7 @@ func main() {
args := qemuCfg.ConstructQemuArgs()
logger.Info(strings.Join(args, " "))
managerGRPCConfig := pkggrpc.ManagerClientConfig{}
managerGRPCConfig := pkggrpc.CVMClientConfig{}
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
exitCode = 1
+24 -5
View File
@@ -7,8 +7,9 @@ import (
"io"
"log/slog"
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/agent/events"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
)
@@ -18,16 +19,17 @@ type handler struct {
opts slog.HandlerOptions
w io.Writer
cmpID string
queue chan *cvms.ClientStreamMessage
}
func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler {
func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, queue chan *cvms.ClientStreamMessage) slog.Handler {
if opts == nil {
opts = &slog.HandlerOptions{}
}
h := &handler{
opts: *opts,
w: conn,
cmpID: cmpID,
queue: queue,
}
return h
@@ -69,7 +71,18 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
},
}
b, err := proto.Marshal(&agentLog)
h.queue <- &cvms.ClientStreamMessage{
Message: &cvms.ClientStreamMessage_AgentLog{
AgentLog: &cvms.AgentLog{
Timestamp: timestamp,
Message: chunk,
Level: level,
ComputationId: h.cmpID,
},
},
}
b, err := protojson.Marshal(&agentLog)
if err != nil {
return err
}
@@ -78,6 +91,11 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
if err != nil {
return err
}
_, err = h.w.Write([]byte("\n"))
if err != nil {
return err
}
}
return nil
@@ -88,7 +106,8 @@ func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler {
}
func (h *handler) WithGroup(name string) slog.Handler {
panic("unimplemented")
h.cmpID = name
return h
}
func (h *handler) Close() error {
+6 -5
View File
@@ -11,6 +11,7 @@ import (
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/agent/cvms"
)
type failedWriter struct{}
@@ -21,14 +22,14 @@ func (f *failedWriter) Write(p []byte) (n int, err error) {
// TestNewProtoHandler tests the initialization of the ProtoHandler.
func TestNewProtoHandler(t *testing.T) {
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
handler := NewProtoHandler(io.Discard, nil, make(chan *cvms.ClientStreamMessage))
assert.NotNil(t, handler, "Handler should not be nil")
}
// TestHandleMessageSuccess tests the handling of a message when the write succeeds.
func TestHandleMessageSuccess(t *testing.T) {
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
handler := NewProtoHandler(io.Discard, nil, make(chan *cvms.ClientStreamMessage, 1))
record := slog.Record{
Time: time.Now(),
Message: "Test message",
@@ -42,7 +43,7 @@ func TestHandleMessageSuccess(t *testing.T) {
// TestHandleMessageFailure tests the caching mechanism when the write fails.
func TestHandleMessageFailure(t *testing.T) {
protohandler := NewProtoHandler(&failedWriter{}, nil, "testCmpID")
protohandler := NewProtoHandler(&failedWriter{}, nil, make(chan *cvms.ClientStreamMessage, 1))
record := slog.Record{
Time: time.Now(),
Message: "Test message",
@@ -56,7 +57,7 @@ func TestHandleMessageFailure(t *testing.T) {
// TestEnabled tests that the handler enables logging based on level.
func TestEnabled(t *testing.T) {
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
handler := NewProtoHandler(io.Discard, nil, make(chan *cvms.ClientStreamMessage, 1))
assert.True(t, handler.Enabled(context.Background(), slog.LevelInfo), "Logging should be enabled for LevelInfo")
assert.False(t, handler.Enabled(context.Background(), slog.LevelDebug), "Logging should be disabled for LevelDebug by default")
@@ -66,7 +67,7 @@ func TestEnabled(t *testing.T) {
func TestCloseStopsRetry(t *testing.T) {
mockWriter := io.Discard
handler := NewProtoHandler(mockWriter, nil, "testCmpID").(*handler)
handler := NewProtoHandler(mockWriter, nil, make(chan *cvms.ClientStreamMessage, 1)).(*handler)
time.Sleep(2 * time.Second)
err := handler.Close()
+213 -461
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -4,7 +4,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.28.1
// - protoc v5.29.0
// source: manager/manager.proto
package manager
-10
View File
@@ -166,16 +166,6 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
ID: c.Id,
Name: c.Name,
Description: c.Description,
AgentConfig: agent.AgentConfig{
Port: c.AgentConfig.Port,
Host: c.AgentConfig.Host,
KeyFile: c.AgentConfig.KeyFile,
CertFile: c.AgentConfig.CertFile,
ServerCAFile: c.AgentConfig.ServerCaFile,
ClientCAFile: c.AgentConfig.ClientCaFile,
LogLevel: c.AgentConfig.LogLevel,
AttestedTls: c.AgentConfig.AttestedTls,
},
}
if len(c.Algorithm.Hash) != hashLength {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
+7
View File
@@ -93,3 +93,10 @@ packages:
dir: "{{.InterfaceDir}}/mocks"
filename: "sdk.go"
mockname: "{{.InterfaceName}}"
github.com/ultravioletrs/cocos/agent/cvms/server:
interfaces:
AgentServerProvider:
config:
dir: "{{.InterfaceDir}}/mocks"
filename: "server.go"
mockname: "{{.InterfaceName}}"
+2 -2
View File
@@ -68,7 +68,7 @@ type AgentClientConfig struct {
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
}
type ManagerClientConfig struct {
type CVMClientConfig struct {
BaseConfig
}
@@ -80,7 +80,7 @@ func (a AgentClientConfig) GetBaseConfig() BaseConfig {
return a.BaseConfig
}
func (a ManagerClientConfig) GetBaseConfig() BaseConfig {
func (a CVMClientConfig) GetBaseConfig() BaseConfig {
return a.BaseConfig
}
+18
View File
@@ -0,0 +1,18 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package cvm
import (
"github.com/ultravioletrs/cocos/agent/cvms"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
)
// NewManagerClient creates new manager gRPC client instance.
func NewCVMClient(cfg grpc.CVMClientConfig) (grpc.Client, cvms.ServiceClient, error) {
client, err := grpc.NewClient(cfg)
if err != nil {
return nil, nil, err
}
return client, cvms.NewServiceClient(client.Connection()), nil
}
+1 -1
View File
@@ -8,7 +8,7 @@ import (
)
// NewManagerClient creates new manager gRPC client instance.
func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
func NewManagerClient(cfg grpc.CVMClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
client, err := grpc.NewClient(cfg)
if err != nil {
return nil, nil, err
+2 -2
View File
@@ -13,12 +13,12 @@ import (
func TestNewManagerClient(t *testing.T) {
tests := []struct {
name string
cfg grpc.ManagerClientConfig
cfg grpc.CVMClientConfig
err error
}{
{
name: "Valid config",
cfg: grpc.ManagerClientConfig{
cfg: grpc.CVMClientConfig{
BaseConfig: grpc.BaseConfig{
URL: "localhost:7001",
},
+14 -14
View File
@@ -13,18 +13,18 @@ import (
mglog "github.com/absmach/magistrala/logger"
"github.com/caarlos0/env/v11"
"github.com/ultravioletrs/cocos/agent/cvms"
cvmgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/manager"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
)
var _ managergrpc.Service = (*svc)(nil)
var _ cvmgrpc.Service = (*svc)(nil)
const (
svcName = "computations_test_server"
@@ -42,7 +42,7 @@ type svc struct {
logger *slog.Logger
}
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc.SendFunc, authInfo credentials.AuthInfo) {
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmgrpc.SendFunc, authInfo credentials.AuthInfo) {
s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAddress))
pubKey, err := os.ReadFile(pubKeyFile)
@@ -52,7 +52,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc
}
pubPem, _ := pem.Decode(pubKey)
var datasets []*manager.Dataset
var datasets []*cvms.Dataset
for _, dataPath := range dataPaths {
if _, err := os.Stat(dataPath); os.IsNotExist(err) {
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
@@ -64,7 +64,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc
return
}
datasets = append(datasets, &manager.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
datasets = append(datasets, &cvms.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
}
algoHash, err := internal.Checksum(algoPath)
@@ -73,16 +73,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc
return
}
if err := sendMessage(&manager.ServerStreamMessage{
Message: &manager.ServerStreamMessage_RunReq{
RunReq: &manager.ComputationRunReq{
if err := sendMessage(&cvms.ServerStreamMessage{
Message: &cvms.ServerStreamMessage_RunReq{
RunReq: &cvms.ComputationRunReq{
Id: "1",
Name: "sample computation",
Description: "sample descrption",
Datasets: datasets,
Algorithm: &manager.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
ResultConsumers: []*manager.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: &manager.AgentConfig{
Algorithm: &cvms.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
ResultConsumers: []*cvms.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: &cvms.AgentConfig{
Port: "7002",
LogLevel: "debug",
AttestedTls: attestedTLS,
@@ -113,7 +113,7 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
incomingChan := make(chan *manager.ClientStreamMessage)
incomingChan := make(chan *cvms.ClientStreamMessage)
logger, err := mglog.New(os.Stdout, "debug")
if err != nil {
@@ -128,7 +128,7 @@ func main() {
registerAgentServiceServer := func(srv *grpc.Server) {
reflection.Register(srv)
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(incomingChan, &svc{logger: logger}))
cvms.RegisterServiceServer(srv, cvmgrpc.NewServer(incomingChan, &svc{logger: logger}))
}
grpcServerConfig := server.ServerConfig{
BaseConfig: server.BaseConfig{
-127
View File
@@ -1,127 +0,0 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Simplified script to pass configs to agent without manager and read logs and events for manager.
// This tool is meant for testing purposes.
package main
import (
"encoding/json"
"encoding/pem"
"log"
"net"
"os"
"strconv"
"github.com/mdlayher/vsock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/internal"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/events"
"github.com/ultravioletrs/cocos/manager/qemu"
)
const (
managerVsockPort = events.ManagerVsockPort
vsockConfigPort = qemu.VsockConfigPort
)
func main() {
if len(os.Args) < 5 {
log.Fatalf("usage: %s <data-path> <algo-path> <public-key-path> <attested-tls-bool>", os.Args[0])
}
dataPath := os.Args[1]
algoPath := os.Args[2]
pubKeyFile := os.Args[3]
attestedTLSParam, err := strconv.ParseBool(os.Args[4])
if err != nil {
log.Fatalf("usage: %s <data-path> <algo-path> <public-key-path> <attested-tls-bool>, <attested-tls-bool> must be a bool value", os.Args[0])
}
attestedTLS := attestedTLSParam
pubKey, err := os.ReadFile(pubKeyFile)
if err != nil {
log.Fatalf("failed to read public key file: %s", err)
}
pubPem, _ := pem.Decode(pubKey)
algoHash, err := internal.Checksum(algoPath)
if err != nil {
log.Fatalf("failed to calculate checksum: %s", err)
}
dataHash, err := internal.Checksum(dataPath)
if err != nil {
log.Fatalf("failed to calculate checksum: %s", err)
}
ac := agent.Computation{
ID: "123",
Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}},
Algorithm: agent.Algorithm{Hash: [32]byte(algoHash), UserKey: pubPem.Bytes},
ResultConsumers: []agent.ResultConsumer{{UserKey: pubPem.Bytes}},
AgentConfig: agent.AgentConfig{
LogLevel: "debug",
Port: "7002",
AttestedTls: attestedTLS,
},
}
if err := sendAgentConfig(3, ac); err != nil {
log.Fatal(err)
}
listener, err := vsock.Listen(managerVsockPort, nil)
if err != nil {
log.Fatalf("failed to listen on vsock: %s", err)
}
defer listener.Close()
log.Printf("Listening on vsock port %d", managerVsockPort)
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("failed to accept connection: %s", err)
continue
}
go handleConnection(conn)
}
}
func sendAgentConfig(cid uint32, ac agent.Computation) error {
conn, err := vsock.Dial(cid, qemu.VsockConfigPort, nil)
if err != nil {
return err
}
defer conn.Close()
payload, err := json.Marshal(ac)
if err != nil {
return err
}
var ac2 agent.Computation
if err := json.Unmarshal(payload, &ac2); err != nil {
return err
}
if _, err := conn.Write(payload); err != nil {
return err
}
return nil
}
func handleConnection(conn net.Conn) {
defer conn.Close()
ackReader := internalvsock.NewAckReader(conn)
for {
var message manager.ClientStreamMessage
err := ackReader.ReadProto(&message)
if err != nil {
log.Printf("Error reading message: %v", err)
return
}
log.Printf("Received message: %s", message.String())
}
}