mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
COCOS-344 - New agent structure (#350)
* new agent structure Signed-off-by: Sammy Oina <sammyoina@gmail.com> * minor fixes and testing Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * cvm tests fix Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix cli test Signed-off-by: Sammy Oina <sammyoina@gmail.com> * rename Signed-off-by: Sammy Oina <sammyoina@gmail.com> * rename cvm to cvms plural Signed-off-by: Sammy Oina <sammyoina@gmail.com> * rename service Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * remove context Signed-off-by: Sammy Oina <sammyoina@gmail.com> * refactor: reorder parameters in NewAlgorithm functions and update CVMClient to CVMSClient Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix(tests): update SendEvent mock to include an additional parameter Signed-off-by: Sammy Oina <sammyoina@gmail.com> * move expectations Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix(tests): move event initialization to the correct scope in service tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix(tests): update SendEvent mock to use EXPECT instead of On in service tests Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
59b8057e5c
commit
ecad6514f3
@@ -33,8 +33,8 @@ jobs:
|
||||
|
||||
- name: Set up protoc
|
||||
run: |
|
||||
PROTOC_VERSION=28.1
|
||||
PROTOC_GEN_VERSION=v1.34.2
|
||||
PROTOC_VERSION=29.0
|
||||
PROTOC_GEN_VERSION=v1.36.0
|
||||
PROTOC_GRPC_VERSION=v1.5.1
|
||||
|
||||
# Download and install protoc
|
||||
|
||||
@@ -37,6 +37,7 @@ protoc:
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/agent.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative manager/manager.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/cvms/cvms.proto
|
||||
|
||||
mocks:
|
||||
mockery --config ./mockery.yml
|
||||
|
||||
+57
-176
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v5.28.1
|
||||
// protoc-gen-go v1.36.0
|
||||
// protoc v5.29.0
|
||||
// source: agent/agent.proto
|
||||
|
||||
package agent
|
||||
@@ -24,21 +24,18 @@ const (
|
||||
)
|
||||
|
||||
type AlgoRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"`
|
||||
Requirements []byte `protobuf:"bytes,2,opt,name=requirements,proto3" json:"requirements,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"`
|
||||
Requirements []byte `protobuf:"bytes,2,opt,name=requirements,proto3" json:"requirements,omitempty"`
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AlgoRequest) Reset() {
|
||||
*x = AlgoRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AlgoRequest) String() string {
|
||||
@@ -49,7 +46,7 @@ func (*AlgoRequest) ProtoMessage() {}
|
||||
|
||||
func (x *AlgoRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -79,18 +76,16 @@ func (x *AlgoRequest) GetRequirements() []byte {
|
||||
}
|
||||
|
||||
type AlgoResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AlgoResponse) Reset() {
|
||||
*x = AlgoResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AlgoResponse) String() string {
|
||||
@@ -101,7 +96,7 @@ func (*AlgoResponse) ProtoMessage() {}
|
||||
|
||||
func (x *AlgoResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -117,21 +112,18 @@ func (*AlgoResponse) Descriptor() ([]byte, []int) {
|
||||
}
|
||||
|
||||
type DataRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Dataset []byte `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"`
|
||||
Filename string `protobuf:"bytes,2,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Dataset []byte `protobuf:"bytes,1,opt,name=dataset,proto3" json:"dataset,omitempty"`
|
||||
Filename string `protobuf:"bytes,2,opt,name=filename,proto3" json:"filename,omitempty"`
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *DataRequest) Reset() {
|
||||
*x = DataRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *DataRequest) String() string {
|
||||
@@ -142,7 +134,7 @@ func (*DataRequest) ProtoMessage() {}
|
||||
|
||||
func (x *DataRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -172,18 +164,16 @@ func (x *DataRequest) GetFilename() string {
|
||||
}
|
||||
|
||||
type DataResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *DataResponse) Reset() {
|
||||
*x = DataResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *DataResponse) String() string {
|
||||
@@ -194,7 +184,7 @@ func (*DataResponse) ProtoMessage() {}
|
||||
|
||||
func (x *DataResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[3]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -210,18 +200,16 @@ func (*DataResponse) Descriptor() ([]byte, []int) {
|
||||
}
|
||||
|
||||
type ResultRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ResultRequest) Reset() {
|
||||
*x = ResultRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ResultRequest) String() string {
|
||||
@@ -232,7 +220,7 @@ func (*ResultRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ResultRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[4]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -248,20 +236,17 @@ func (*ResultRequest) Descriptor() ([]byte, []int) {
|
||||
}
|
||||
|
||||
type ResultResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ResultResponse) Reset() {
|
||||
*x = ResultResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[5]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[5]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ResultResponse) String() string {
|
||||
@@ -272,7 +257,7 @@ func (*ResultResponse) ProtoMessage() {}
|
||||
|
||||
func (x *ResultResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[5]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -295,20 +280,17 @@ func (x *ResultResponse) GetFile() []byte {
|
||||
}
|
||||
|
||||
type AttestationRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64.
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64.
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AttestationRequest) Reset() {
|
||||
*x = AttestationRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[6]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[6]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AttestationRequest) String() string {
|
||||
@@ -319,7 +301,7 @@ func (*AttestationRequest) ProtoMessage() {}
|
||||
|
||||
func (x *AttestationRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[6]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -342,20 +324,17 @@ func (x *AttestationRequest) GetReportData() []byte {
|
||||
}
|
||||
|
||||
type AttestationResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"`
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AttestationResponse) Reset() {
|
||||
*x = AttestationResponse{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_agent_proto_msgTypes[7]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_agent_proto_msgTypes[7]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AttestationResponse) String() string {
|
||||
@@ -366,7 +345,7 @@ func (*AttestationResponse) ProtoMessage() {}
|
||||
|
||||
func (x *AttestationResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_agent_proto_msgTypes[7]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -477,104 +456,6 @@ func file_agent_agent_proto_init() {
|
||||
if File_agent_agent_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_agent_agent_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AlgoRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AlgoResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*DataRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*DataResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ResultRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*ResultResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[6].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AttestationRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_agent_proto_msgTypes[7].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AttestationResponse); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.28.1
|
||||
// - protoc v5.29.0
|
||||
// source: agent/agent.proto
|
||||
|
||||
package agent
|
||||
|
||||
@@ -46,4 +46,7 @@ func AlgorithmArgsFromContext(ctx context.Context) []string {
|
||||
type Algorithm interface {
|
||||
// Run executes the algorithm and returns the result.
|
||||
Run() error
|
||||
|
||||
// Stop stops the algorithm.
|
||||
Stop() error
|
||||
}
|
||||
|
||||
@@ -20,29 +20,46 @@ type binary struct {
|
||||
stderr io.Writer
|
||||
stdout io.Writer
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string, cmpID string) algorithm.Algorithm {
|
||||
return &binary{
|
||||
algoFile: algoFile,
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
args: args,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *binary) Run() error {
|
||||
cmd := exec.Command(b.algoFile, b.args...)
|
||||
cmd.Stderr = b.stderr
|
||||
cmd.Stdout = b.stdout
|
||||
b.cmd = exec.Command(b.algoFile, b.args...)
|
||||
b.cmd.Stderr = b.stderr
|
||||
b.cmd.Stdout = b.stdout
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
if err := b.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
if err := b.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *binary) Stop() error {
|
||||
if b.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.cmd.ProcessState != nil && b.cmd.ProcessState.Exited() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := b.cmd.Process.Kill(); err != nil {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
algoFile := "/path/to/algo"
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile, args, "")
|
||||
|
||||
b, ok := algo.(*binary)
|
||||
if !ok {
|
||||
@@ -74,7 +74,7 @@ func TestBinaryRun(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventsSvc := new(mocks.Service)
|
||||
|
||||
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args).(*binary)
|
||||
b := NewAlgorithm(logger, eventsSvc, tt.algoFile, tt.args, "").(*binary)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
b.stdout = &stdout
|
||||
|
||||
@@ -35,11 +35,11 @@ type docker struct {
|
||||
stdout io.Writer
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string) algorithm.Algorithm {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile, cmpID string) algorithm.Algorithm {
|
||||
d := &docker{
|
||||
algoFile: algoFile,
|
||||
logger: logger,
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
}
|
||||
|
||||
@@ -47,8 +47,6 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
|
||||
}
|
||||
|
||||
func (d *docker) Run() error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a new Docker client.
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
if err != nil {
|
||||
@@ -62,6 +60,7 @@ func (d *docker) Run() error {
|
||||
}
|
||||
defer imageFile.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
// Load the Docker image from the tar file.
|
||||
resp, err := cli.ImageLoad(ctx, imageFile, true)
|
||||
if err != nil {
|
||||
@@ -176,3 +175,8 @@ func writeToOut(readCloser io.ReadCloser, ioWriter io.Writer) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *docker) Stop() error {
|
||||
// To be supported later.
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
eventsSvc := new(mocks.Service)
|
||||
algoFile := "/path/to/algo.tar"
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile)
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile, "")
|
||||
|
||||
d, ok := algo.(*docker)
|
||||
assert.True(t, ok, "NewAlgorithm should return a *docker")
|
||||
|
||||
@@ -50,6 +50,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) {
|
||||
type Stderr struct {
|
||||
Logger *slog.Logger
|
||||
EventSvc events.Service
|
||||
CmpID string
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
@@ -70,9 +71,7 @@ func (s *Stderr) Write(p []byte) (n int, err error) {
|
||||
s.Logger.Error(string(buf[:n]))
|
||||
}
|
||||
|
||||
if err := s.EventSvc.SendEvent(algorithmRun, warningStatus, json.RawMessage{}); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
s.EventSvc.SendEvent(s.CmpID, algorithmRun, warningStatus, json.RawMessage{})
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestStderrWrite(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockEventService := mocks.NewService(t)
|
||||
mockEventService.On("SendEvent", "AlgorithmRun", manager.Warning.String(), mock.Anything).Return(nil)
|
||||
mockEventService.On("SendEvent", mock.Anything, "AlgorithmRun", manager.Warning.String(), mock.Anything).Return(nil)
|
||||
|
||||
stderr := &Stderr{Logger: mglog.NewMock(), EventSvc: mockEventService}
|
||||
n, err := stderr.Write([]byte(tt.input))
|
||||
|
||||
@@ -39,12 +39,13 @@ type python struct {
|
||||
runtime string
|
||||
requirementsFile string
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string) algorithm.Algorithm {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string, cmpID string) algorithm.Algorithm {
|
||||
p := &python{
|
||||
algoFile: algoFile,
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
requirementsFile: requirementsFile,
|
||||
args: args,
|
||||
@@ -85,15 +86,15 @@ func (p *python) Run() error {
|
||||
}
|
||||
|
||||
args := append([]string{p.algoFile}, p.args...)
|
||||
cmd := exec.Command(pythonPath, args...)
|
||||
cmd.Stderr = p.stderr
|
||||
cmd.Stdout = p.stdout
|
||||
p.cmd = exec.Command(pythonPath, args...)
|
||||
p.cmd.Stderr = p.stderr
|
||||
p.cmd.Stdout = p.stdout
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
if err := p.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
if err := p.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
@@ -103,3 +104,19 @@ func (p *python) Run() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *python) Stop() error {
|
||||
if p.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.cmd.ProcessState != nil && p.cmd.ProcessState.Exited() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.cmd.Process.Kill(); err != nil {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
algoFile := "algorithm.py"
|
||||
args := []string{"--arg1", "value1"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args)
|
||||
algo := NewAlgorithm(logger, eventsSvc, runtime, requirementsFile, algoFile, args, "")
|
||||
|
||||
p, ok := algo.(*python)
|
||||
if !ok {
|
||||
|
||||
@@ -24,12 +24,13 @@ type wasm struct {
|
||||
stderr io.Writer
|
||||
stdout io.Writer
|
||||
args []string
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, args []string, algoFile, cmpID string) algorithm.Algorithm {
|
||||
return &wasm{
|
||||
algoFile: algoFile,
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc, CmpID: cmpID},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
args: args,
|
||||
}
|
||||
@@ -38,17 +39,33 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
|
||||
func (w *wasm) Run() error {
|
||||
args := append(mapDirOption, w.algoFile)
|
||||
args = append(args, w.args...)
|
||||
cmd := exec.Command(wasmRuntime, args...)
|
||||
cmd.Stderr = w.stderr
|
||||
cmd.Stdout = w.stdout
|
||||
w.cmd = exec.Command(wasmRuntime, args...)
|
||||
w.cmd.Stderr = w.stderr
|
||||
w.cmd.Stdout = w.stdout
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
if err := w.cmd.Start(); err != nil {
|
||||
return fmt.Errorf("error starting algorithm: %v", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
if err := w.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("algorithm execution error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wasm) Stop() error {
|
||||
if w.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.cmd.ProcessState != nil && w.cmd.ProcessState.Exited() {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := w.cmd.Process.Kill(); err != nil {
|
||||
return fmt.Errorf("error stopping algorithm: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
algoFile := "test.wasm"
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
algo := NewAlgorithm(logger, eventsSvc, algoFile, args)
|
||||
algo := NewAlgorithm(logger, eventsSvc, args, algoFile, "")
|
||||
|
||||
w, ok := algo.(*wasm)
|
||||
if !ok {
|
||||
@@ -54,7 +54,7 @@ func TestRunError(t *testing.T) {
|
||||
algoFile := "test.wasm"
|
||||
args := []string{"arg1", "arg2"}
|
||||
|
||||
w := NewAlgorithm(logger, eventsSvc, algoFile, args).(*wasm)
|
||||
w := NewAlgorithm(logger, eventsSvc, args, algoFile, "").(*wasm)
|
||||
|
||||
err := w.Run()
|
||||
if err == nil {
|
||||
|
||||
@@ -27,6 +27,34 @@ func LoggingMiddleware(svc agent.Service, logger *slog.Logger) agent.Service {
|
||||
return &loggingMiddleware{logger, svc}
|
||||
}
|
||||
|
||||
// InitComputation implements agent.Service.
|
||||
func (lm *loggingMiddleware) InitComputation(ctx context.Context, cmp agent.Computation) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method InitComputation for computation id %s took %s to complete", cmp.ID, time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.WithGroup(cmp.ID).Warn(fmt.Sprintf("%s with error: %s", message, err))
|
||||
return
|
||||
}
|
||||
lm.logger.WithGroup(cmp.ID).Info(fmt.Sprintf("%s without errors", message))
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.InitComputation(ctx, cmp)
|
||||
}
|
||||
|
||||
// StopComputation implements agent.Service.
|
||||
func (lm *loggingMiddleware) StopComputation(ctx context.Context) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method StopComputation took %s to complete", time.Since(begin))
|
||||
if err != nil {
|
||||
lm.logger.Warn(fmt.Sprintf("%s with error: %s", message, err))
|
||||
return
|
||||
}
|
||||
lm.logger.Info(fmt.Sprintf("%s without errors", message))
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.StopComputation(ctx)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Algo(ctx context.Context, algorithm agent.Algorithm) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method Algo took %s to complete", time.Since(begin))
|
||||
|
||||
@@ -32,6 +32,26 @@ func MetricsMiddleware(svc agent.Service, counter metrics.Counter, latency metri
|
||||
}
|
||||
}
|
||||
|
||||
// InitComputation implements agent.Service.
|
||||
func (ms *metricsMiddleware) InitComputation(ctx context.Context, cmp agent.Computation) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "init_computation").Add(1)
|
||||
ms.latency.With("method", "init_computation").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.InitComputation(ctx, cmp)
|
||||
}
|
||||
|
||||
// StopComputation implements agent.Service.
|
||||
func (ms *metricsMiddleware) StopComputation(ctx context.Context) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "stop_computation").Add(1)
|
||||
ms.latency.With("method", "stop_computation").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.StopComputation(ctx)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Algo(ctx context.Context, algorithm agent.Algorithm) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "algo").Add(1)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
var _ fmt.Stringer = (*Datasets)(nil)
|
||||
|
||||
type AgentConfig struct {
|
||||
LogLevel string `json:"log_level,omitempty"`
|
||||
Host string `json:"host,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
CertFile string `json:"cert_file,omitempty"`
|
||||
@@ -30,7 +29,6 @@ type Computation struct {
|
||||
Datasets Datasets `json:"datasets,omitempty"`
|
||||
Algorithm Algorithm `json:"algorithm,omitempty"`
|
||||
ResultConsumers []ResultConsumer `json:"result_consumers,omitempty"`
|
||||
AgentConfig AgentConfig `json:"agent_config,omitempty"`
|
||||
}
|
||||
|
||||
type ResultConsumer struct {
|
||||
|
||||
@@ -106,7 +106,6 @@ func TestDecompressToContext(t *testing.T) {
|
||||
|
||||
func TestAgentConfigJSON(t *testing.T) {
|
||||
config := AgentConfig{
|
||||
LogLevel: "info",
|
||||
Host: "localhost",
|
||||
Port: "8080",
|
||||
CertFile: "cert.pem",
|
||||
|
||||
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/server"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
errCorruptedManifest = errors.New("received manifest may be corrupted")
|
||||
errUnknonwMessageType = errors.New("unknown message type")
|
||||
sendTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type CVMSClient struct {
|
||||
mu sync.Mutex
|
||||
stream cvms.Service_ProcessClient
|
||||
svc agent.Service
|
||||
messageQueue chan *cvms.ClientStreamMessage
|
||||
logger *slog.Logger
|
||||
runReqManager *runRequestManager
|
||||
sp server.AgentServer
|
||||
}
|
||||
|
||||
// NewClient returns new gRPC client instance.
|
||||
func NewClient(stream cvms.Service_ProcessClient, svc agent.Service, messageQueue chan *cvms.ClientStreamMessage, logger *slog.Logger, sp server.AgentServer) CVMSClient {
|
||||
return CVMSClient{
|
||||
stream: stream,
|
||||
svc: svc,
|
||||
messageQueue: messageQueue,
|
||||
logger: logger,
|
||||
runReqManager: newRunRequestManager(),
|
||||
sp: sp,
|
||||
}
|
||||
}
|
||||
|
||||
func (client *CVMSClient) Process(ctx context.Context, cancel context.CancelFunc) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
eg.Go(func() error {
|
||||
return client.handleIncomingMessages(ctx)
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
return client.handleOutgoingMessages(ctx)
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (client *CVMSClient) handleIncomingMessages(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := client.stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := client.processIncomingMessage(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *CVMSClient) processIncomingMessage(ctx context.Context, req *cvms.ServerStreamMessage) error {
|
||||
switch mes := req.Message.(type) {
|
||||
case *cvms.ServerStreamMessage_RunReqChunks:
|
||||
return client.handleRunReqChunks(ctx, mes)
|
||||
case *cvms.ServerStreamMessage_StopComputation:
|
||||
go client.handleStopComputation(ctx, mes)
|
||||
default:
|
||||
return errUnknonwMessageType
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *CVMSClient) handleRunReqChunks(ctx context.Context, msg *cvms.ServerStreamMessage_RunReqChunks) error {
|
||||
buffer, complete := client.runReqManager.addChunk(msg.RunReqChunks.Id, msg.RunReqChunks.Data, msg.RunReqChunks.IsLast)
|
||||
|
||||
if complete {
|
||||
var runReq cvms.ComputationRunReq
|
||||
if err := proto.Unmarshal(buffer, &runReq); err != nil {
|
||||
return errors.Wrap(err, errCorruptedManifest)
|
||||
}
|
||||
|
||||
go client.executeRun(ctx, &runReq)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *CVMSClient) executeRun(ctx context.Context, runReq *cvms.ComputationRunReq) {
|
||||
ac := agent.Computation{
|
||||
ID: runReq.Id,
|
||||
Name: runReq.Name,
|
||||
Description: runReq.Description,
|
||||
}
|
||||
|
||||
if runReq.Algorithm != nil {
|
||||
ac.Algorithm = agent.Algorithm{
|
||||
Hash: [32]byte(runReq.Algorithm.Hash),
|
||||
UserKey: runReq.Algorithm.UserKey,
|
||||
}
|
||||
}
|
||||
|
||||
for _, ds := range runReq.Datasets {
|
||||
ac.Datasets = append(ac.Datasets, agent.Dataset{
|
||||
Hash: [32]byte(ds.Hash),
|
||||
UserKey: ds.UserKey,
|
||||
})
|
||||
}
|
||||
|
||||
for _, rc := range runReq.ResultConsumers {
|
||||
ac.ResultConsumers = append(ac.ResultConsumers, agent.ResultConsumer{
|
||||
UserKey: rc.UserKey,
|
||||
})
|
||||
}
|
||||
|
||||
if err := client.svc.InitComputation(ctx, ac); err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
|
||||
if runReq.AgentConfig == nil {
|
||||
runReq.AgentConfig = &cvms.AgentConfig{}
|
||||
}
|
||||
|
||||
runRes := &cvms.ClientStreamMessage_RunRes{
|
||||
RunRes: &cvms.RunResponse{
|
||||
ComputationId: runReq.Id,
|
||||
},
|
||||
}
|
||||
|
||||
err := client.sp.Start(ctx, agent.AgentConfig{
|
||||
Port: runReq.AgentConfig.Port,
|
||||
Host: runReq.AgentConfig.Host,
|
||||
CertFile: runReq.AgentConfig.CertFile,
|
||||
KeyFile: runReq.AgentConfig.KeyFile,
|
||||
ServerCAFile: runReq.AgentConfig.ServerCaFile,
|
||||
ClientCAFile: runReq.AgentConfig.ClientCaFile,
|
||||
AttestedTls: runReq.AgentConfig.AttestedTls,
|
||||
}, ac)
|
||||
if err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
runRes.RunRes.Error = err.Error()
|
||||
}
|
||||
|
||||
client.sendMessage(&cvms.ClientStreamMessage{Message: runRes})
|
||||
}
|
||||
|
||||
func (client *CVMSClient) handleStopComputation(ctx context.Context, mes *cvms.ServerStreamMessage_StopComputation) {
|
||||
msg := &cvms.ClientStreamMessage_StopComputationRes{
|
||||
StopComputationRes: &cvms.StopComputationResponse{
|
||||
ComputationId: mes.StopComputation.ComputationId,
|
||||
},
|
||||
}
|
||||
if err := client.svc.StopComputation(ctx); err != nil {
|
||||
msg.StopComputationRes.Message = err.Error()
|
||||
}
|
||||
|
||||
client.mu.Lock()
|
||||
defer client.mu.Unlock()
|
||||
|
||||
if err := client.sp.Stop(); err != nil {
|
||||
msg.StopComputationRes.Message = err.Error()
|
||||
}
|
||||
|
||||
client.sendMessage(&cvms.ClientStreamMessage{Message: msg})
|
||||
}
|
||||
|
||||
func (client *CVMSClient) handleOutgoingMessages(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case mes := <-client.messageQueue:
|
||||
if err := client.stream.Send(mes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *CVMSClient) sendMessage(mes *cvms.ClientStreamMessage) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case client.messageQueue <- mes:
|
||||
case <-ctx.Done():
|
||||
client.logger.Warn("Failed to send message: timeout exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
type runRequestManager struct {
|
||||
requests map[string]*runRequest
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type runRequest struct {
|
||||
buffer []byte
|
||||
lastChunk time.Time
|
||||
timer *time.Timer
|
||||
}
|
||||
|
||||
func newRunRequestManager() *runRequestManager {
|
||||
return &runRequestManager{
|
||||
requests: make(map[string]*runRequest),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *runRequestManager) addChunk(id string, chunk []byte, isLast bool) ([]byte, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
req, exists := m.requests[id]
|
||||
if !exists {
|
||||
req = &runRequest{
|
||||
buffer: make([]byte, 0),
|
||||
lastChunk: time.Now(),
|
||||
timer: time.AfterFunc(runReqTimeout, func() { m.timeoutRequest(id) }),
|
||||
}
|
||||
m.requests[id] = req
|
||||
}
|
||||
|
||||
req.buffer = append(req.buffer, chunk...)
|
||||
req.lastChunk = time.Now()
|
||||
req.timer.Reset(runReqTimeout)
|
||||
|
||||
if isLast {
|
||||
delete(m.requests, id)
|
||||
req.timer.Stop()
|
||||
return req.buffer, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *runRequestManager) timeoutRequest(id string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.requests, id)
|
||||
// Log timeout or handle it as needed
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
servermocks "github.com/ultravioletrs/cocos/agent/cvms/server/mocks"
|
||||
"github.com/ultravioletrs/cocos/agent/mocks"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type mockStream struct {
|
||||
mock.Mock
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (m *mockStream) Recv() (*cvms.ServerStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*cvms.ServerStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockStream) Send(msg *cvms.ClientStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestManagerClient_Process1(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider)
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Stop computation",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil)
|
||||
mockSvc.On("StopComputation", mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Stop").Return(nil)
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Run request chunks",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{},
|
||||
},
|
||||
}, nil)
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
mockSvc.On("Run", mock.Anything, mock.Anything).Return("", assert.AnError).Once()
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Receive error",
|
||||
setupMocks: func(mockStream *mockStream, mockSvc *mocks.Service, mockServerSvc *servermocks.AgentServerProvider) {
|
||||
mockStream.On("Recv").Return(&cvms.ServerStreamMessage{}, assert.AnError)
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServerProvider)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
tc.setupMocks(mockStream, mockSvc, mockServerSvc)
|
||||
|
||||
err := client.Process(ctx, cancel)
|
||||
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
if tc.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServerProvider)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk1 := &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[:len(runReqBytes)/2],
|
||||
IsLast: false,
|
||||
},
|
||||
}
|
||||
chunk2 := &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[len(runReqBytes)/2:],
|
||||
IsLast: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("InitComputation", mock.Anything, mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Start", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
|
||||
err := client.handleRunReqChunks(context.Background(), chunk1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = client.handleRunReqChunks(context.Background(), chunk2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
runRes, ok := msg.Message.(*cvms.ClientStreamMessage_RunRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
|
||||
}
|
||||
|
||||
func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
mockServerSvc := new(servermocks.AgentServerProvider)
|
||||
messageQueue := make(chan *cvms.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger, mockServerSvc)
|
||||
|
||||
stopReq := &cvms.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &cvms.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
|
||||
mockSvc.On("StopComputation", mock.Anything).Return(nil)
|
||||
mockServerSvc.On("Stop").Return(nil)
|
||||
|
||||
client.handleStopComputation(context.Background(), stopReq)
|
||||
|
||||
// Wait for the goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
mockSvc.AssertExpectations(t)
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
stopRes, ok := msg.Message.(*cvms.ClientStreamMessage_StopComputationRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
|
||||
assert.Empty(t, stopRes.StopComputationRes.Message)
|
||||
}
|
||||
|
||||
func TestManagerClient_timeoutRequest(t *testing.T) {
|
||||
rm := newRunRequestManager()
|
||||
rm.requests["test-id"] = &runRequest{
|
||||
timer: time.NewTimer(100 * time.Millisecond),
|
||||
buffer: []byte("test-data"),
|
||||
lastChunk: time.Now(),
|
||||
}
|
||||
|
||||
rm.timeoutRequest("test-id")
|
||||
|
||||
assert.Len(t, rm.requests, 0)
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package grpc contains implementation of kit service gRPC API.
|
||||
package grpc
|
||||
@@ -0,0 +1,133 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
_ cvms.ServiceServer = (*grpcServer)(nil)
|
||||
ErrUnexpectedMsg = errors.New("unknown message type")
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 1024 * 1024 // 1 MB
|
||||
runReqTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type SendFunc func(*cvms.ServerStreamMessage) error
|
||||
|
||||
type grpcServer struct {
|
||||
cvms.UnimplementedServiceServer
|
||||
incoming chan *cvms.ClientStreamMessage
|
||||
svc Service
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo)
|
||||
}
|
||||
|
||||
// NewServer returns new AuthServiceServer instance.
|
||||
func NewServer(incoming chan *cvms.ClientStreamMessage, svc Service) cvms.ServiceServer {
|
||||
return &grpcServer{
|
||||
incoming: incoming,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *grpcServer) Process(stream cvms.Service_ProcessServer) error {
|
||||
client, ok := peer.FromContext(stream.Context())
|
||||
if !ok {
|
||||
return errors.New("failed to get peer info")
|
||||
}
|
||||
|
||||
eg, ctx := errgroup.WithContext(stream.Context())
|
||||
|
||||
eg.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
req, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.incoming <- req
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
sendMessage := func(msg *cvms.ServerStreamMessage) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
switch m := msg.Message.(type) {
|
||||
case *cvms.ServerStreamMessage_RunReq:
|
||||
return s.sendRunReqInChunks(stream, m.RunReq)
|
||||
default:
|
||||
return stream.Send(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.svc.Run(ctx, client.Addr.String(), sendMessage, client.AuthInfo)
|
||||
return nil
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
func (s *grpcServer) sendRunReqInChunks(stream cvms.Service_ProcessServer, runReq *cvms.ComputationRunReq) error {
|
||||
data, err := proto.Marshal(runReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dataBuffer := bytes.NewBuffer(data)
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
for {
|
||||
n, err := dataBuffer.Read(buf)
|
||||
isLast := false
|
||||
|
||||
if err == io.EOF {
|
||||
isLast = true
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunk := &cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &cvms.RunReqChunks{
|
||||
Id: runReq.Id,
|
||||
Data: buf[:n],
|
||||
IsLast: isLast,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := stream.Send(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isLast {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
type mockServerStream struct {
|
||||
mock.Mock
|
||||
cvms.Service_ProcessServer
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Send(msg *cvms.ServerStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Recv() (*cvms.ClientStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*cvms.ClientStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Context() context.Context {
|
||||
args := m.Called()
|
||||
return args.Get(0).(context.Context)
|
||||
}
|
||||
|
||||
type mockService struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockService) Run(ctx context.Context, ipAddress string, sendMessage SendFunc, authInfo credentials.AuthInfo) {
|
||||
m.Called(ctx, ipAddress, sendMessage, authInfo)
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
incoming := make(chan *cvms.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
|
||||
server := NewServer(incoming, mockSvc)
|
||||
|
||||
assert.NotNil(t, server)
|
||||
assert.IsType(t, &grpcServer{}, server)
|
||||
}
|
||||
|
||||
func TestGrpcServer_Process(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
recvReturn *cvms.ClientStreamMessage
|
||||
recvError error
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "Process with context deadline exceeded",
|
||||
recvReturn: &cvms.ClientStreamMessage{},
|
||||
recvError: nil,
|
||||
expectedError: "context deadline exceeded",
|
||||
},
|
||||
{
|
||||
name: "Process with Recv error",
|
||||
recvReturn: &cvms.ClientStreamMessage{},
|
||||
recvError: errors.New("recv error"),
|
||||
expectedError: "recv error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
incoming := make(chan *cvms.ClientStreamMessage, 1)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStream.On("Context").Return(peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
}))
|
||||
|
||||
if tt.recvError == nil {
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mockStream.On("Recv").Return(tt.recvReturn, tt.recvError)
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).Return()
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.expectedError)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunks(t *testing.T) {
|
||||
incoming := make(chan *cvms.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
|
||||
largePayload := make([]byte, bufferSize*2)
|
||||
for i := range largePayload {
|
||||
largePayload[i] = byte(i % 256)
|
||||
}
|
||||
runReq.Algorithm = &cvms.Algorithm{}
|
||||
runReq.Algorithm.UserKey = largePayload
|
||||
|
||||
mockStream.On("Send", mock.AnythingOfType("*cvms.ServerStreamMessage")).Return(nil).Times(4)
|
||||
|
||||
err := server.sendRunReqInChunks(mockStream, runReq)
|
||||
|
||||
assert.NoError(t, err)
|
||||
mockStream.AssertExpectations(t)
|
||||
|
||||
calls := mockStream.Calls
|
||||
assert.Equal(t, 4, len(calls))
|
||||
|
||||
for i, call := range calls {
|
||||
msg := call.Arguments[0].(*cvms.ServerStreamMessage)
|
||||
chunk := msg.GetRunReqChunks()
|
||||
|
||||
assert.NotNil(t, chunk)
|
||||
assert.Equal(t, "test-id", chunk.Id)
|
||||
|
||||
if i < 3 {
|
||||
assert.False(t, chunk.IsLast)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(chunk.Data))
|
||||
assert.True(t, chunk.IsLast)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockAddr struct{}
|
||||
|
||||
func (mockAddr) Network() string { return "test network" }
|
||||
func (mockAddr) String() string { return "test" }
|
||||
|
||||
type mockAuthInfo struct{}
|
||||
|
||||
func (mockAuthInfo) AuthType() string { return "test auth" }
|
||||
|
||||
func TestGrpcServer_ProcessWithMockService(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMockFn func(*mockService, *mockServerStream)
|
||||
}{
|
||||
{
|
||||
name: "Run Request Test",
|
||||
setupMockFn: func(mockSvc *mockService, mockStream *mockServerStream) {
|
||||
mockSvc.On("Run", mock.Anything, "test", mock.Anything, mock.AnythingOfType("mockAuthInfo")).
|
||||
Run(func(args mock.Arguments) {
|
||||
sendFunc := args.Get(2).(SendFunc)
|
||||
runReq := &cvms.ComputationRunReq{Id: "test-run-id"}
|
||||
err := sendFunc(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReq{
|
||||
RunReq: runReq,
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}).
|
||||
Return()
|
||||
|
||||
mockStream.On("Send", mock.MatchedBy(func(msg *cvms.ServerStreamMessage) bool {
|
||||
chunks := msg.GetRunReqChunks()
|
||||
return chunks != nil && chunks.Id == "test-run-id"
|
||||
})).Return(nil)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
incoming := make(chan *cvms.ClientStreamMessage, 10)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
go func() {
|
||||
for mes := range incoming {
|
||||
assert.NotNil(t, mes)
|
||||
}
|
||||
}()
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
peerCtx := peer.NewContext(ctx, &peer.Peer{
|
||||
Addr: mockAddr{},
|
||||
AuthInfo: mockAuthInfo{},
|
||||
})
|
||||
|
||||
mockStream.On("Context").Return(peerCtx)
|
||||
mockStream.On("Recv").Return(&cvms.ClientStreamMessage{}, nil).Maybe()
|
||||
|
||||
tt.setupMockFn(mockSvc, mockStream)
|
||||
|
||||
go func() {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
mockStream.AssertExpectations(t)
|
||||
mockSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcServer_sendRunReqInChunksError(t *testing.T) {
|
||||
incoming := make(chan *cvms.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
|
||||
runReq := &cvms.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
|
||||
// Simulate an error when sending
|
||||
mockStream.On("Send", mock.AnythingOfType("*cvms.ServerStreamMessage")).Return(errors.New("send error")).Once()
|
||||
|
||||
err := server.sendRunReqInChunks(mockStream, runReq)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "send error")
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGrpcServer_ProcessMissingPeerInfo(t *testing.T) {
|
||||
incoming := make(chan *cvms.ClientStreamMessage)
|
||||
mockSvc := new(mockService)
|
||||
server := NewServer(incoming, mockSvc).(*grpcServer)
|
||||
|
||||
mockStream := new(mockServerStream)
|
||||
ctx := context.Background()
|
||||
|
||||
// Return a context without peer info
|
||||
mockStream.On("Context").Return(ctx)
|
||||
|
||||
err := server.Process(mockStream)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to get peer info")
|
||||
mockStream.AssertExpectations(t)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
package cvms;
|
||||
|
||||
option go_package = "./cvms";
|
||||
|
||||
service Service {
|
||||
rpc Process(stream ClientStreamMessage) returns (stream ServerStreamMessage) {}
|
||||
}
|
||||
|
||||
message StopComputation {
|
||||
string computation_id = 1;
|
||||
}
|
||||
|
||||
message StopComputationResponse {
|
||||
string computation_id = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message RunResponse{
|
||||
string computation_id = 1;
|
||||
string error = 2;
|
||||
}
|
||||
|
||||
message AgentEvent {
|
||||
string event_type = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
string computation_id = 3;
|
||||
bytes details = 4;
|
||||
string originator = 5;
|
||||
string status = 6;
|
||||
}
|
||||
|
||||
message AgentLog {
|
||||
string message = 1;
|
||||
string computation_id = 2;
|
||||
string level = 3;
|
||||
google.protobuf.Timestamp timestamp = 4;
|
||||
}
|
||||
|
||||
message ClientStreamMessage {
|
||||
oneof message {
|
||||
AgentLog agent_log = 1;
|
||||
AgentEvent agent_event = 2;
|
||||
RunResponse run_res = 3;
|
||||
StopComputationResponse stopComputationRes = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message ServerStreamMessage {
|
||||
oneof message {
|
||||
RunReqChunks runReqChunks = 1;
|
||||
ComputationRunReq runReq = 2;
|
||||
StopComputation stopComputation = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message RunReqChunks {
|
||||
bytes data = 1;
|
||||
string id = 2;
|
||||
bool is_last = 3;
|
||||
}
|
||||
|
||||
message ComputationRunReq {
|
||||
string id = 1;
|
||||
string name = 2;
|
||||
string description = 3;
|
||||
repeated Dataset datasets = 4;
|
||||
Algorithm algorithm = 5;
|
||||
repeated ResultConsumer result_consumers = 6;
|
||||
AgentConfig agent_config = 7;
|
||||
}
|
||||
|
||||
message ResultConsumer {
|
||||
bytes userKey = 1;
|
||||
}
|
||||
|
||||
message Dataset {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
string filename = 3;
|
||||
}
|
||||
|
||||
message Algorithm {
|
||||
bytes hash = 1; // should be sha3.Sum256, 32 byte length.
|
||||
bytes userKey = 2;
|
||||
}
|
||||
|
||||
message AgentConfig {
|
||||
string port = 1;
|
||||
string host = 2;
|
||||
string cert_file = 3;
|
||||
string key_file = 4;
|
||||
string client_ca_file = 5;
|
||||
string server_ca_file = 6;
|
||||
string log_level = 7;
|
||||
bool attested_tls = 8;
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.29.0
|
||||
// source: agent/cvms/cvms.proto
|
||||
|
||||
package cvms
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
Service_Process_FullMethodName = "/cvms.Service/Process"
|
||||
)
|
||||
|
||||
// ServiceClient is the client API for Service service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ServiceClient interface {
|
||||
Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error)
|
||||
}
|
||||
|
||||
type serviceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewServiceClient(cc grpc.ClientConnInterface) ServiceClient {
|
||||
return &serviceClient{cc}
|
||||
}
|
||||
|
||||
func (c *serviceClient) Process(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &Service_ServiceDesc.Streams[0], Service_Process_FullMethodName, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[ClientStreamMessage, ServerStreamMessage]{ClientStream: stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type Service_ProcessClient = grpc.BidiStreamingClient[ClientStreamMessage, ServerStreamMessage]
|
||||
|
||||
// ServiceServer is the server API for Service service.
|
||||
// All implementations must embed UnimplementedServiceServer
|
||||
// for forward compatibility.
|
||||
type ServiceServer interface {
|
||||
Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error
|
||||
mustEmbedUnimplementedServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedServiceServer struct{}
|
||||
|
||||
func (UnimplementedServiceServer) Process(grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Process not implemented")
|
||||
}
|
||||
func (UnimplementedServiceServer) mustEmbedUnimplementedServiceServer() {}
|
||||
func (UnimplementedServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to ServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeServiceServer interface {
|
||||
mustEmbedUnimplementedServiceServer()
|
||||
}
|
||||
|
||||
func RegisterServiceServer(s grpc.ServiceRegistrar, srv ServiceServer) {
|
||||
// If the following call pancis, it indicates UnimplementedServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&Service_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _Service_Process_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(ServiceServer).Process(&grpc.GenericServerStream[ClientStreamMessage, ServerStreamMessage]{ServerStream: stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type Service_ProcessServer = grpc.BidiStreamingServer[ClientStreamMessage, ServerStreamMessage]
|
||||
|
||||
// Service_ServiceDesc is the grpc.ServiceDesc for Service service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var Service_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "cvms.Service",
|
||||
HandlerType: (*ServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Process",
|
||||
Handler: _Service_Process_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "agent/cvms/cvms.proto",
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
context "context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "agent"
|
||||
defSvcGRPCPort = "7002"
|
||||
)
|
||||
|
||||
type AgentServer interface {
|
||||
Start(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type agentServer struct {
|
||||
gs server.Server
|
||||
logger *slog.Logger
|
||||
svc agent.Service
|
||||
}
|
||||
|
||||
func NewServer(logger *slog.Logger, svc agent.Service) AgentServer {
|
||||
return &agentServer{
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
func (as *agentServer) Start(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
if cfg.Port == "" {
|
||||
cfg.Port = defSvcGRPCPort
|
||||
}
|
||||
|
||||
agentGrpcServerConfig := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
CertFile: cfg.CertFile,
|
||||
KeyFile: cfg.KeyFile,
|
||||
ServerCAFile: cfg.ServerCAFile,
|
||||
ClientCAFile: cfg.ClientCAFile,
|
||||
},
|
||||
},
|
||||
AttestedTLS: cfg.AttestedTls,
|
||||
}
|
||||
|
||||
registerAgentServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(as.svc))
|
||||
}
|
||||
|
||||
authSvc, err := auth.New(cmp)
|
||||
if err != nil {
|
||||
as.logger.WithGroup(cmp.ID).Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
if err != nil {
|
||||
as.logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error()))
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
as.gs = grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, as.logger, qp, authSvc)
|
||||
|
||||
return as.gs.Start()
|
||||
}
|
||||
|
||||
func (as *agentServer) Stop() error {
|
||||
return as.gs.Stop()
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
agent "github.com/ultravioletrs/cocos/agent"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// AgentServerProvider is an autogenerated mock type for the AgentServerProvider type
|
||||
type AgentServerProvider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AgentServerProvider_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AgentServerProvider) EXPECT() *AgentServerProvider_Expecter {
|
||||
return &AgentServerProvider_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Start provides a mock function with given fields: ctx, cfg, cmp
|
||||
func (_m *AgentServerProvider) Start(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation) error {
|
||||
ret := _m.Called(ctx, cfg, cmp)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Start")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.AgentConfig, agent.Computation) error); ok {
|
||||
r0 = rf(ctx, cfg, cmp)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentServerProvider_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
|
||||
type AgentServerProvider_Start_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Start is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cfg agent.AgentConfig
|
||||
// - cmp agent.Computation
|
||||
func (_e *AgentServerProvider_Expecter) Start(ctx interface{}, cfg interface{}, cmp interface{}) *AgentServerProvider_Start_Call {
|
||||
return &AgentServerProvider_Start_Call{Call: _e.mock.On("Start", ctx, cfg, cmp)}
|
||||
}
|
||||
|
||||
func (_c *AgentServerProvider_Start_Call) Run(run func(ctx context.Context, cfg agent.AgentConfig, cmp agent.Computation)) *AgentServerProvider_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(agent.AgentConfig), args[2].(agent.Computation))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServerProvider_Start_Call) Return(_a0 error) *AgentServerProvider_Start_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServerProvider_Start_Call) RunAndReturn(run func(context.Context, agent.AgentConfig, agent.Computation) error) *AgentServerProvider_Start_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Stop provides a mock function with given fields:
|
||||
func (_m *AgentServerProvider) Stop() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Stop")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AgentServerProvider_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop'
|
||||
type AgentServerProvider_Stop_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Stop is a helper method to define mock.On call
|
||||
func (_e *AgentServerProvider_Expecter) Stop() *AgentServerProvider_Stop_Call {
|
||||
return &AgentServerProvider_Stop_Call{Call: _e.mock.On("Stop")}
|
||||
}
|
||||
|
||||
func (_c *AgentServerProvider_Stop_Call) Run(run func()) *AgentServerProvider_Stop_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServerProvider_Stop_Call) Return(_a0 error) *AgentServerProvider_Stop_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AgentServerProvider_Stop_Call) RunAndReturn(run func() error) *AgentServerProvider_Stop_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewAgentServerProvider creates a new instance of AgentServerProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgentServerProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AgentServerProvider {
|
||||
mock := &AgentServerProvider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
+19
-24
@@ -4,43 +4,38 @@ package events
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type service struct {
|
||||
service string
|
||||
computationID string
|
||||
conn io.Writer
|
||||
service string
|
||||
queue chan *cvms.ClientStreamMessage
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
SendEvent(event, status string, details json.RawMessage) error
|
||||
SendEvent(cmpID, event, status string, details json.RawMessage)
|
||||
}
|
||||
|
||||
func New(svc, computationID string, conn io.Writer) (Service, error) {
|
||||
func New(svc string, queue chan *cvms.ClientStreamMessage) (Service, error) {
|
||||
return &service{
|
||||
service: svc,
|
||||
computationID: computationID,
|
||||
conn: conn,
|
||||
service: svc,
|
||||
queue: queue,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *service) SendEvent(event, status string, details json.RawMessage) error {
|
||||
body := EventsLogs{Message: &EventsLogs_AgentEvent{AgentEvent: &AgentEvent{
|
||||
EventType: event,
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: s.computationID,
|
||||
Originator: s.service,
|
||||
Status: status,
|
||||
Details: details,
|
||||
}}}
|
||||
protoBody, err := proto.Marshal(&body)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *service) SendEvent(cmpID, event, status string, details json.RawMessage) {
|
||||
s.queue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &cvms.AgentEvent{
|
||||
EventType: event,
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: cmpID,
|
||||
Originator: s.service,
|
||||
Status: status,
|
||||
Details: details,
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = s.conn.Write(protoBody)
|
||||
return err
|
||||
}
|
||||
|
||||
+36
-79
@@ -3,8 +3,8 @@
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v5.28.1
|
||||
// protoc-gen-go v1.36.0
|
||||
// protoc v5.29.0
|
||||
// source: agent/events/events.proto
|
||||
|
||||
package events
|
||||
@@ -25,25 +25,22 @@ const (
|
||||
)
|
||||
|
||||
type AgentEvent struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"`
|
||||
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
|
||||
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AgentEvent) Reset() {
|
||||
*x = AgentEvent{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_events_events_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_events_events_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AgentEvent) String() string {
|
||||
@@ -54,7 +51,7 @@ func (*AgentEvent) ProtoMessage() {}
|
||||
|
||||
func (x *AgentEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_events_events_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -112,23 +109,20 @@ func (x *AgentEvent) GetStatus() string {
|
||||
}
|
||||
|
||||
type AgentLog struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *AgentLog) Reset() {
|
||||
*x = AgentLog{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_events_events_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_events_events_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *AgentLog) String() string {
|
||||
@@ -139,7 +133,7 @@ func (*AgentLog) ProtoMessage() {}
|
||||
|
||||
func (x *AgentLog) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_events_events_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -183,24 +177,21 @@ func (x *AgentLog) GetTimestamp() *timestamppb.Timestamp {
|
||||
}
|
||||
|
||||
type EventsLogs struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Message:
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Types that are valid to be assigned to Message:
|
||||
//
|
||||
// *EventsLogs_AgentLog
|
||||
// *EventsLogs_AgentEvent
|
||||
Message isEventsLogs_Message `protobuf_oneof:"message"`
|
||||
Message isEventsLogs_Message `protobuf_oneof:"message"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *EventsLogs) Reset() {
|
||||
*x = EventsLogs{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_events_events_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
mi := &file_agent_events_events_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *EventsLogs) String() string {
|
||||
@@ -211,7 +202,7 @@ func (*EventsLogs) ProtoMessage() {}
|
||||
|
||||
func (x *EventsLogs) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_events_events_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -226,23 +217,27 @@ func (*EventsLogs) Descriptor() ([]byte, []int) {
|
||||
return file_agent_events_events_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (m *EventsLogs) GetMessage() isEventsLogs_Message {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
func (x *EventsLogs) GetMessage() isEventsLogs_Message {
|
||||
if x != nil {
|
||||
return x.Message
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventsLogs) GetAgentLog() *AgentLog {
|
||||
if x, ok := x.GetMessage().(*EventsLogs_AgentLog); ok {
|
||||
return x.AgentLog
|
||||
if x != nil {
|
||||
if x, ok := x.Message.(*EventsLogs_AgentLog); ok {
|
||||
return x.AgentLog
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventsLogs) GetAgentEvent() *AgentEvent {
|
||||
if x, ok := x.GetMessage().(*EventsLogs_AgentEvent); ok {
|
||||
return x.AgentEvent
|
||||
if x != nil {
|
||||
if x, ok := x.Message.(*EventsLogs_AgentEvent); ok {
|
||||
return x.AgentEvent
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -342,44 +337,6 @@ func file_agent_events_events_proto_init() {
|
||||
if File_agent_events_events_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_agent_events_events_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AgentEvent); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_events_events_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AgentLog); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_events_events_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*EventsLogs); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_agent_events_events_proto_msgTypes[2].OneofWrappers = []any{
|
||||
(*EventsLogs_AgentLog)(nil),
|
||||
(*EventsLogs_AgentEvent)(nil),
|
||||
|
||||
+17
-43
@@ -3,62 +3,36 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
writeErr error
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *mockConn) Write(p []byte) (n int, err error) {
|
||||
if m.writeErr != nil {
|
||||
return 0, m.writeErr
|
||||
}
|
||||
return m.buf.Write(p)
|
||||
}
|
||||
|
||||
func TestSendEventSuccess(t *testing.T) {
|
||||
mockConnection := &mockConn{}
|
||||
|
||||
svc, err := New("test_service", "12345", mockConnection)
|
||||
queue := make(chan *cvms.ClientStreamMessage, 1)
|
||||
svc, err := New("test_service", queue)
|
||||
assert.NoError(t, err)
|
||||
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
err = svc.SendEvent("test_event", "success", details)
|
||||
assert.NoError(t, err)
|
||||
go func() {
|
||||
msg := <-queue
|
||||
assert.NotNil(t, msg)
|
||||
assert.NotNil(t, msg.GetAgentEvent())
|
||||
assert.Equal(t, "test_event", msg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, "testid", msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, "test_service", msg.GetAgentEvent().Originator)
|
||||
assert.Equal(t, "success", msg.GetAgentEvent().Status)
|
||||
|
||||
var writtenMessage EventsLogs
|
||||
err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage)
|
||||
assert.NoError(t, err)
|
||||
now := time.Now()
|
||||
eventTimestamp := msg.GetAgentEvent().GetTimestamp().AsTime()
|
||||
assert.WithinDuration(t, now, eventTimestamp, 1*time.Second)
|
||||
}()
|
||||
|
||||
assert.Equal(t, "test_event", writtenMessage.GetAgentEvent().EventType)
|
||||
assert.Equal(t, "12345", writtenMessage.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, "test_service", writtenMessage.GetAgentEvent().Originator)
|
||||
assert.Equal(t, "success", writtenMessage.GetAgentEvent().Status)
|
||||
svc.SendEvent("testid", "test_event", "success", details)
|
||||
|
||||
now := time.Now()
|
||||
eventTimestamp := writtenMessage.GetAgentEvent().GetTimestamp().AsTime()
|
||||
assert.WithinDuration(t, now, eventTimestamp, 1*time.Second)
|
||||
}
|
||||
|
||||
func TestSendEventFailure(t *testing.T) {
|
||||
mockConnection := &mockConn{writeErr: errors.New("write error")}
|
||||
|
||||
svc, err := New("test_service", "12345", mockConnection)
|
||||
assert.NoError(t, err)
|
||||
|
||||
details := json.RawMessage(`{"key": "value"}`)
|
||||
|
||||
err = svc.SendEvent("test_event", "failure", details)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "write error", err.Error())
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
|
||||
@@ -24,22 +24,9 @@ func (_m *Service) EXPECT() *Service_Expecter {
|
||||
return &Service_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// SendEvent provides a mock function with given fields: event, status, details
|
||||
func (_m *Service) SendEvent(event string, status string, details json.RawMessage) error {
|
||||
ret := _m.Called(event, status, details)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendEvent")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(string, string, json.RawMessage) error); ok {
|
||||
r0 = rf(event, status, details)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
// SendEvent provides a mock function with given fields: cmpID, event, status, details
|
||||
func (_m *Service) SendEvent(cmpID string, event string, status string, details json.RawMessage) {
|
||||
_m.Called(cmpID, event, status, details)
|
||||
}
|
||||
|
||||
// Service_SendEvent_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendEvent'
|
||||
@@ -48,26 +35,27 @@ type Service_SendEvent_Call struct {
|
||||
}
|
||||
|
||||
// SendEvent is a helper method to define mock.On call
|
||||
// - cmpID string
|
||||
// - event string
|
||||
// - status string
|
||||
// - details json.RawMessage
|
||||
func (_e *Service_Expecter) SendEvent(event interface{}, status interface{}, details interface{}) *Service_SendEvent_Call {
|
||||
return &Service_SendEvent_Call{Call: _e.mock.On("SendEvent", event, status, details)}
|
||||
func (_e *Service_Expecter) SendEvent(cmpID interface{}, event interface{}, status interface{}, details interface{}) *Service_SendEvent_Call {
|
||||
return &Service_SendEvent_Call{Call: _e.mock.On("SendEvent", cmpID, event, status, details)}
|
||||
}
|
||||
|
||||
func (_c *Service_SendEvent_Call) Run(run func(event string, status string, details json.RawMessage)) *Service_SendEvent_Call {
|
||||
func (_c *Service_SendEvent_Call) Run(run func(cmpID string, event string, status string, details json.RawMessage)) *Service_SendEvent_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(string), args[2].(json.RawMessage))
|
||||
run(args[0].(string), args[1].(string), args[2].(string), args[3].(json.RawMessage))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_SendEvent_Call) Return(_a0 error) *Service_SendEvent_Call {
|
||||
_c.Call.Return(_a0)
|
||||
func (_c *Service_SendEvent_Call) Return() *Service_SendEvent_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_SendEvent_Call) RunAndReturn(run func(string, string, json.RawMessage) error) *Service_SendEvent_Call {
|
||||
func (_c *Service_SendEvent_Call) RunAndReturn(run func(string, string, string, json.RawMessage)) *Service_SendEvent_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -179,6 +179,53 @@ func (_c *Service_Data_Call) RunAndReturn(run func(context.Context, agent.Datase
|
||||
return _c
|
||||
}
|
||||
|
||||
// InitComputation provides a mock function with given fields: ctx, cmp
|
||||
func (_m *Service) InitComputation(ctx context.Context, cmp agent.Computation) error {
|
||||
ret := _m.Called(ctx, cmp)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for InitComputation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, agent.Computation) error); ok {
|
||||
r0 = rf(ctx, cmp)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_InitComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitComputation'
|
||||
type Service_InitComputation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InitComputation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cmp agent.Computation
|
||||
func (_e *Service_Expecter) InitComputation(ctx interface{}, cmp interface{}) *Service_InitComputation_Call {
|
||||
return &Service_InitComputation_Call{Call: _e.mock.On("InitComputation", ctx, cmp)}
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) Run(run func(ctx context.Context, cmp agent.Computation)) *Service_InitComputation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(agent.Computation))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) Return(_a0 error) *Service_InitComputation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_InitComputation_Call) RunAndReturn(run func(context.Context, agent.Computation) error) *Service_InitComputation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Result provides a mock function with given fields: ctx
|
||||
func (_m *Service) Result(ctx context.Context) ([]byte, error) {
|
||||
ret := _m.Called(ctx)
|
||||
@@ -237,6 +284,52 @@ func (_c *Service_Result_Call) RunAndReturn(run func(context.Context) ([]byte, e
|
||||
return _c
|
||||
}
|
||||
|
||||
// StopComputation provides a mock function with given fields: ctx
|
||||
func (_m *Service) StopComputation(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for StopComputation")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_StopComputation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StopComputation'
|
||||
type Service_StopComputation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// StopComputation is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *Service_Expecter) StopComputation(ctx interface{}) *Service_StopComputation_Call {
|
||||
return &Service_StopComputation_Call{Call: _e.mock.On("StopComputation", ctx)}
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) Run(run func(ctx context.Context)) *Service_StopComputation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) Return(_a0 error) *Service_StopComputation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_StopComputation_Call) RunAndReturn(run func(context.Context) error) *Service_StopComputation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
|
||||
+68
-19
@@ -104,6 +104,8 @@ var (
|
||||
// Service specifies an API that must be fullfiled by the domain service
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
type Service interface {
|
||||
InitComputation(ctx context.Context, cmp Computation) error
|
||||
StopComputation(ctx context.Context) error
|
||||
Algo(ctx context.Context, algorithm Algorithm) error
|
||||
Data(ctx context.Context, dataset Dataset) error
|
||||
Result(ctx context.Context) ([]byte, error)
|
||||
@@ -121,19 +123,21 @@ type agentService struct {
|
||||
quoteProvider client.QuoteProvider // Provider for generating attestation quotes.
|
||||
logger *slog.Logger // Logger for the agent service.
|
||||
resultsConsumed bool // Indicates if the results have been consumed.
|
||||
cancel context.CancelFunc // Cancels the computation context.
|
||||
}
|
||||
|
||||
var _ Service = (*agentService)(nil)
|
||||
|
||||
// New instantiates the agent service implementation.
|
||||
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp Computation, quoteProvider client.QuoteProvider) Service {
|
||||
func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.QuoteProvider) Service {
|
||||
sm := statemachine.NewStateMachine(Idle)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
svc := &agentService{
|
||||
sm: sm,
|
||||
eventSvc: eventSvc,
|
||||
quoteProvider: quoteProvider,
|
||||
logger: logger,
|
||||
computation: cmp,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
transitions := []statemachine.Transition{
|
||||
@@ -141,13 +145,6 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
|
||||
{From: ReceivingManifest, Event: ManifestReceived, To: ReceivingAlgorithm},
|
||||
}
|
||||
|
||||
if len(cmp.Datasets) == 0 {
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
|
||||
} else {
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
|
||||
}
|
||||
|
||||
transitions = append(transitions, []statemachine.Transition{
|
||||
{From: Running, Event: RunComplete, To: ConsumingResults},
|
||||
{From: Running, Event: RunFailed, To: Failed},
|
||||
@@ -158,8 +155,6 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
|
||||
sm.AddTransition(t)
|
||||
}
|
||||
|
||||
sm.SetAction(Idle, svc.publishEvent(IdleState.String()))
|
||||
sm.SetAction(ReceivingManifest, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(ReceivingAlgorithm, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(ReceivingData, svc.publishEvent(InProgress.String()))
|
||||
sm.SetAction(Running, svc.runComputation)
|
||||
@@ -173,11 +168,67 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp
|
||||
}
|
||||
}()
|
||||
sm.SendEvent(Start)
|
||||
defer sm.SendEvent(ManifestReceived)
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func (as *agentService) InitComputation(ctx context.Context, cmp Computation) error {
|
||||
defer as.sm.SendEvent(ManifestReceived)
|
||||
if as.sm.GetState() != ReceivingManifest {
|
||||
return ErrStateNotReady
|
||||
}
|
||||
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
|
||||
as.computation = cmp
|
||||
|
||||
transitions := []statemachine.Transition{}
|
||||
|
||||
if len(cmp.Datasets) == 0 {
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: Running})
|
||||
} else {
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingAlgorithm, Event: AlgorithmReceived, To: ReceivingData})
|
||||
transitions = append(transitions, statemachine.Transition{From: ReceivingData, Event: DataReceived, To: Running})
|
||||
}
|
||||
|
||||
for _, t := range transitions {
|
||||
as.sm.AddTransition(t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (as *agentService) StopComputation(ctx context.Context) error {
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
|
||||
as.cancel()
|
||||
|
||||
if err := as.algorithm.Stop(); err != nil {
|
||||
return fmt.Errorf("error stopping computation: %v", err)
|
||||
}
|
||||
|
||||
sm := statemachine.NewStateMachine(Idle)
|
||||
|
||||
if err := os.RemoveAll(algorithm.DatasetsDir); err != nil {
|
||||
return fmt.Errorf("error removing datasets directory: %v", err)
|
||||
}
|
||||
|
||||
if err := os.RemoveAll(algorithm.ResultsDir); err != nil {
|
||||
return fmt.Errorf("error removing results directory: %v", err)
|
||||
}
|
||||
|
||||
as.sm = sm
|
||||
as.computation = Computation{}
|
||||
as.algorithm = nil
|
||||
as.result = nil
|
||||
as.runError = nil
|
||||
as.resultsConsumed = false
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
if as.sm.GetState() != ReceivingAlgorithm {
|
||||
return ErrStateNotReady
|
||||
@@ -225,7 +276,7 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
|
||||
switch algoType {
|
||||
case string(algorithm.AlgoTypeBin):
|
||||
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
|
||||
as.algorithm = binary.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args, as.computation.ID)
|
||||
case string(algorithm.AlgoTypePython):
|
||||
var requirementsFile string
|
||||
if len(algo.Requirements) > 0 {
|
||||
@@ -243,11 +294,11 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
requirementsFile = fr.Name()
|
||||
}
|
||||
runtime := python.PythonRunTimeFromContext(ctx)
|
||||
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args)
|
||||
as.algorithm = python.NewAlgorithm(as.logger, as.eventSvc, runtime, requirementsFile, f.Name(), args, as.computation.ID)
|
||||
case string(algorithm.AlgoTypeWasm):
|
||||
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, f.Name(), args)
|
||||
as.algorithm = wasm.NewAlgorithm(as.logger, as.eventSvc, args, f.Name(), as.computation.ID)
|
||||
case string(algorithm.AlgoTypeDocker):
|
||||
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name())
|
||||
as.algorithm = docker.NewAlgorithm(as.logger, as.eventSvc, f.Name(), as.computation.ID)
|
||||
}
|
||||
|
||||
if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil {
|
||||
@@ -400,8 +451,6 @@ func (as *agentService) runComputation(state statemachine.State) {
|
||||
|
||||
func (as *agentService) publishEvent(status string) statemachine.Action {
|
||||
return func(state statemachine.State) {
|
||||
if err := as.eventSvc.SendEvent(state.String(), status, json.RawMessage{}); err != nil {
|
||||
as.logger.Warn(err.Error())
|
||||
}
|
||||
as.eventSvc.SendEvent(as.computation.ID, state.String(), status, json.RawMessage{})
|
||||
}
|
||||
}
|
||||
|
||||
+23
-29
@@ -35,11 +35,6 @@ var (
|
||||
const datasetFile = "iris.csv"
|
||||
|
||||
func TestAlgo(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -120,9 +115,15 @@ func TestAlgo(t *testing.T) {
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, tc.algoType, python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
|
||||
events := new(mocks.Service)
|
||||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
|
||||
svc := New(ctx, mglog.NewMock(), events, qp)
|
||||
|
||||
err := svc.InitComputation(ctx, testComputation(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
@@ -138,11 +139,6 @@ func TestAlgo(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestData(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -209,6 +205,9 @@ func TestData(t *testing.T) {
|
||||
python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
|
||||
events := new(mocks.Service)
|
||||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||||
|
||||
if tc.err != ErrUndeclaredDataset {
|
||||
ctx = IndexToContext(ctx, 0)
|
||||
}
|
||||
@@ -216,13 +215,16 @@ func TestData(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
comp := testComputation(t)
|
||||
svc := New(ctx, mglog.NewMock(), events, qp)
|
||||
|
||||
err := svc.InitComputation(ctx, testComputation(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, comp, qp)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
if tc.err != ErrStateNotReady {
|
||||
_ = svc.Algo(ctx, alg)
|
||||
err = svc.Algo(ctx, alg)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
err = svc.Data(ctx, tc.data)
|
||||
@@ -238,11 +240,6 @@ func TestData(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
qp, err := quoteprovider.GetQuoteProvider()
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -285,6 +282,9 @@ func TestResult(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
events := new(mocks.Service)
|
||||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
||||
@@ -323,12 +323,8 @@ func TestResult(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAttestation(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
qp := new(mocks2.QuoteProvider)
|
||||
|
||||
evCall := events.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
defer evCall.Unset()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
reportData [ReportDataSize]byte
|
||||
@@ -350,6 +346,9 @@ func TestAttestation(t *testing.T) {
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
events := new(mocks.Service)
|
||||
events.EXPECT().SendEvent(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
|
||||
|
||||
ctx := metadata.NewIncomingContext(context.Background(),
|
||||
metadata.Pairs(algorithm.AlgoTypeKey, "python", python.PyRuntimeKey, python.PyRuntime),
|
||||
)
|
||||
@@ -362,7 +361,7 @@ func TestAttestation(t *testing.T) {
|
||||
}
|
||||
defer getQuote.Unset()
|
||||
|
||||
svc := New(ctx, mglog.NewMock(), events, testComputation(t), qp)
|
||||
svc := New(ctx, mglog.NewMock(), events, qp)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
_, err := svc.Attestation(ctx, tc.reportData)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err)
|
||||
@@ -397,10 +396,5 @@ func testComputation(t *testing.T) Computation {
|
||||
Datasets: []Dataset{{Hash: dataHash, UserKey: []byte("key"), Dataset: data, Filename: datasetFile}},
|
||||
Algorithm: Algorithm{Hash: algoHash, UserKey: []byte("key"), Algorithm: algo},
|
||||
ResultConsumers: []ResultConsumer{{UserKey: []byte("key")}},
|
||||
AgentConfig: AgentConfig{
|
||||
Port: "7002",
|
||||
LogLevel: "debug",
|
||||
AttestedTls: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +131,7 @@ func TestManifestChecksum(t *testing.T) {
|
||||
"name": "Example Computation",
|
||||
"description": "This is an example computation"
|
||||
}`,
|
||||
expectedSum: "868825367c32c4b6d621d5d95e2890f233d8554df2348ab743aac2663a936f08",
|
||||
expectedSum: "a99683e4d22ba54cefa51aa49fb2e97a92b828c088395992ddff16a6236f3299",
|
||||
},
|
||||
{
|
||||
name: "Invalid JSON",
|
||||
|
||||
+57
-182
@@ -3,77 +3,68 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/prometheus"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/go-sev-guest/abi"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/mdlayher/vsock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/api"
|
||||
agentgrpc "github.com/ultravioletrs/cocos/agent/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/auth"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
cvmapi "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms/server"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
managerevents "github.com/ultravioletrs/cocos/manager/events"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider"
|
||||
"golang.org/x/crypto/sha3"
|
||||
pkggrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
cvmgrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/cvm"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
const (
|
||||
svcName = "agent"
|
||||
defSvcGRPCPort = "7002"
|
||||
retryInterval = 5 * time.Second
|
||||
svcName = "agent"
|
||||
defSvcGRPCPort = "7002"
|
||||
retryInterval = 5 * time.Second
|
||||
envPrefixCVMGRPC = "AGENT_CVM_GRPC_"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"AGENT_LOG_LEVEL" envDefault:"debug"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
cfg, err := readConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read agent configuration from vsock %s", err.Error())
|
||||
var cfg config
|
||||
if err := env.Parse(&cfg); err != nil {
|
||||
log.Fatalf("failed to load %s configuration : %s", svcName, err)
|
||||
}
|
||||
|
||||
conn, err := dialVsock()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ackConn := ackvsock.NewAckWriter(conn)
|
||||
|
||||
var exitCode int
|
||||
defer mglog.ExitWithError(&exitCode)
|
||||
|
||||
var level slog.Level
|
||||
if err := level.UnmarshalText([]byte(cfg.AgentConfig.LogLevel)); err != nil {
|
||||
if err := level.UnmarshalText([]byte(cfg.LogLevel)); err != nil {
|
||||
log.Println(err)
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
handler := agentlogger.NewProtoHandler(ackConn, &slog.HandlerOptions{Level: level}, cfg.ID)
|
||||
eventsLogsQueue := make(chan *cvms.ClientStreamMessage, 1000)
|
||||
|
||||
handler := agentlogger.NewProtoHandler(os.Stdout, &slog.HandlerOptions{Level: level}, eventsLogsQueue)
|
||||
logger := slog.New(handler)
|
||||
|
||||
eventSvc, err := events.New(svcName, cfg.ID, ackConn)
|
||||
eventSvc, err := events.New(svcName, eventsLogsQueue)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create events service %s", err.Error()))
|
||||
exitCode = 1
|
||||
@@ -87,64 +78,49 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
if err := verifyManifest(cfg, qp); err != nil {
|
||||
cvmGrpcConfig := pkggrpc.CVMClientConfig{}
|
||||
if err := env.ParseWithOptions(&cvmGrpcConfig, env.Options{Prefix: envPrefixCVMGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
cvmGRPCClient, cvmClient, err := cvmgrpc.NewCVMClient(cvmGrpcConfig)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
defer cvmGRPCClient.Close()
|
||||
|
||||
pc, err := cvmClient.Process(ctx)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
setDefaultValues(&cfg)
|
||||
svc := newService(ctx, logger, eventSvc, qp)
|
||||
|
||||
svc := newService(ctx, logger, eventSvc, cfg, qp)
|
||||
|
||||
agentGrpcServerConfig := server.AgentConfig{
|
||||
ServerConfig: server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
Host: cfg.AgentConfig.Host,
|
||||
Port: cfg.AgentConfig.Port,
|
||||
CertFile: cfg.AgentConfig.CertFile,
|
||||
KeyFile: cfg.AgentConfig.KeyFile,
|
||||
ServerCAFile: cfg.AgentConfig.ServerCAFile,
|
||||
ClientCAFile: cfg.AgentConfig.ClientCAFile,
|
||||
},
|
||||
},
|
||||
AttestedTLS: cfg.AgentConfig.AttestedTls,
|
||||
}
|
||||
|
||||
registerAgentServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
agent.RegisterAgentServiceServer(srv, agentgrpc.NewServer(svc))
|
||||
}
|
||||
|
||||
authSvc, err := auth.New(cfg)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create auth service %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
gs := grpcserver.New(ctx, cancel, svcName, agentGrpcServerConfig, registerAgentServiceServer, logger, qp, authSvc)
|
||||
mc := cvmapi.NewClient(pc, svc, eventsLogsQueue, logger, server.NewServer(logger, svc))
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
if _, err := io.Copy(io.Discard, conn); err != nil {
|
||||
log.Printf("vsock connection lost: %v, reconnecting...", err)
|
||||
conn.Close()
|
||||
conn, err = dialVsock()
|
||||
if err != nil {
|
||||
log.Fatal("failed to reconnect: ", err)
|
||||
}
|
||||
}
|
||||
time.Sleep(retryInterval)
|
||||
ch := make(chan os.Signal, 1)
|
||||
signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(ch)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
logger.Info("Received signal, shutting down...")
|
||||
cancel()
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return gs.Start()
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return server.StopHandler(ctx, cancel, logger, svcName, gs)
|
||||
return mc.Process(ctx, cancel)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
@@ -152,8 +128,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp agent.Computation, qp client.QuoteProvider) agent.Service {
|
||||
svc := agent.New(ctx, logger, eventSvc, cmp, qp)
|
||||
func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, qp client.QuoteProvider) agent.Service {
|
||||
svc := agent.New(ctx, logger, eventSvc, qp)
|
||||
|
||||
svc = api.LoggingMiddleware(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics(svcName, "api")
|
||||
@@ -161,104 +137,3 @@ func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Servic
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func readConfig() (agent.Computation, error) {
|
||||
l, err := vsock.Listen(qemu.VsockConfigPort, nil)
|
||||
if err != nil {
|
||||
return agent.Computation{}, err
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return agent.Computation{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var buffer []byte
|
||||
for {
|
||||
chunk := make([]byte, 1024)
|
||||
n, err := conn.Read(chunk)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return agent.Computation{}, err
|
||||
}
|
||||
buffer = append(buffer, chunk[:n]...)
|
||||
}
|
||||
|
||||
ac := agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{},
|
||||
}
|
||||
if err := json.Unmarshal(buffer, &ac); err != nil {
|
||||
return agent.Computation{}, err
|
||||
}
|
||||
return ac, nil
|
||||
}
|
||||
|
||||
func setDefaultValues(cfg *agent.Computation) {
|
||||
if cfg.AgentConfig.LogLevel == "" {
|
||||
cfg.AgentConfig.LogLevel = "info"
|
||||
}
|
||||
if cfg.AgentConfig.Port == "" {
|
||||
cfg.AgentConfig.Port = defSvcGRPCPort
|
||||
}
|
||||
}
|
||||
|
||||
func isTEE() bool {
|
||||
_, err := os.Stat("/dev/sev-guest")
|
||||
return !os.IsNotExist(err)
|
||||
}
|
||||
|
||||
func dialVsock() (*vsock.Conn, error) {
|
||||
var conn *vsock.Conn
|
||||
var err error
|
||||
|
||||
err = backoff.Retry(func() error {
|
||||
conn, err = vsock.Dial(vsock.Host, managerevents.ManagerVsockPort, nil)
|
||||
if err == nil {
|
||||
log.Println("vsock connection established")
|
||||
return nil
|
||||
}
|
||||
log.Printf("vsock connection failed, retrying in %s... Error: %v", retryInterval, err)
|
||||
return err
|
||||
}, backoff.NewExponentialBackOff())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func verifyManifest(cfg agent.Computation, qp client.QuoteProvider) error {
|
||||
if !isTEE() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ar, err := qp.GetRawQuote(sha3.Sum512([]byte(cfg.ID)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
arProto, err := abi.ReportCertsToProto(ar[:abi.ReportSize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfgBytes, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcHash := sha3.Sum256(cfgBytes)
|
||||
|
||||
if arProto.Report.HostData == nil {
|
||||
return fmt.Errorf("manifest verification failed: HostData is nil")
|
||||
}
|
||||
if !bytes.Equal(arProto.Report.HostData, mcHash[:]) {
|
||||
return fmt.Errorf("manifest verification failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
qpmocks "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks"
|
||||
)
|
||||
|
||||
func TestSetDefaultValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input agent.Computation
|
||||
expected agent.Computation
|
||||
}{
|
||||
{
|
||||
name: "Empty config",
|
||||
input: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{},
|
||||
},
|
||||
expected: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "info",
|
||||
Port: "7002",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Partial config",
|
||||
input: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "debug",
|
||||
},
|
||||
},
|
||||
expected: agent.Computation{
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "debug",
|
||||
Port: "7002",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
setDefaultValues(&tt.input)
|
||||
assert.Equal(t, tt.expected, tt.input)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
eventSvc := new(mocks.Service)
|
||||
eventSvc.On("SendEvent", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
cmp := agent.Computation{
|
||||
ID: "test-computation",
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "info",
|
||||
Port: "7002",
|
||||
},
|
||||
}
|
||||
qp := new(qpmocks.QuoteProvider)
|
||||
|
||||
svc := newService(ctx, logger, eventSvc, cmp, qp)
|
||||
|
||||
assert.NotNil(t, svc)
|
||||
}
|
||||
|
||||
func TestVerifyManifest(t *testing.T) {
|
||||
cfg := agent.Computation{
|
||||
ID: "test-computation",
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "info",
|
||||
Port: "7002",
|
||||
},
|
||||
}
|
||||
|
||||
mockQP := new(qpmocks.QuoteProvider)
|
||||
mockQP.On("GetRawQuote", mock.Anything).Return([]byte{}, nil)
|
||||
|
||||
err := verifyManifest(cfg, mockQP)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
+1
-1
@@ -92,7 +92,7 @@ func main() {
|
||||
args := qemuCfg.ConstructQemuArgs()
|
||||
logger.Info(strings.Join(args, " "))
|
||||
|
||||
managerGRPCConfig := pkggrpc.ManagerClientConfig{}
|
||||
managerGRPCConfig := pkggrpc.CVMClientConfig{}
|
||||
if err := env.ParseWithOptions(&managerGRPCConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC client configuration : %s", svcName, err))
|
||||
exitCode = 1
|
||||
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
@@ -18,16 +19,17 @@ type handler struct {
|
||||
opts slog.HandlerOptions
|
||||
w io.Writer
|
||||
cmpID string
|
||||
queue chan *cvms.ClientStreamMessage
|
||||
}
|
||||
|
||||
func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, cmpID string) slog.Handler {
|
||||
func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, queue chan *cvms.ClientStreamMessage) slog.Handler {
|
||||
if opts == nil {
|
||||
opts = &slog.HandlerOptions{}
|
||||
}
|
||||
h := &handler{
|
||||
opts: *opts,
|
||||
w: conn,
|
||||
cmpID: cmpID,
|
||||
queue: queue,
|
||||
}
|
||||
|
||||
return h
|
||||
@@ -69,7 +71,18 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
|
||||
},
|
||||
}
|
||||
|
||||
b, err := proto.Marshal(&agentLog)
|
||||
h.queue <- &cvms.ClientStreamMessage{
|
||||
Message: &cvms.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &cvms.AgentLog{
|
||||
Timestamp: timestamp,
|
||||
Message: chunk,
|
||||
Level: level,
|
||||
ComputationId: h.cmpID,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b, err := protojson.Marshal(&agentLog)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -78,6 +91,11 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = h.w.Write([]byte("\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -88,7 +106,8 @@ func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
}
|
||||
|
||||
func (h *handler) WithGroup(name string) slog.Handler {
|
||||
panic("unimplemented")
|
||||
h.cmpID = name
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *handler) Close() error {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
)
|
||||
|
||||
type failedWriter struct{}
|
||||
@@ -21,14 +22,14 @@ func (f *failedWriter) Write(p []byte) (n int, err error) {
|
||||
|
||||
// TestNewProtoHandler tests the initialization of the ProtoHandler.
|
||||
func TestNewProtoHandler(t *testing.T) {
|
||||
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
|
||||
handler := NewProtoHandler(io.Discard, nil, make(chan *cvms.ClientStreamMessage))
|
||||
|
||||
assert.NotNil(t, handler, "Handler should not be nil")
|
||||
}
|
||||
|
||||
// TestHandleMessageSuccess tests the handling of a message when the write succeeds.
|
||||
func TestHandleMessageSuccess(t *testing.T) {
|
||||
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
|
||||
handler := NewProtoHandler(io.Discard, nil, make(chan *cvms.ClientStreamMessage, 1))
|
||||
record := slog.Record{
|
||||
Time: time.Now(),
|
||||
Message: "Test message",
|
||||
@@ -42,7 +43,7 @@ func TestHandleMessageSuccess(t *testing.T) {
|
||||
|
||||
// TestHandleMessageFailure tests the caching mechanism when the write fails.
|
||||
func TestHandleMessageFailure(t *testing.T) {
|
||||
protohandler := NewProtoHandler(&failedWriter{}, nil, "testCmpID")
|
||||
protohandler := NewProtoHandler(&failedWriter{}, nil, make(chan *cvms.ClientStreamMessage, 1))
|
||||
record := slog.Record{
|
||||
Time: time.Now(),
|
||||
Message: "Test message",
|
||||
@@ -56,7 +57,7 @@ func TestHandleMessageFailure(t *testing.T) {
|
||||
|
||||
// TestEnabled tests that the handler enables logging based on level.
|
||||
func TestEnabled(t *testing.T) {
|
||||
handler := NewProtoHandler(io.Discard, nil, "testCmpID")
|
||||
handler := NewProtoHandler(io.Discard, nil, make(chan *cvms.ClientStreamMessage, 1))
|
||||
|
||||
assert.True(t, handler.Enabled(context.Background(), slog.LevelInfo), "Logging should be enabled for LevelInfo")
|
||||
assert.False(t, handler.Enabled(context.Background(), slog.LevelDebug), "Logging should be disabled for LevelDebug by default")
|
||||
@@ -66,7 +67,7 @@ func TestEnabled(t *testing.T) {
|
||||
func TestCloseStopsRetry(t *testing.T) {
|
||||
mockWriter := io.Discard
|
||||
|
||||
handler := NewProtoHandler(mockWriter, nil, "testCmpID").(*handler)
|
||||
handler := NewProtoHandler(mockWriter, nil, make(chan *cvms.ClientStreamMessage, 1)).(*handler)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
err := handler.Close()
|
||||
|
||||
+213
-461
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.5.1
|
||||
// - protoc v5.28.1
|
||||
// - protoc v5.29.0
|
||||
// source: manager/manager.proto
|
||||
|
||||
package manager
|
||||
|
||||
@@ -166,16 +166,6 @@ func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string
|
||||
ID: c.Id,
|
||||
Name: c.Name,
|
||||
Description: c.Description,
|
||||
AgentConfig: agent.AgentConfig{
|
||||
Port: c.AgentConfig.Port,
|
||||
Host: c.AgentConfig.Host,
|
||||
KeyFile: c.AgentConfig.KeyFile,
|
||||
CertFile: c.AgentConfig.CertFile,
|
||||
ServerCAFile: c.AgentConfig.ServerCaFile,
|
||||
ClientCAFile: c.AgentConfig.ClientCaFile,
|
||||
LogLevel: c.AgentConfig.LogLevel,
|
||||
AttestedTls: c.AgentConfig.AttestedTls,
|
||||
},
|
||||
}
|
||||
if len(c.Algorithm.Hash) != hashLength {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
|
||||
@@ -93,3 +93,10 @@ packages:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "sdk.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
github.com/ultravioletrs/cocos/agent/cvms/server:
|
||||
interfaces:
|
||||
AgentServerProvider:
|
||||
config:
|
||||
dir: "{{.InterfaceDir}}/mocks"
|
||||
filename: "server.go"
|
||||
mockname: "{{.InterfaceName}}"
|
||||
|
||||
@@ -68,7 +68,7 @@ type AgentClientConfig struct {
|
||||
AttestedTLS bool `env:"ATTESTED_TLS" envDefault:"false"`
|
||||
}
|
||||
|
||||
type ManagerClientConfig struct {
|
||||
type CVMClientConfig struct {
|
||||
BaseConfig
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func (a AgentClientConfig) GetBaseConfig() BaseConfig {
|
||||
return a.BaseConfig
|
||||
}
|
||||
|
||||
func (a ManagerClientConfig) GetBaseConfig() BaseConfig {
|
||||
func (a CVMClientConfig) GetBaseConfig() BaseConfig {
|
||||
return a.BaseConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package cvm
|
||||
|
||||
import (
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
)
|
||||
|
||||
// NewManagerClient creates new manager gRPC client instance.
|
||||
func NewCVMClient(cfg grpc.CVMClientConfig) (grpc.Client, cvms.ServiceClient, error) {
|
||||
client, err := grpc.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return client, cvms.NewServiceClient(client.Connection()), nil
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
// NewManagerClient creates new manager gRPC client instance.
|
||||
func NewManagerClient(cfg grpc.ManagerClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
|
||||
func NewManagerClient(cfg grpc.CVMClientConfig) (grpc.Client, manager.ManagerServiceClient, error) {
|
||||
client, err := grpc.NewClient(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -13,12 +13,12 @@ import (
|
||||
func TestNewManagerClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg grpc.ManagerClientConfig
|
||||
cfg grpc.CVMClientConfig
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid config",
|
||||
cfg: grpc.ManagerClientConfig{
|
||||
cfg: grpc.CVMClientConfig{
|
||||
BaseConfig: grpc.BaseConfig{
|
||||
URL: "localhost:7001",
|
||||
},
|
||||
|
||||
@@ -13,18 +13,18 @@ import (
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/ultravioletrs/cocos/agent/cvms"
|
||||
cvmgrpc "github.com/ultravioletrs/cocos/agent/cvms/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
var _ managergrpc.Service = (*svc)(nil)
|
||||
var _ cvmgrpc.Service = (*svc)(nil)
|
||||
|
||||
const (
|
||||
svcName = "computations_test_server"
|
||||
@@ -42,7 +42,7 @@ type svc struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc.SendFunc, authInfo credentials.AuthInfo) {
|
||||
func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage cvmgrpc.SendFunc, authInfo credentials.AuthInfo) {
|
||||
s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAddress))
|
||||
|
||||
pubKey, err := os.ReadFile(pubKeyFile)
|
||||
@@ -52,7 +52,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc
|
||||
}
|
||||
pubPem, _ := pem.Decode(pubKey)
|
||||
|
||||
var datasets []*manager.Dataset
|
||||
var datasets []*cvms.Dataset
|
||||
for _, dataPath := range dataPaths {
|
||||
if _, err := os.Stat(dataPath); os.IsNotExist(err) {
|
||||
s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath))
|
||||
@@ -64,7 +64,7 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc
|
||||
return
|
||||
}
|
||||
|
||||
datasets = append(datasets, &manager.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
|
||||
datasets = append(datasets, &cvms.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes})
|
||||
}
|
||||
|
||||
algoHash, err := internal.Checksum(algoPath)
|
||||
@@ -73,16 +73,16 @@ func (s *svc) Run(ctx context.Context, ipAddress string, sendMessage managergrpc
|
||||
return
|
||||
}
|
||||
|
||||
if err := sendMessage(&manager.ServerStreamMessage{
|
||||
Message: &manager.ServerStreamMessage_RunReq{
|
||||
RunReq: &manager.ComputationRunReq{
|
||||
if err := sendMessage(&cvms.ServerStreamMessage{
|
||||
Message: &cvms.ServerStreamMessage_RunReq{
|
||||
RunReq: &cvms.ComputationRunReq{
|
||||
Id: "1",
|
||||
Name: "sample computation",
|
||||
Description: "sample descrption",
|
||||
Datasets: datasets,
|
||||
Algorithm: &manager.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
|
||||
ResultConsumers: []*manager.ResultConsumer{{UserKey: pubPem.Bytes}},
|
||||
AgentConfig: &manager.AgentConfig{
|
||||
Algorithm: &cvms.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes},
|
||||
ResultConsumers: []*cvms.ResultConsumer{{UserKey: pubPem.Bytes}},
|
||||
AgentConfig: &cvms.AgentConfig{
|
||||
Port: "7002",
|
||||
LogLevel: "debug",
|
||||
AttestedTls: attestedTLS,
|
||||
@@ -113,7 +113,7 @@ func main() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
incomingChan := make(chan *manager.ClientStreamMessage)
|
||||
incomingChan := make(chan *cvms.ClientStreamMessage)
|
||||
|
||||
logger, err := mglog.New(os.Stdout, "debug")
|
||||
if err != nil {
|
||||
@@ -128,7 +128,7 @@ func main() {
|
||||
|
||||
registerAgentServiceServer := func(srv *grpc.Server) {
|
||||
reflection.Register(srv)
|
||||
manager.RegisterManagerServiceServer(srv, managergrpc.NewServer(incomingChan, &svc{logger: logger}))
|
||||
cvms.RegisterServiceServer(srv, cvmgrpc.NewServer(incomingChan, &svc{logger: logger}))
|
||||
}
|
||||
grpcServerConfig := server.ServerConfig{
|
||||
BaseConfig: server.BaseConfig{
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Simplified script to pass configs to agent without manager and read logs and events for manager.
|
||||
// This tool is meant for testing purposes.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/events"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
)
|
||||
|
||||
const (
|
||||
managerVsockPort = events.ManagerVsockPort
|
||||
vsockConfigPort = qemu.VsockConfigPort
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 5 {
|
||||
log.Fatalf("usage: %s <data-path> <algo-path> <public-key-path> <attested-tls-bool>", os.Args[0])
|
||||
}
|
||||
dataPath := os.Args[1]
|
||||
algoPath := os.Args[2]
|
||||
pubKeyFile := os.Args[3]
|
||||
attestedTLSParam, err := strconv.ParseBool(os.Args[4])
|
||||
if err != nil {
|
||||
log.Fatalf("usage: %s <data-path> <algo-path> <public-key-path> <attested-tls-bool>, <attested-tls-bool> must be a bool value", os.Args[0])
|
||||
}
|
||||
attestedTLS := attestedTLSParam
|
||||
|
||||
pubKey, err := os.ReadFile(pubKeyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to read public key file: %s", err)
|
||||
}
|
||||
pubPem, _ := pem.Decode(pubKey)
|
||||
algoHash, err := internal.Checksum(algoPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to calculate checksum: %s", err)
|
||||
}
|
||||
dataHash, err := internal.Checksum(dataPath)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to calculate checksum: %s", err)
|
||||
}
|
||||
|
||||
ac := agent.Computation{
|
||||
ID: "123",
|
||||
Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}},
|
||||
Algorithm: agent.Algorithm{Hash: [32]byte(algoHash), UserKey: pubPem.Bytes},
|
||||
ResultConsumers: []agent.ResultConsumer{{UserKey: pubPem.Bytes}},
|
||||
AgentConfig: agent.AgentConfig{
|
||||
LogLevel: "debug",
|
||||
Port: "7002",
|
||||
AttestedTls: attestedTLS,
|
||||
},
|
||||
}
|
||||
if err := sendAgentConfig(3, ac); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
listener, err := vsock.Listen(managerVsockPort, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen on vsock: %s", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
log.Printf("Listening on vsock port %d", managerVsockPort)
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func sendAgentConfig(cid uint32, ac agent.Computation) error {
|
||||
conn, err := vsock.Dial(cid, qemu.VsockConfigPort, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
payload, err := json.Marshal(ac)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ac2 agent.Computation
|
||||
if err := json.Unmarshal(payload, &ac2); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := conn.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
var message manager.ClientStreamMessage
|
||||
err := ackReader.ReadProto(&message)
|
||||
if err != nil {
|
||||
log.Printf("Error reading message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Received message: %s", message.String())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user