NOISSUE - Refactor manager events and detangle service (#287)

* extract events service

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* major refactor and detangling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* small fixes

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* handle tests better

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix lint

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix race condition

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix race

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* use plain interface

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* move mutex

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-30 18:07:54 +03:00
committed by GitHub
parent 20e7ea76e0
commit fad3182638
52 changed files with 1190 additions and 523 deletions
+2 -1
View File
@@ -35,7 +35,8 @@ $(BACKEND_INFO):
protoc:
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/agent.proto
protoc -I. --go_out=./pkg --go_opt=paths=source_relative --go-grpc_out=./pkg --go-grpc_opt=paths=source_relative manager/manager.proto
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative manager/manager.proto
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto
mocks:
go generate ./...
+3 -2
View File
@@ -9,6 +9,7 @@ import (
"os/exec"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events"
)
@@ -24,8 +25,8 @@ type binary struct {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
return &binary{
algoFile: algoFile,
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &algorithm.Stdout{Logger: logger},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &logging.Stdout{Logger: logger},
args: args,
}
}
+3 -3
View File
@@ -8,7 +8,7 @@ import (
"os"
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
@@ -39,11 +39,11 @@ func TestNewAlgorithm(t *testing.T) {
}
}
if _, ok := b.stderr.(*algorithm.Stderr); !ok {
if _, ok := b.stderr.(*logging.Stderr); !ok {
t.Errorf("Expected stderr to be *algorithm.Stderr")
}
if _, ok := b.stdout.(*algorithm.Stdout); !ok {
if _, ok := b.stdout.(*logging.Stdout); !ok {
t.Errorf("Expected stdout to be *algorithm.Stdout")
}
}
+3 -2
View File
@@ -16,6 +16,7 @@ import (
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/client"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events"
)
@@ -38,8 +39,8 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
d := &docker{
algoFile: algoFile,
logger: logger,
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &algorithm.Stdout{Logger: logger},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &logging.Stdout{Logger: logger},
}
return d
+3 -3
View File
@@ -8,7 +8,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
@@ -24,6 +24,6 @@ func TestNewAlgorithm(t *testing.T) {
assert.True(t, ok, "NewAlgorithm should return a *docker")
assert.Equal(t, algoFile, d.algoFile, "algoFile should be set correctly")
assert.NotNil(t, d.logger, "logger should be set")
assert.IsType(t, &algorithm.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr")
assert.IsType(t, &algorithm.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout")
assert.IsType(t, &logging.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr")
assert.IsType(t, &logging.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout")
}
@@ -1,6 +1,6 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package algorithm
package logging
import (
"bytes"
@@ -1,6 +1,6 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package algorithm
package logging
import (
"strings"
+3 -2
View File
@@ -12,6 +12,7 @@ import (
"path/filepath"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events"
"google.golang.org/grpc/metadata"
)
@@ -43,8 +44,8 @@ type python struct {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string) algorithm.Algorithm {
p := &python{
algoFile: algoFile,
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &algorithm.Stdout{Logger: logger},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &logging.Stdout{Logger: logger},
requirementsFile: requirementsFile,
args: args,
}
+5 -5
View File
@@ -12,7 +12,7 @@ import (
"strings"
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
"google.golang.org/grpc/metadata"
)
@@ -90,8 +90,8 @@ func TestRun(t *testing.T) {
algo := &python{
algoFile: scriptPath,
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
stderr: io.MultiWriter(&stderr, &logging.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
stdout: io.MultiWriter(&stdout, &logging.Stdout{Logger: slog.Default()}),
runtime: "python3",
}
@@ -132,8 +132,8 @@ func TestRunWithRequirements(t *testing.T) {
algo := &python{
algoFile: scriptPath,
requirementsFile: requirementsPath,
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
stderr: io.MultiWriter(&stderr, &logging.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
stdout: io.MultiWriter(&stdout, &logging.Stdout{Logger: slog.Default()}),
runtime: "python3",
}
+3 -2
View File
@@ -9,6 +9,7 @@ import (
"os/exec"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events"
)
@@ -28,8 +29,8 @@ type wasm struct {
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
return &wasm{
algoFile: algoFile,
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &algorithm.Stdout{Logger: logger},
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
stdout: &logging.Stdout{Logger: logger},
args: args,
}
}
+3 -3
View File
@@ -8,7 +8,7 @@ import (
"os/exec"
"testing"
"github.com/ultravioletrs/cocos/agent/algorithm"
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
"github.com/ultravioletrs/cocos/agent/events/mocks"
)
@@ -33,12 +33,12 @@ func TestNewAlgorithm(t *testing.T) {
t.Errorf("Expected %d args, got %d", len(args), len(w.args))
}
_, ok = w.stderr.(*algorithm.Stderr)
_, ok = w.stderr.(*logging.Stderr)
if !ok {
t.Errorf("Expected stderr to be *algorithm.Stderr")
}
_, ok = w.stdout.(*algorithm.Stdout)
_, ok = w.stdout.(*logging.Stdout)
if !ok {
t.Errorf("Expected stdout to be *algorithm.Stdout")
}
+1 -11
View File
@@ -8,7 +8,6 @@ import (
"sync"
"time"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
@@ -24,15 +23,6 @@ type service struct {
stopRetry chan struct{}
}
type AgentEvent struct {
EventType string `json:"event_type"`
Timestamp time.Time `json:"timestamp"`
ComputationID string `json:"computation_id,omitempty"`
Details json.RawMessage `json:"details,omitempty"`
Originator string `json:"originator"`
Status string `json:"status,omitempty"`
}
//go:generate mockery --name Service --output=./mocks --filename events.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type Service interface {
SendEvent(event, status string, details json.RawMessage) error
@@ -54,7 +44,7 @@ func New(svc, computationID string, conn io.Writer) (Service, error) {
}
func (s *service) SendEvent(event, status string, details json.RawMessage) error {
body := manager.ClientStreamMessage{Message: &manager.ClientStreamMessage_AgentEvent{AgentEvent: &manager.AgentEvent{
body := EventsLogs{Message: &EventsLogs_AgentEvent{AgentEvent: &AgentEvent{
EventType: event,
Timestamp: timestamppb.Now(),
ComputationId: s.computationID,
+405
View File
@@ -0,0 +1,405 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.28.1
// source: agent/events/events.proto
package events
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type AgentEvent struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"`
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
}
func (x *AgentEvent) Reset() {
*x = AgentEvent{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_events_events_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *AgentEvent) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AgentEvent) ProtoMessage() {}
func (x *AgentEvent) ProtoReflect() protoreflect.Message {
mi := &file_agent_events_events_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AgentEvent.ProtoReflect.Descriptor instead.
func (*AgentEvent) Descriptor() ([]byte, []int) {
return file_agent_events_events_proto_rawDescGZIP(), []int{0}
}
func (x *AgentEvent) GetEventType() string {
if x != nil {
return x.EventType
}
return ""
}
func (x *AgentEvent) GetTimestamp() *timestamppb.Timestamp {
if x != nil {
return x.Timestamp
}
return nil
}
func (x *AgentEvent) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
func (x *AgentEvent) GetDetails() []byte {
if x != nil {
return x.Details
}
return nil
}
func (x *AgentEvent) GetOriginator() string {
if x != nil {
return x.Originator
}
return ""
}
func (x *AgentEvent) GetStatus() string {
if x != nil {
return x.Status
}
return ""
}
type AgentLog struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
}
func (x *AgentLog) Reset() {
*x = AgentLog{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_events_events_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *AgentLog) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AgentLog) ProtoMessage() {}
func (x *AgentLog) ProtoReflect() protoreflect.Message {
mi := &file_agent_events_events_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AgentLog.ProtoReflect.Descriptor instead.
func (*AgentLog) Descriptor() ([]byte, []int) {
return file_agent_events_events_proto_rawDescGZIP(), []int{1}
}
func (x *AgentLog) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *AgentLog) GetComputationId() string {
if x != nil {
return x.ComputationId
}
return ""
}
func (x *AgentLog) GetLevel() string {
if x != nil {
return x.Level
}
return ""
}
func (x *AgentLog) GetTimestamp() *timestamppb.Timestamp {
if x != nil {
return x.Timestamp
}
return nil
}
type EventsLogs struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Message:
//
// *EventsLogs_AgentLog
// *EventsLogs_AgentEvent
Message isEventsLogs_Message `protobuf_oneof:"message"`
}
func (x *EventsLogs) Reset() {
*x = EventsLogs{}
if protoimpl.UnsafeEnabled {
mi := &file_agent_events_events_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *EventsLogs) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*EventsLogs) ProtoMessage() {}
func (x *EventsLogs) ProtoReflect() protoreflect.Message {
mi := &file_agent_events_events_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use EventsLogs.ProtoReflect.Descriptor instead.
func (*EventsLogs) Descriptor() ([]byte, []int) {
return file_agent_events_events_proto_rawDescGZIP(), []int{2}
}
func (m *EventsLogs) GetMessage() isEventsLogs_Message {
if m != nil {
return m.Message
}
return nil
}
func (x *EventsLogs) GetAgentLog() *AgentLog {
if x, ok := x.GetMessage().(*EventsLogs_AgentLog); ok {
return x.AgentLog
}
return nil
}
func (x *EventsLogs) GetAgentEvent() *AgentEvent {
if x, ok := x.GetMessage().(*EventsLogs_AgentEvent); ok {
return x.AgentEvent
}
return nil
}
type isEventsLogs_Message interface {
isEventsLogs_Message()
}
type EventsLogs_AgentLog struct {
AgentLog *AgentLog `protobuf:"bytes,1,opt,name=agent_log,json=agentLog,proto3,oneof"`
}
type EventsLogs_AgentEvent struct {
AgentEvent *AgentEvent `protobuf:"bytes,2,opt,name=agent_event,json=agentEvent,proto3,oneof"`
}
func (*EventsLogs_AgentLog) isEventsLogs_Message() {}
func (*EventsLogs_AgentEvent) isEventsLogs_Message() {}
var File_agent_events_events_proto protoreflect.FileDescriptor
var file_agent_events_events_proto_rawDesc = []byte{
0x0a, 0x19, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x65,
0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x65, 0x76, 0x65,
0x6e, 0x74, 0x73, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x22, 0xde, 0x01, 0x0a, 0x0a, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76,
0x65, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70,
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x54, 0x79,
0x70, 0x65, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d,
0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x25, 0x0a, 0x0e,
0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x03,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04,
0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1e, 0x0a,
0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28,
0x09, 0x52, 0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x16, 0x0a,
0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73,
0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x9b, 0x01, 0x0a, 0x08, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x4c,
0x6f, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x25, 0x0a, 0x0e,
0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f,
0x6e, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01,
0x28, 0x09, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d,
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67,
0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54,
0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
0x61, 0x6d, 0x70, 0x22, 0x7f, 0x0a, 0x0a, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x4c, 0x6f, 0x67,
0x73, 0x12, 0x2f, 0x0a, 0x09, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x6c, 0x6f, 0x67, 0x18, 0x01,
0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x41, 0x67,
0x65, 0x6e, 0x74, 0x4c, 0x6f, 0x67, 0x48, 0x00, 0x52, 0x08, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x4c,
0x6f, 0x67, 0x12, 0x35, 0x0a, 0x0b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x65, 0x76, 0x65, 0x6e,
0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x61,
0x67, 0x65, 0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x42, 0x0a, 0x5a, 0x08, 0x2e, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_agent_events_events_proto_rawDescOnce sync.Once
file_agent_events_events_proto_rawDescData = file_agent_events_events_proto_rawDesc
)
func file_agent_events_events_proto_rawDescGZIP() []byte {
file_agent_events_events_proto_rawDescOnce.Do(func() {
file_agent_events_events_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_events_events_proto_rawDescData)
})
return file_agent_events_events_proto_rawDescData
}
var file_agent_events_events_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
var file_agent_events_events_proto_goTypes = []any{
(*AgentEvent)(nil), // 0: events.AgentEvent
(*AgentLog)(nil), // 1: events.AgentLog
(*EventsLogs)(nil), // 2: events.EventsLogs
(*timestamppb.Timestamp)(nil), // 3: google.protobuf.Timestamp
}
var file_agent_events_events_proto_depIdxs = []int32{
3, // 0: events.AgentEvent.timestamp:type_name -> google.protobuf.Timestamp
3, // 1: events.AgentLog.timestamp:type_name -> google.protobuf.Timestamp
1, // 2: events.EventsLogs.agent_log:type_name -> events.AgentLog
0, // 3: events.EventsLogs.agent_event:type_name -> events.AgentEvent
4, // [4:4] is the sub-list for method output_type
4, // [4:4] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
}
func init() { file_agent_events_events_proto_init() }
func file_agent_events_events_proto_init() {
if File_agent_events_events_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_agent_events_events_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*AgentEvent); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_events_events_proto_msgTypes[1].Exporter = func(v any, i int) any {
switch v := v.(*AgentLog); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_agent_events_events_proto_msgTypes[2].Exporter = func(v any, i int) any {
switch v := v.(*EventsLogs); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_agent_events_events_proto_msgTypes[2].OneofWrappers = []any{
(*EventsLogs_AgentLog)(nil),
(*EventsLogs_AgentEvent)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_agent_events_events_proto_rawDesc,
NumEnums: 0,
NumMessages: 3,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_agent_events_events_proto_goTypes,
DependencyIndexes: file_agent_events_events_proto_depIdxs,
MessageInfos: file_agent_events_events_proto_msgTypes,
}.Build()
File_agent_events_events_proto = out.File
file_agent_events_events_proto_rawDesc = nil
file_agent_events_events_proto_goTypes = nil
file_agent_events_events_proto_depIdxs = nil
}
+33
View File
@@ -0,0 +1,33 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package events;
import "google/protobuf/timestamp.proto";
option go_package = "./events";
message AgentEvent {
string event_type = 1;
google.protobuf.Timestamp timestamp = 2;
string computation_id = 3;
bytes details = 4;
string originator = 5;
string status = 6;
}
message AgentLog {
string message = 1;
string computation_id = 2;
string level = 3;
google.protobuf.Timestamp timestamp = 4;
}
message EventsLogs {
oneof message {
AgentLog agent_log = 1;
AgentEvent agent_event = 2;
}
}
+1 -2
View File
@@ -10,7 +10,6 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
)
@@ -37,7 +36,7 @@ func TestSendEventSuccess(t *testing.T) {
err = svc.SendEvent("test_event", "success", details)
assert.NoError(t, err)
var writtenMessage manager.ClientStreamMessage
var writtenMessage EventsLogs
err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage)
assert.NoError(t, err)
+8
View File
@@ -12,6 +12,7 @@ import (
"os"
"path/filepath"
"slices"
sync "sync"
"github.com/google/go-sev-guest/client"
"github.com/ultravioletrs/cocos/agent/algorithm"
@@ -111,6 +112,7 @@ type Service interface {
}
type agentService struct {
mu sync.Mutex
computation Computation // Holds the current computation request details.
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
result []byte // Stores the result of the computation.
@@ -181,6 +183,8 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
if as.sm.GetState() != ReceivingAlgorithm {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
if as.algorithm != nil {
return ErrAllManifestItemsReceived
}
@@ -262,6 +266,8 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
if as.sm.GetState() != ReceivingData {
return ErrStateNotReady
}
as.mu.Lock()
defer as.mu.Unlock()
if len(as.computation.Datasets) == 0 {
return ErrAllManifestItemsReceived
}
@@ -322,6 +328,8 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
return []byte{}, ErrUndeclaredConsumer
}
as.mu.Lock()
defer as.mu.Unlock()
if index < 0 || index >= len(as.computation.ResultConsumers) {
return []byte{}, ErrUndeclaredConsumer
}
+2 -2
View File
@@ -29,7 +29,7 @@ import (
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
managerevents "github.com/ultravioletrs/cocos/manager/events"
"github.com/ultravioletrs/cocos/manager/qemu"
"golang.org/x/crypto/sha3"
"golang.org/x/sync/errgroup"
@@ -213,7 +213,7 @@ func dialVsock() (*vsock.Conn, error) {
var err error
err = backoff.Retry(func() error {
conn, err = vsock.Dial(vsock.Host, manager.ManagerVsockPort, nil)
conn, err = vsock.Dial(vsock.Host, managerevents.ManagerVsockPort, nil)
if err == nil {
log.Println("vsock connection established")
return nil
+12 -4
View File
@@ -22,11 +22,11 @@ import (
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/api"
managerapi "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/manager/events"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/tracing"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
)
@@ -113,7 +113,7 @@ func main() {
return
}
eventsChan := make(chan *pkgmanager.ClientStreamMessage, clientBufferSize)
eventsChan := make(chan *manager.ClientStreamMessage, clientBufferSize)
svc, err := newService(logger, tracer, qemuCfg, eventsChan, cfg.BackendMeasurementBinary)
if err != nil {
logger.Error(err.Error())
@@ -121,6 +121,15 @@ func main() {
return
}
eventsSvc := events.New(logger, svc.ReportBrokenConnection, eventsChan)
if eventsSvc == nil {
logger.Error("Failed to create events service")
exitCode = 1
return
}
go eventsSvc.Listen(ctx)
mc := managerapi.NewClient(pc, svc, eventsChan, logger)
g.Go(func() error {
@@ -147,12 +156,11 @@ func main() {
}
}
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, eventsChan chan *pkgmanager.ClientStreamMessage, backendMeasurementPath string) (manager.Service, error) {
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, eventsChan chan *manager.ClientStreamMessage, backendMeasurementPath string) (manager.Service, error) {
svc, err := manager.New(qemuCfg, backendMeasurementPath, logger, eventsChan, qemu.NewVM)
if err != nil {
return nil, err
}
go svc.RetrieveAgentEventsLogs()
svc = api.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics(svcName, "api")
svc = api.MetricsMiddleware(svc, counter, latency)
+4 -4
View File
@@ -9,7 +9,7 @@ import (
"sync"
"time"
"github.com/ultravioletrs/cocos/pkg/manager"
"github.com/ultravioletrs/cocos/agent/events"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
@@ -71,9 +71,9 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
chunk := message[start:end]
agentLog := manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentLog{
AgentLog: &manager.AgentLog{
agentLog := events.EventsLogs{
Message: &events.EventsLogs_AgentLog{
AgentLog: &events.AgentLog{
Timestamp: timestamp,
Message: chunk,
Level: level,
-3
View File
@@ -104,18 +104,15 @@ func (aw *AckWriter) sendMessages() {
}
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
// Write message ID
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
return err
}
// Write message length
messageLen := uint32(len(p))
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
return err
}
// Write message content
if _, err := aw.conn.Write(p); err != nil {
return err
}
+1 -1
View File
@@ -13,7 +13,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/manager"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/proto"
)
+12 -82
View File
@@ -4,97 +4,18 @@ package manager
import (
"fmt"
"log/slog"
"net"
"regexp"
"strconv"
"github.com/mdlayher/vsock"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
ManagerVsockPort = 9997
messageSize int = 1024
)
var (
errFailedToParseCID = fmt.Errorf("failed to parse computation ID")
errComputationNotFound = fmt.Errorf("computation not found")
)
// RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock.
func (ms *managerService) RetrieveAgentEventsLogs() {
l, err := vsock.Listen(ManagerVsockPort, nil)
if err != nil {
ms.logger.Warn(err.Error())
return
}
for {
conn, err := l.Accept()
if err != nil {
ms.logger.Warn(err.Error())
continue
}
go ms.handleConnection(conn)
}
}
func (ms *managerService) handleConnection(conn net.Conn) {
defer conn.Close()
cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String())
if err != nil {
ms.logger.Warn(err.Error())
return
}
ackReader := internalvsock.NewAckReader(conn)
for {
var message manager.ClientStreamMessage
data, err := ackReader.Read()
if err != nil {
go ms.reportBrokenConnection(cmpID)
ms.logger.Warn(err.Error())
return
}
if err := proto.Unmarshal(data, &message); err != nil {
ms.logger.Warn(err.Error())
continue
}
ms.eventsChan <- &message
args := []any{}
switch message.Message.(type) {
case *manager.ClientStreamMessage_AgentEvent:
args = append(args, slog.Group("agent-event",
slog.String("event-type", message.GetAgentEvent().GetEventType()),
slog.String("computation-id", message.GetAgentEvent().GetComputationId()),
slog.String("status", message.GetAgentEvent().GetStatus()),
slog.String("originator", message.GetAgentEvent().GetOriginator()),
slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()),
slog.String("details", string(message.GetAgentEvent().GetDetails()))))
case *manager.ClientStreamMessage_AgentLog:
args = append(args, slog.Group("agent-log",
slog.String("computation-id", message.GetAgentLog().GetComputationId()),
slog.String("level", message.GetAgentLog().GetLevel()),
slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()),
slog.String("message", message.GetAgentLog().GetMessage())))
}
ms.logger.Info("", args...)
}
}
func (ms *managerService) computationIDFromAddress(address string) (string, error) {
re := regexp.MustCompile(`vm\((\d+)\)`)
matches := re.FindStringSubmatch(address)
@@ -122,9 +43,9 @@ func (ms *managerService) findComputationID(cid int) (string, error) {
}
func (ms *managerService) reportBrokenConnection(cmpID string) {
ms.eventsChan <- &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentEvent{
AgentEvent: &AgentEvent{
EventType: ms.vms[cmpID].State(),
ComputationId: cmpID,
Status: manager.Disconnected.String(),
@@ -134,3 +55,12 @@ func (ms *managerService) reportBrokenConnection(cmpID string) {
},
}
}
func (ms *managerService) ReportBrokenConnection(addr string) {
cmpID, err := ms.computationIDFromAddress(addr)
if err != nil {
ms.logger.Warn(err.Error())
return
}
ms.reportBrokenConnection(cmpID)
}
+4 -109
View File
@@ -3,83 +3,19 @@
package manager
import (
"net"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
type MockConn struct {
mock.Mock
}
func (m *MockConn) Read(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockConn) LocalAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) RemoteAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) SetDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
type MockAddr struct {
mock.Mock
}
func (m *MockAddr) Network() string {
args := m.Called()
return args.String(0)
}
func (m *MockAddr) String() string {
args := m.Called()
return args.String(0)
}
func TestComputationIDFromAddress(t *testing.T) {
ms := &managerService{
vms: map[string]vm.VM{
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
"comp2": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}, make(chan *manager.ClientStreamMessage), "comp2"),
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, func(event interface{}) error { return nil }, "comp1"),
"comp2": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}, func(event interface{}) error { return nil }, "comp2"),
},
}
@@ -107,52 +43,11 @@ func TestComputationIDFromAddress(t *testing.T) {
}
}
func TestHandleConnection(t *testing.T) {
ms := &managerService{
vms: map[string]vm.VM{
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
},
eventsChan: make(chan *manager.ClientStreamMessage, 1),
logger: mglog.NewMock(),
}
mockConn := new(MockConn)
mockAddr := new(MockAddr)
mockConn.On("RemoteAddr").Return(mockAddr)
mockConn.On("Close").Return(nil)
mockAddr.On("String").Return("vm(3)")
msg := &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
EventType: manager.VmProvision.String(),
ComputationId: "comp1",
Status: manager.VmProvision.String(),
Timestamp: timestamppb.Now(),
Originator: "agent",
},
},
}
msgBytes, _ := proto.Marshal(msg)
mockConn.On("Read", mock.Anything).Return(len(msgBytes), nil).Run(func(args mock.Arguments) {
copy(args.Get(0).([]byte), msgBytes)
}).Once()
mockConn.On("Read", mock.Anything).Return(0, net.ErrClosed)
go ms.handleConnection(mockConn)
receivedMsg := <-ms.eventsChan
assert.Equal(t, msg.GetAgentEvent().EventType, receivedMsg.GetAgentEvent().EventType)
assert.Equal(t, msg.GetAgentEvent().ComputationId, receivedMsg.GetAgentEvent().ComputationId)
}
func TestReportBrokenConnection(t *testing.T) {
ms := &managerService{
eventsChan: make(chan *manager.ClientStreamMessage, 1),
eventsChan: make(chan *ClientStreamMessage, 1),
vms: map[string]vm.VM{
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, func(event interface{}) error { return nil }, "comp1"),
},
}
+24 -25
View File
@@ -10,7 +10,6 @@ import (
"github.com/absmach/magistrala/pkg/errors"
"github.com/ultravioletrs/cocos/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
)
@@ -22,15 +21,15 @@ var (
)
type ManagerClient struct {
stream pkgmanager.ManagerService_ProcessClient
stream manager.ManagerService_ProcessClient
svc manager.Service
messageQueue chan *pkgmanager.ClientStreamMessage
messageQueue chan *manager.ClientStreamMessage
logger *slog.Logger
runReqManager *runRequestManager
}
// NewClient returns new gRPC client instance.
func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *pkgmanager.ClientStreamMessage, logger *slog.Logger) ManagerClient {
func NewClient(stream manager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *manager.ClientStreamMessage, logger *slog.Logger) ManagerClient {
return ManagerClient{
stream: stream,
svc: svc,
@@ -71,15 +70,15 @@ func (client ManagerClient) handleIncomingMessages(ctx context.Context) error {
}
}
func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkgmanager.ServerStreamMessage) error {
func (client ManagerClient) processIncomingMessage(ctx context.Context, req *manager.ServerStreamMessage) error {
switch mes := req.Message.(type) {
case *pkgmanager.ServerStreamMessage_RunReqChunks:
case *manager.ServerStreamMessage_RunReqChunks:
return client.handleRunReqChunks(ctx, mes)
case *pkgmanager.ServerStreamMessage_TerminateReq:
case *manager.ServerStreamMessage_TerminateReq:
return client.handleTerminateReq(mes)
case *pkgmanager.ServerStreamMessage_StopComputation:
case *manager.ServerStreamMessage_StopComputation:
go client.handleStopComputation(ctx, mes)
case *pkgmanager.ServerStreamMessage_BackendInfoReq:
case *manager.ServerStreamMessage_BackendInfoReq:
go client.handleBackendInfoReq(ctx, mes)
default:
return errors.New("unknown message type")
@@ -87,11 +86,11 @@ func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkg
return nil
}
func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgmanager.ServerStreamMessage_RunReqChunks) error {
func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *manager.ServerStreamMessage_RunReqChunks) error {
buffer, complete := client.runReqManager.addChunk(mes.RunReqChunks.Id, mes.RunReqChunks.Data, mes.RunReqChunks.IsLast)
if complete {
var runReq pkgmanager.ComputationRunReq
var runReq manager.ComputationRunReq
if err := proto.Unmarshal(buffer, &runReq); err != nil {
return errors.Wrap(err, errCorruptedManifest)
}
@@ -102,50 +101,50 @@ func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgman
return nil
}
func (client ManagerClient) executeRun(ctx context.Context, runReq *pkgmanager.ComputationRunReq) {
func (client ManagerClient) executeRun(ctx context.Context, runReq *manager.ComputationRunReq) {
port, err := client.svc.Run(ctx, runReq)
if err != nil {
client.logger.Warn(err.Error())
return
}
runRes := &pkgmanager.ClientStreamMessage_RunRes{
RunRes: &pkgmanager.RunResponse{
runRes := &manager.ClientStreamMessage_RunRes{
RunRes: &manager.RunResponse{
AgentPort: port,
ComputationId: runReq.Id,
},
}
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: runRes})
client.sendMessage(&manager.ClientStreamMessage{Message: runRes})
}
func (client ManagerClient) handleTerminateReq(mes *pkgmanager.ServerStreamMessage_TerminateReq) error {
func (client ManagerClient) handleTerminateReq(mes *manager.ServerStreamMessage_TerminateReq) error {
return errors.Wrap(errTerminationFromServer, errors.New(mes.TerminateReq.Message))
}
func (client ManagerClient) handleStopComputation(ctx context.Context, mes *pkgmanager.ServerStreamMessage_StopComputation) {
msg := &pkgmanager.ClientStreamMessage_StopComputationRes{
StopComputationRes: &pkgmanager.StopComputationResponse{
func (client ManagerClient) handleStopComputation(ctx context.Context, mes *manager.ServerStreamMessage_StopComputation) {
msg := &manager.ClientStreamMessage_StopComputationRes{
StopComputationRes: &manager.StopComputationResponse{
ComputationId: mes.StopComputation.ComputationId,
},
}
if err := client.svc.Stop(ctx, mes.StopComputation.ComputationId); err != nil {
msg.StopComputationRes.Message = err.Error()
}
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: msg})
client.sendMessage(&manager.ClientStreamMessage{Message: msg})
}
func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *pkgmanager.ServerStreamMessage_BackendInfoReq) {
func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *manager.ServerStreamMessage_BackendInfoReq) {
res, err := client.svc.FetchBackendInfo(ctx, mes.BackendInfoReq.Id)
if err != nil {
client.logger.Warn(err.Error())
return
}
info := &pkgmanager.ClientStreamMessage_BackendInfo{
BackendInfo: &pkgmanager.BackendInfo{
info := &manager.ClientStreamMessage_BackendInfo{
BackendInfo: &manager.BackendInfo{
Info: res,
Id: mes.BackendInfoReq.Id,
},
}
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: info})
client.sendMessage(&manager.ClientStreamMessage{Message: info})
}
func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error {
@@ -161,7 +160,7 @@ func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error {
}
}
func (client ManagerClient) sendMessage(mes *pkgmanager.ClientStreamMessage) {
func (client ManagerClient) sendMessage(mes *manager.ClientStreamMessage) {
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
defer cancel()
+26 -26
View File
@@ -11,8 +11,8 @@ import (
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/mocks"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
@@ -22,12 +22,12 @@ type mockStream struct {
grpc.ClientStream
}
func (m *mockStream) Recv() (*pkgmanager.ServerStreamMessage, error) {
func (m *mockStream) Recv() (*manager.ServerStreamMessage, error) {
args := m.Called()
return args.Get(0).(*pkgmanager.ServerStreamMessage), args.Error(1)
return args.Get(0).(*manager.ServerStreamMessage), args.Error(1)
}
func (m *mockStream) Send(msg *pkgmanager.ClientStreamMessage) error {
func (m *mockStream) Send(msg *manager.ClientStreamMessage) error {
args := m.Called(msg)
return args.Error(0)
}
@@ -35,7 +35,7 @@ func (m *mockStream) Send(msg *pkgmanager.ClientStreamMessage) error {
func TestManagerClient_Process(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
@@ -43,7 +43,7 @@ func TestManagerClient_Process(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mockStream.On("Recv").Return(&pkgmanager.ServerStreamMessage{Message: &pkgmanager.ServerStreamMessage_StopComputation{StopComputation: &pkgmanager.StopComputation{}}}, nil).Maybe()
mockStream.On("Recv").Return(&manager.ServerStreamMessage{Message: &manager.ServerStreamMessage_StopComputation{StopComputation: &manager.StopComputation{}}}, nil).Maybe()
mockStream.On("Send", mock.Anything).Return(nil).Maybe()
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil).Maybe()
@@ -57,25 +57,25 @@ func TestManagerClient_Process(t *testing.T) {
func TestManagerClient_handleRunReqChunks(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
runReq := &pkgmanager.ComputationRunReq{
runReq := &manager.ComputationRunReq{
Id: "test-id",
}
runReqBytes, _ := proto.Marshal(runReq)
chunk1 := &pkgmanager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &pkgmanager.RunReqChunks{
chunk1 := &manager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &manager.RunReqChunks{
Id: "chunk-1",
Data: runReqBytes[:len(runReqBytes)/2],
IsLast: false,
},
}
chunk2 := &pkgmanager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &pkgmanager.RunReqChunks{
chunk2 := &manager.ServerStreamMessage_RunReqChunks{
RunReqChunks: &manager.RunReqChunks{
Id: "chunk-1",
Data: runReqBytes[len(runReqBytes)/2:],
IsLast: true,
@@ -97,7 +97,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
runRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_RunRes)
runRes, ok := msg.Message.(*manager.ClientStreamMessage_RunRes)
assert.True(t, ok)
assert.Equal(t, "8080", runRes.RunRes.AgentPort)
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
@@ -106,8 +106,8 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
func TestManagerClient_handleTerminateReq(t *testing.T) {
client := ManagerClient{}
terminateReq := &pkgmanager.ServerStreamMessage_TerminateReq{
TerminateReq: &pkgmanager.Terminate{
terminateReq := &manager.ServerStreamMessage_TerminateReq{
TerminateReq: &manager.Terminate{
Message: "Test termination",
},
}
@@ -121,13 +121,13 @@ func TestManagerClient_handleTerminateReq(t *testing.T) {
func TestManagerClient_handleStopComputation(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
stopReq := &pkgmanager.ServerStreamMessage_StopComputation{
StopComputation: &pkgmanager.StopComputation{
stopReq := &manager.ServerStreamMessage_StopComputation{
StopComputation: &manager.StopComputation{
ComputationId: "test-comp-id",
},
}
@@ -143,7 +143,7 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
stopRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_StopComputationRes)
stopRes, ok := msg.Message.(*manager.ClientStreamMessage_StopComputationRes)
assert.True(t, ok)
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
assert.Empty(t, stopRes.StopComputationRes.Message)
@@ -153,13 +153,13 @@ func TestManagerClient_handleBackendInfoReq(t *testing.T) {
t.Run("success", func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &pkgmanager.BackendInfoReq{
infoReq := &manager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &manager.BackendInfoReq{
Id: "test-info-id",
},
}
@@ -175,7 +175,7 @@ func TestManagerClient_handleBackendInfoReq(t *testing.T) {
assert.Len(t, messageQueue, 1)
msg := <-messageQueue
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
infoRes, ok := msg.Message.(*manager.ClientStreamMessage_BackendInfo)
assert.True(t, ok)
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
@@ -183,13 +183,13 @@ func TestManagerClient_handleBackendInfoReq(t *testing.T) {
t.Run("error", func(t *testing.T) {
mockStream := new(mockStream)
mockSvc := new(mocks.Service)
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
messageQueue := make(chan *manager.ClientStreamMessage, 10)
logger := mglog.NewMock()
client := NewClient(mockStream, mockSvc, messageQueue, logger)
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &pkgmanager.BackendInfoReq{
infoReq := &manager.ServerStreamMessage_BackendInfoReq{
BackendInfoReq: &manager.BackendInfoReq{
Id: "test-info-id",
},
}
+1 -1
View File
@@ -9,7 +9,7 @@ import (
"io"
"time"
"github.com/ultravioletrs/cocos/pkg/manager"
"github.com/ultravioletrs/cocos/manager"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
+1 -1
View File
@@ -10,7 +10,7 @@ import (
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/pkg/manager"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/peer"
)
+5 -6
View File
@@ -13,7 +13,6 @@ import (
"time"
"github.com/ultravioletrs/cocos/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
var _ manager.Service = (*loggingMiddleware)(nil)
@@ -28,7 +27,7 @@ func LoggingMiddleware(svc manager.Service, logger *slog.Logger) manager.Service
return &loggingMiddleware{logger, svc}
}
func (lm *loggingMiddleware) Run(ctx context.Context, mc *pkgmanager.ComputationRunReq) (agentAddr string, err error) {
func (lm *loggingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (agentAddr string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method Run for computation took %s to complete", time.Since(begin))
if err != nil {
@@ -54,10 +53,6 @@ func (lm *loggingMiddleware) Stop(ctx context.Context, computationID string) (er
return lm.svc.Stop(ctx, computationID)
}
func (lm *loggingMiddleware) RetrieveAgentEventsLogs() {
lm.svc.RetrieveAgentEventsLogs()
}
func (lm *loggingMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) (body []byte, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method FetchBackendInfo for computation %s took %s to complete", cmpId, time.Since(begin))
@@ -71,3 +66,7 @@ func (lm *loggingMiddleware) FetchBackendInfo(ctx context.Context, cmpId string)
return lm.svc.FetchBackendInfo(ctx, cmpId)
}
func (lm *loggingMiddleware) ReportBrokenConnection(addr string) {
lm.svc.ReportBrokenConnection(addr)
}
+5 -6
View File
@@ -12,7 +12,6 @@ import (
"github.com/go-kit/kit/metrics"
"github.com/ultravioletrs/cocos/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
var _ manager.Service = (*metricsMiddleware)(nil)
@@ -33,7 +32,7 @@ func MetricsMiddleware(svc manager.Service, counter metrics.Counter, latency met
}
}
func (ms *metricsMiddleware) Run(ctx context.Context, mc *pkgmanager.ComputationRunReq) (string, error) {
func (ms *metricsMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
defer func(begin time.Time) {
ms.counter.With("method", "Run").Add(1)
ms.latency.With("method", "Run").Observe(time.Since(begin).Seconds())
@@ -51,10 +50,6 @@ func (ms *metricsMiddleware) Stop(ctx context.Context, computationID string) err
return ms.svc.Stop(ctx, computationID)
}
func (ms *metricsMiddleware) RetrieveAgentEventsLogs() {
ms.svc.RetrieveAgentEventsLogs()
}
func (ms *metricsMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) ([]byte, error) {
defer func(begin time.Time) {
ms.counter.With("method", "FetchBackendInfo").Add(1)
@@ -63,3 +58,7 @@ func (ms *metricsMiddleware) FetchBackendInfo(ctx context.Context, cmpId string)
return ms.svc.FetchBackendInfo(ctx, cmpId)
}
func (ms *metricsMiddleware) ReportBrokenConnection(addr string) {
ms.svc.ReportBrokenConnection(addr)
}
+2 -2
View File
@@ -14,8 +14,8 @@ import (
"os"
"os/exec"
"github.com/ultravioletrs/cocos/cli"
"github.com/ultravioletrs/cocos/manager/qemu"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/virtee/sev-snp-measure-go/cpuid"
"github.com/virtee/sev-snp-measure-go/guest"
"github.com/virtee/sev-snp-measure-go/vmmtypes"
@@ -48,7 +48,7 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
return nil, err
}
var backendInfo cli.AttestationConfiguration
var backendInfo grpc.AttestationConfiguration
if err = json.Unmarshal(f, &backendInfo); err != nil {
return nil, err
+9
View File
@@ -0,0 +1,9 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import "context"
type Listener interface {
Listen(ctx context.Context)
}
+128
View File
@@ -0,0 +1,128 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"context"
"log/slog"
"net"
"github.com/mdlayher/vsock"
agentevents "github.com/ultravioletrs/cocos/agent/events"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/proto"
)
const (
ManagerVsockPort = 9997
messageSize int = 1024 * 1024
)
type ReportBrokenConnectionFunc func(address string)
type events struct {
reportBrokenConnection ReportBrokenConnectionFunc
lis net.Listener
logger *slog.Logger
eventsChan chan *manager.ClientStreamMessage
}
func New(logger *slog.Logger, reportBrokenConnection ReportBrokenConnectionFunc, eventsChan chan *manager.ClientStreamMessage) Listener {
l, err := vsock.Listen(ManagerVsockPort, nil)
if err != nil {
return nil
}
return &events{
lis: l,
reportBrokenConnection: reportBrokenConnection,
logger: logger,
eventsChan: eventsChan,
}
}
func (e *events) Listen(ctx context.Context) {
for {
select {
case <-ctx.Done():
e.logger.Info("Listener shutting down")
return
default:
conn, err := e.lis.Accept()
if err != nil {
e.logger.Warn(err.Error())
continue
}
go e.handleConnection(conn)
}
}
}
func (e *events) handleConnection(conn net.Conn) {
defer conn.Close()
ackReader := internalvsock.NewAckReader(conn)
for {
var message agentevents.EventsLogs
data, err := ackReader.Read()
if err != nil {
go e.reportBrokenConnection(conn.RemoteAddr().String())
e.logger.Warn(err.Error())
return
}
if err := proto.Unmarshal(data, &message); err != nil {
e.logger.Warn(err.Error())
continue
}
var mes manager.ClientStreamMessage
args := []any{}
switch message.Message.(type) {
case *agentevents.EventsLogs_AgentEvent:
args = append(args, slog.Group("agent-event",
slog.String("event-type", message.GetAgentEvent().GetEventType()),
slog.String("computation-id", message.GetAgentEvent().GetComputationId()),
slog.String("status", message.GetAgentEvent().GetStatus()),
slog.String("originator", message.GetAgentEvent().GetOriginator()),
slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()),
slog.String("details", string(message.GetAgentEvent().GetDetails()))))
mes = manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
EventType: message.GetAgentEvent().GetEventType(),
ComputationId: message.GetAgentEvent().GetComputationId(),
Status: message.GetAgentEvent().GetStatus(),
Originator: message.GetAgentEvent().GetOriginator(),
Timestamp: message.GetAgentEvent().GetTimestamp(),
Details: message.GetAgentEvent().GetDetails(),
},
},
}
case *agentevents.EventsLogs_AgentLog:
args = append(args, slog.Group("agent-log",
slog.String("computation-id", message.GetAgentLog().GetComputationId()),
slog.String("level", message.GetAgentLog().GetLevel()),
slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()),
slog.String("message", message.GetAgentLog().GetMessage())))
mes = manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentLog{
AgentLog: &manager.AgentLog{
ComputationId: message.GetAgentLog().GetComputationId(),
Level: message.GetAgentLog().GetLevel(),
Timestamp: message.GetAgentLog().GetTimestamp(),
Message: message.GetAgentLog().GetMessage(),
},
},
}
}
e.eventsChan <- &mes
e.logger.Info("", args...)
}
}
+251
View File
@@ -0,0 +1,251 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package events
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"log/slog"
"net"
"os"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
type MockVsockListener struct {
mock.Mock
}
func (m *MockVsockListener) Accept() (net.Conn, error) {
args := m.Called()
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *MockVsockListener) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockVsockListener) Addr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
var _ net.Conn = (*MockConn)(nil)
type MockConn struct {
mock.Mock
}
func (m *MockConn) Read(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockConn) LocalAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) RemoteAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) SetDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func TestNew(t *testing.T) {
if !vsockDeviceExists() {
t.Skip("Skipping test: vsock device not available")
}
logger := &slog.Logger{}
reportBrokenConnection := func(address string) {}
eventsChan := make(chan *manager.ClientStreamMessage)
e := New(logger, reportBrokenConnection, eventsChan)
assert.NotNil(t, e)
assert.IsType(t, &events{}, e)
}
func TestListen(t *testing.T) {
mockListener := new(MockVsockListener)
mockConn := new(MockConn)
e := &events{
lis: mockListener,
logger: mglog.NewMock(),
}
mockListener.On("Accept").Return(mockConn, fmt.Errorf("mock error")).Once()
mockListener.On("Accept").Return(mockConn, nil)
mockConn.On("Close").Return(nil)
mockConn.On("Read", mock.Anything).Return(0, nil)
go e.Listen(context.Background())
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
}
func vsockDeviceExists() bool {
fs, err := os.Stat("/dev/vsock")
if err != nil {
return false
}
if fs.Mode()&os.ModeDevice == 0 {
return false
}
return true
}
type MockConnWithBuffer struct {
mock.Mock
readBuf *bytes.Buffer
writeBuf *bytes.Buffer
}
func NewMockConnWithBuffer() *MockConnWithBuffer {
return &MockConnWithBuffer{
readBuf: new(bytes.Buffer),
writeBuf: new(bytes.Buffer),
}
}
func (m *MockConnWithBuffer) Read(b []byte) (n int, err error) {
return m.readBuf.Read(b)
}
func (m *MockConnWithBuffer) Write(b []byte) (n int, err error) {
return m.writeBuf.Write(b)
}
func (m *MockConnWithBuffer) Close() error {
return nil
}
func (m *MockConnWithBuffer) LocalAddr() net.Addr {
return nil
}
func (m *MockConnWithBuffer) RemoteAddr() net.Addr {
return &net.IPAddr{IP: net.ParseIP("localhost")}
}
func (m *MockConnWithBuffer) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConnWithBuffer) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConnWithBuffer) SetWriteDeadline(t time.Time) error {
return nil
}
func TestHandleConnection(t *testing.T) {
mockConn := NewMockConnWithBuffer()
eventsChan := make(chan *manager.ClientStreamMessage, 1)
e := &events{
logger: mglog.NewMock(),
eventsChan: eventsChan,
reportBrokenConnection: func(address string) {},
}
message := &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
EventType: "test_event",
ComputationId: "test_computation",
Status: "test_status",
Originator: "test_originator",
Timestamp: timestamppb.Now(),
Details: []byte("test_details"),
},
},
}
data, err := proto.Marshal(message)
assert.NoError(t, err)
messageID := uint32(1)
err = binary.Write(mockConn.readBuf, binary.LittleEndian, messageID)
assert.NoError(t, err)
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(len(data)))
assert.NoError(t, err)
_, err = mockConn.readBuf.Write(data)
assert.NoError(t, err)
// Add EOF to signal end of stream
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
assert.NoError(t, err)
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
assert.NoError(t, err)
done := make(chan struct{})
go func() {
e.handleConnection(mockConn)
close(done)
}()
var receivedMessage *manager.ClientStreamMessage
select {
case receivedMessage = <-eventsChan:
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for message in eventsChan")
}
assert.NotNil(t, receivedMessage)
select {
case <-done:
// handleConnection has exited
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for handleConnection to exit")
}
// Check if ack was written
var receivedAck uint32
err = binary.Read(mockConn.writeBuf, binary.LittleEndian, &receivedAck)
assert.NoError(t, err)
assert.Equal(t, messageID, receivedAck)
// Ensure no unexpected calls were made on the mock
mockConn.AssertExpectations(t)
}
+1 -1
View File
@@ -7,7 +7,7 @@ import (
"context"
"testing"
"github.com/ultravioletrs/cocos/pkg/manager"
"github.com/ultravioletrs/cocos/manager"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
)
+8 -9
View File
@@ -9,8 +9,7 @@ import (
context "context"
mock "github.com/stretchr/testify/mock"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
manager "github.com/ultravioletrs/cocos/manager"
)
// Service is an autogenerated mock type for the Service type
@@ -48,13 +47,13 @@ func (_m *Service) FetchBackendInfo(ctx context.Context, computationID string) (
return r0, r1
}
// RetrieveAgentEventsLogs provides a mock function with given fields:
func (_m *Service) RetrieveAgentEventsLogs() {
_m.Called()
// ReportBrokenConnection provides a mock function with given fields: addr
func (_m *Service) ReportBrokenConnection(addr string) {
_m.Called(addr)
}
// Run provides a mock function with given fields: ctx, c
func (_m *Service) Run(ctx context.Context, c *pkgmanager.ComputationRunReq) (string, error) {
func (_m *Service) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) {
ret := _m.Called(ctx, c)
if len(ret) == 0 {
@@ -63,16 +62,16 @@ func (_m *Service) Run(ctx context.Context, c *pkgmanager.ComputationRunReq) (st
var r0 string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) (string, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) (string, error)); ok {
return rf(ctx, c)
}
if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) string); ok {
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) string); ok {
r0 = rf(ctx, c)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(context.Context, *pkgmanager.ComputationRunReq) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, *manager.ComputationRunReq) error); ok {
r1 = rf(ctx, c)
} else {
r1 = ret.Error(1)
+27 -31
View File
@@ -26,19 +26,19 @@ const (
)
type qemuVM struct {
config Config
cmd *exec.Cmd
logsChan chan *manager.ClientStreamMessage
computationId string
config Config
cmd *exec.Cmd
eventsLogsSender vm.EventSender
computationId string
vm.StateMachine
}
func NewVM(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) vm.VM {
func NewVM(config interface{}, eventsLogsSender vm.EventSender, computationId string) vm.VM {
return &qemuVM{
config: config.(Config),
logsChan: logsChan,
computationId: computationId,
StateMachine: vm.NewStateMachine(),
config: config.(Config),
eventsLogsSender: eventsLogsSender,
computationId: computationId,
StateMachine: vm.NewStateMachine(),
}
}
@@ -74,8 +74,8 @@ func (v *qemuVM) Start() (err error) {
}
v.cmd = exec.Command(exe, args...)
v.cmd.Stdout = &vm.Stdout{LogsChan: v.logsChan, ComputationId: v.computationId}
v.cmd.Stderr = &vm.Stderr{LogsChan: v.logsChan, ComputationId: v.computationId, StateMachine: v.StateMachine}
v.cmd.Stdout = &vm.Stdout{ComputationId: v.computationId, EventSender: v.eventsLogsSender}
v.cmd.Stderr = &vm.Stderr{EventSender: v.eventsLogsSender, ComputationId: v.computationId, StateMachine: v.StateMachine}
return v.cmd.Start()
}
@@ -84,16 +84,14 @@ func (v *qemuVM) Stop() error {
defer func() {
err := v.StateMachine.Transition(manager.StopComputationRun)
if err != nil {
v.logsChan <- &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
ComputationId: v.computationId,
EventType: v.StateMachine.State(),
Status: manager.Warning.String(),
Timestamp: timestamppb.Now(),
Originator: "manager",
},
},
if err := v.eventsLogsSender(&vm.Event{
EventType: v.StateMachine.State(),
Timestamp: timestamppb.Now(),
ComputationId: v.computationId,
Originator: "manager",
Status: manager.Warning.String(),
}); err != nil {
return
}
}
}()
@@ -160,16 +158,14 @@ func (v *qemuVM) executableAndArgs() (string, []string, error) {
func (v *qemuVM) checkVMProcessPeriodically() {
for {
if !processExists(v.GetProcess()) {
v.logsChan <- &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
ComputationId: v.computationId,
EventType: v.StateMachine.State(),
Status: manager.Stopped.String(),
Timestamp: timestamppb.Now(),
Originator: "manager",
},
},
if err := v.eventsLogsSender(&vm.Event{
EventType: v.StateMachine.State(),
Timestamp: timestamppb.Now(),
ComputationId: v.computationId,
Originator: "manager",
Status: manager.Stopped.String(),
}); err != nil {
return
}
break
}
+22 -23
View File
@@ -11,16 +11,15 @@ import (
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/manager/vm/mocks"
"github.com/ultravioletrs/cocos/pkg/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
const testComputationID = "test-computation"
func TestNewVM(t *testing.T) {
config := Config{}
logsChan := make(chan *manager.ClientStreamMessage)
vm := NewVM(config, logsChan, testComputationID)
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID)
assert.NotNil(t, vm)
assert.IsType(t, &qemuVM{}, vm)
@@ -38,9 +37,8 @@ func TestStart(t *testing.T) {
},
QemuBinPath: "echo",
}
logsChan := make(chan *manager.ClientStreamMessage)
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
err = vm.Start()
assert.NoError(t, err)
@@ -62,9 +60,8 @@ func TestStartSudo(t *testing.T) {
QemuBinPath: "echo",
UseSudo: true,
}
logsChan := make(chan *manager.ClientStreamMessage)
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
err = vm.Start()
assert.NoError(t, err)
@@ -79,7 +76,7 @@ func TestStop(t *testing.T) {
err := cmd.Start()
assert.NoError(t, err)
sm := new(mocks.StateMachine)
sm.On("Transition", manager.StopComputationRun).Return(nil)
sm.On("Transition", pkgmanager.StopComputationRun).Return(nil)
vm := &qemuVM{
cmd: &exec.Cmd{
@@ -96,21 +93,19 @@ func TestStop(t *testing.T) {
err := cmd.Start()
assert.NoError(t, err)
sm := new(mocks.StateMachine)
sm.On("Transition", manager.StopComputationRun).Return(assert.AnError)
sm.On("State").Return(manager.Stopped.String())
sm.On("Transition", pkgmanager.StopComputationRun).Return(assert.AnError)
sm.On("State").Return(pkgmanager.Stopped.String())
vm := &qemuVM{
cmd: &exec.Cmd{
Process: cmd.Process,
},
StateMachine: sm,
logsChan: make(chan *manager.ClientStreamMessage),
eventsLogsSender: func(event interface{}) error {
return nil
},
}
go func() {
<-vm.logsChan
}()
err = vm.Stop()
assert.NoError(t, err)
})
@@ -168,9 +163,12 @@ func TestGetConfig(t *testing.T) {
}
func TestCheckVMProcessPeriodically(t *testing.T) {
logsChan := make(chan *manager.ClientStreamMessage, 1)
vm := &qemuVM{
logsChan: logsChan,
logsChan := make(chan interface{}, 1)
vmi := &qemuVM{
eventsLogsSender: func(event interface{}) error {
logsChan <- event
return nil
},
computationId: testComputationID,
cmd: &exec.Cmd{
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
@@ -178,14 +176,15 @@ func TestCheckVMProcessPeriodically(t *testing.T) {
StateMachine: vm.NewStateMachine(),
}
go vm.checkVMProcessPeriodically()
go vmi.checkVMProcessPeriodically()
select {
case msg := <-logsChan:
assert.NotNil(t, msg.GetAgentEvent())
assert.Equal(t, testComputationID, msg.GetAgentEvent().ComputationId)
assert.Equal(t, manager.VmProvision.String(), msg.GetAgentEvent().EventType)
assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status)
assert.NotNil(t, msg)
msgE := msg.(*vm.Event)
assert.Equal(t, testComputationID, msgE.ComputationId)
assert.Equal(t, pkgmanager.VmProvision.String(), msgE.EventType)
assert.Equal(t, pkgmanager.Stopped.String(), msgE.Status)
case <-time.After(2 * interval):
t.Fatal("Timeout waiting for VM stopped message")
}
+41 -11
View File
@@ -57,13 +57,13 @@ var (
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type Service interface {
// Run create a computation.
Run(ctx context.Context, c *manager.ComputationRunReq) (string, error)
Run(ctx context.Context, c *ComputationRunReq) (string, error)
// Stop stops a computation.
Stop(ctx context.Context, computationID string) error
// RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock.
RetrieveAgentEventsLogs()
// FetchBackendInfo measures and fetches the backend information.
FetchBackendInfo(ctx context.Context, computationID string) ([]byte, error)
// ReportBrokenConnection reports a broken connection.
ReportBrokenConnection(addr string)
}
type managerService struct {
@@ -71,7 +71,7 @@ type managerService struct {
qemuCfg qemu.Config
backendMeasurementBinaryPath string
logger *slog.Logger
eventsChan chan *manager.ClientStreamMessage
eventsChan chan *ClientStreamMessage
vms map[string]vm.VM
vmFactory vm.Provider
portRangeMin int
@@ -82,7 +82,7 @@ type managerService struct {
var _ Service = (*managerService)(nil)
// New instantiates the manager service implementation.
func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, eventsChan chan *manager.ClientStreamMessage, vmFactory vm.Provider) (Service, error) {
func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, eventsChan chan *ClientStreamMessage, vmFactory vm.Provider) (Service, error) {
start, end, err := decodeRange(cfg.HostFwdRange)
if err != nil {
return nil, err
@@ -112,7 +112,7 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger,
return ms, nil
}
func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) {
func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string, error) {
ms.publishEvent(manager.VmProvision.String(), c.Id, manager.Starting.String(), json.RawMessage{})
ac := agent.Computation{
ID: c.Id,
@@ -164,7 +164,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
// Define host-data value of QEMU for SEV-SNP, with a base64 encoding of the computation hash.
ms.qemuCfg.SevConfig.HostData = base64.StdEncoding.EncodeToString(ch[:])
cvm := ms.vmFactory(ms.qemuCfg, ms.eventsChan, c.Id)
cvm := ms.vmFactory(ms.qemuCfg, ms.eventsLogsSender, c.Id)
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.InProgress.String(), json.RawMessage{})
if err = cvm.Start(); err != nil {
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
@@ -267,9 +267,9 @@ func checkPortisFree(port int) bool {
}
func (ms *managerService) publishEvent(event, cmpID, status string, details json.RawMessage) {
ms.eventsChan <- &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentEvent{
AgentEvent: &AgentEvent{
EventType: event,
ComputationId: cmpID,
Status: status,
@@ -329,7 +329,7 @@ func (ms *managerService) restoreVMs() error {
continue
}
cvm := ms.vmFactory(state.Config, ms.eventsChan, state.ID)
cvm := ms.vmFactory(state.Config, ms.eventsLogsSender, state.ID)
if err = cvm.SetProcess(state.PID); err != nil {
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
@@ -362,3 +362,33 @@ func (ms *managerService) processExists(pid int) bool {
}
return false
}
func (ms *managerService) eventsLogsSender(e interface{}) error {
switch msg := e.(type) {
case *vm.Event:
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentEvent{
AgentEvent: &AgentEvent{
EventType: msg.EventType,
Timestamp: msg.Timestamp,
ComputationId: msg.ComputationId,
Originator: msg.Originator,
Status: msg.Status,
Details: msg.Details,
},
},
}
case *vm.Log:
ms.eventsChan <- &ClientStreamMessage{
Message: &ClientStreamMessage_AgentLog{
AgentLog: &AgentLog{
ComputationId: msg.ComputationId,
Level: msg.Level,
Timestamp: msg.Timestamp,
Message: msg.Message,
},
},
}
}
return nil
}
+19 -20
View File
@@ -22,7 +22,6 @@ import (
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
"github.com/ultravioletrs/cocos/manager/vm"
"github.com/ultravioletrs/cocos/manager/vm/mocks"
"github.com/ultravioletrs/cocos/pkg/manager"
)
func TestNew(t *testing.T) {
@@ -30,7 +29,7 @@ func TestNew(t *testing.T) {
HostFwdRange: "6000-6100",
}
logger := slog.Default()
eventsChan := make(chan *manager.ClientStreamMessage)
eventsChan := make(chan *ClientStreamMessage)
vmf := new(mocks.Provider)
service, err := New(cfg, "", logger, eventsChan, vmf.Execute)
@@ -47,59 +46,59 @@ func TestRun(t *testing.T) {
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
tests := []struct {
name string
req *manager.ComputationRunReq
req *ComputationRunReq
vmStartError error
expectedError error
}{
{
name: "Successful run",
req: &manager.ComputationRunReq{
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &manager.Algorithm{
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &manager.AgentConfig{},
AgentConfig: &AgentConfig{},
},
vmStartError: nil,
expectedError: nil,
},
{
name: "VM start failure",
req: &manager.ComputationRunReq{
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &manager.Algorithm{
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &manager.AgentConfig{},
AgentConfig: &AgentConfig{},
},
vmStartError: assert.AnError,
expectedError: assert.AnError,
},
{
name: "Invalid algorithm hash",
req: &manager.ComputationRunReq{
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &manager.Algorithm{
Algorithm: &Algorithm{
Hash: make([]byte, hashLength-1),
},
AgentConfig: &manager.AgentConfig{},
AgentConfig: &AgentConfig{},
},
vmStartError: nil,
expectedError: errInvalidHashLength,
},
{
name: "Invalid dataset hash",
req: &manager.ComputationRunReq{
req: &ComputationRunReq{
Id: "test-computation",
Name: "Test Computation",
Algorithm: &manager.Algorithm{
Algorithm: &Algorithm{
Hash: make([]byte, hashLength),
},
AgentConfig: &manager.AgentConfig{},
Datasets: []*manager.Dataset{
AgentConfig: &AgentConfig{},
Datasets: []*Dataset{
{
Hash: make([]byte, hashLength-1),
},
@@ -130,7 +129,7 @@ func TestRun(t *testing.T) {
},
}
logger := slog.Default()
eventsChan := make(chan *manager.ClientStreamMessage, 10)
eventsChan := make(chan *ClientStreamMessage, 10)
ms := &managerService{
qemuCfg: qemuCfg,
@@ -203,7 +202,7 @@ func TestStop(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := slog.Default()
eventsChan := make(chan *manager.ClientStreamMessage, 10)
eventsChan := make(chan *ClientStreamMessage, 10)
ms := &managerService{
logger: logger,
vms: make(map[string]vm.VM),
@@ -281,7 +280,7 @@ func TestPublishEvent(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
eventsChan := make(chan *manager.ClientStreamMessage, 1)
eventsChan := make(chan *ClientStreamMessage, 1)
ms := &managerService{
eventsChan: eventsChan,
}
@@ -370,7 +369,7 @@ func TestRestoreVMs(t *testing.T) {
ms := &managerService{
persistence: mockPersistence,
vms: make(map[string]vm.VM),
eventsChan: make(chan *manager.ClientStreamMessage, 10),
eventsChan: make(chan *ClientStreamMessage, 10),
vmFactory: vmf.Execute,
logger: mglog.NewMock(),
}
+1 -1
View File
@@ -15,8 +15,8 @@ import (
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/ultravioletrs/cocos/manager"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/crypto/sha3"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
+5 -6
View File
@@ -6,7 +6,6 @@ import (
"context"
"github.com/ultravioletrs/cocos/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"go.opentelemetry.io/otel/trace"
)
@@ -22,7 +21,7 @@ func New(svc manager.Service, tracer trace.Tracer) manager.Service {
return &tracingMiddleware{tracer, svc}
}
func (tm *tracingMiddleware) Run(ctx context.Context, mc *pkgmanager.ComputationRunReq) (string, error) {
func (tm *tracingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
ctx, span := tm.tracer.Start(ctx, "run")
defer span.End()
@@ -36,13 +35,13 @@ func (tm *tracingMiddleware) Stop(ctx context.Context, computationID string) err
return tm.svc.Stop(ctx, computationID)
}
func (tm *tracingMiddleware) RetrieveAgentEventsLogs() {
tm.svc.RetrieveAgentEventsLogs()
}
func (tm *tracingMiddleware) FetchBackendInfo(ctx context.Context, computationId string) ([]byte, error) {
_, span := tm.tracer.Start(ctx, "fetch_backend_info")
defer span.End()
return tm.svc.FetchBackendInfo(ctx, computationId)
}
func (tm *tracingMiddleware) ReportBrokenConnection(addr string) {
tm.svc.ReportBrokenConnection(addr)
}
+21 -58
View File
@@ -4,46 +4,26 @@ package vm
import (
"bytes"
"errors"
"io"
"log/slog"
"strings"
"github.com/ultravioletrs/cocos/pkg/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/types/known/timestamppb"
)
var (
_ io.Writer = &Stdout{}
_ io.Writer = &Stderr{}
ErrFailedToSendMessage = errors.New("failed to send message to channel")
ErrPanicRecovered = errors.New("panic recovered: channel may be closed")
_ io.Writer = &Stdout{}
_ io.Writer = &Stderr{}
)
const bufSize = 1024
type Stdout struct {
LogsChan chan *manager.ClientStreamMessage
EventSender EventSender
ComputationId string
}
// safeSend safely sends a message to the channel and returns an error on failure.
func safeSend(ch chan *manager.ClientStreamMessage, msg *manager.ClientStreamMessage) (err error) {
defer func() {
if r := recover(); r != nil {
// Recover from panic if the channel is closed
err = ErrPanicRecovered
}
}()
select {
case ch <- msg:
return nil
default:
// Channel is full or closed
return ErrFailedToSendMessage
}
}
// Write implements io.Writer.
func (s *Stdout) Write(p []byte) (n int, err error) {
inBuf := bytes.NewBuffer(p)
@@ -59,7 +39,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) {
return len(p) - inBuf.Len(), err
}
if err := sendLog(s.LogsChan, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil {
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil {
return len(p) - inBuf.Len(), err
}
}
@@ -68,7 +48,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) {
}
type Stderr struct {
LogsChan chan *manager.ClientStreamMessage
EventSender EventSender
ComputationId string
StateMachine StateMachine
}
@@ -88,32 +68,23 @@ func (s *Stderr) Write(p []byte) (n int, err error) {
return len(p) - inBuf.Len(), err
}
if err := sendLog(s.LogsChan, s.ComputationId, string(buf[:n]), ""); err != nil {
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), ""); err != nil {
return len(p) - inBuf.Len(), err
}
}
// Ensure vm-provision failure message is sent
eventMsg := &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentEvent{
AgentEvent: &manager.AgentEvent{
ComputationId: s.ComputationId,
EventType: s.StateMachine.State(),
Timestamp: timestamppb.Now(),
Originator: "manager",
Status: manager.Warning.String(),
},
},
eventMsg := &Event{
ComputationId: s.ComputationId,
EventType: s.StateMachine.State(),
Timestamp: timestamppb.Now(),
Originator: "manager",
Status: pkgmanager.Warning.String(),
}
if err := safeSend(s.LogsChan, eventMsg); err != nil {
return len(p), err
}
return len(p), nil
return len(p), s.EventSender(eventMsg)
}
func sendLog(logsChan chan *manager.ClientStreamMessage, computationID, message, level string) error {
func sendLog(eventSender EventSender, computationID, message, level string) error {
if len(message) < 3 {
return nil
}
@@ -126,20 +97,12 @@ func sendLog(logsChan chan *manager.ClientStreamMessage, computationID, message,
}
}
msg := &manager.ClientStreamMessage{
Message: &manager.ClientStreamMessage_AgentLog{
AgentLog: &manager.AgentLog{
Message: message,
ComputationId: computationID,
Level: level,
Timestamp: timestamppb.Now(),
},
},
msg := Log{
Message: message,
ComputationId: computationID,
Level: level,
Timestamp: timestamppb.Now(),
}
if err := safeSend(logsChan, msg); err != nil {
return err
}
return nil
return eventSender(&msg)
}
+40 -36
View File
@@ -8,7 +8,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/ultravioletrs/cocos/pkg/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
func TestStdoutWrite(t *testing.T) {
@@ -36,9 +36,12 @@ func TestStdoutWrite(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logsChan := make(chan *manager.ClientStreamMessage, 10)
eventLogChan := make(chan interface{}, 10)
s := &Stdout{
LogsChan: logsChan,
EventSender: func(event interface{}) error {
eventLogChan <- event
return nil
},
ComputationId: "test-computation",
}
@@ -50,9 +53,9 @@ func TestStdoutWrite(t *testing.T) {
var receivedWrites int
for i := 0; i < tt.expectedWrites; i++ {
select {
case msg := <-logsChan:
case msg := <-eventLogChan:
receivedWrites++
agentLog := msg.GetAgentLog()
agentLog := msg.(*Log)
assert.NotNil(t, agentLog)
assert.Equal(t, "test-computation", agentLog.ComputationId)
assert.Equal(t, slog.LevelDebug.String(), agentLog.Level)
@@ -93,14 +96,17 @@ func TestStderrWrite(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logsChan := make(chan *manager.ClientStreamMessage, 10)
eventLogChan := make(chan interface{}, 10)
s := &Stderr{
LogsChan: logsChan,
EventSender: func(event interface{}) error {
eventLogChan <- event
return nil
},
ComputationId: "test-computation",
StateMachine: NewStateMachine(),
}
err := s.StateMachine.Transition(manager.VmRunning)
err := s.StateMachine.Transition(pkgmanager.VmRunning)
assert.NoError(t, err)
n, err := s.Write([]byte(tt.input))
@@ -111,23 +117,21 @@ func TestStderrWrite(t *testing.T) {
var receivedWrites int
for i := 0; i < tt.expectedWrites; i++ {
select {
case msg := <-logsChan:
case msg := <-eventLogChan:
receivedWrites++
switch msg.Message.(type) {
case *manager.ClientStreamMessage_AgentLog:
agentLog := msg.GetAgentLog()
assert.NotNil(t, agentLog)
assert.Equal(t, "test-computation", agentLog.ComputationId)
assert.Equal(t, slog.LevelError.String(), agentLog.Level)
assert.NotEmpty(t, agentLog.Message)
assert.NotNil(t, agentLog.Timestamp)
case *manager.ClientStreamMessage_AgentEvent:
agentEvent := msg.GetAgentEvent()
assert.NotNil(t, agentEvent)
assert.Equal(t, "test-computation", agentEvent.ComputationId)
assert.Equal(t, manager.VmRunning.String(), agentEvent.EventType)
assert.Equal(t, manager.Warning.String(), agentEvent.Status)
assert.NotNil(t, agentEvent.Timestamp)
switch logEv := msg.(type) {
case *Log:
assert.NotNil(t, logEv)
assert.Equal(t, "test-computation", logEv.ComputationId)
assert.Equal(t, slog.LevelError.String(), logEv.Level)
assert.NotEmpty(t, logEv.Message)
assert.NotNil(t, logEv.Timestamp)
case *Event:
assert.NotNil(t, logEv)
assert.Equal(t, "test-computation", logEv.ComputationId)
assert.Equal(t, pkgmanager.VmRunning.String(), logEv.EventType)
assert.Equal(t, pkgmanager.Warning.String(), logEv.Status)
assert.NotNil(t, logEv.Timestamp)
}
case <-time.After(time.Second):
t.Fatal("Timed out waiting for log message")
@@ -140,37 +144,37 @@ func TestStderrWrite(t *testing.T) {
}
func TestStdoutWriteErrorHandling(t *testing.T) {
logsChan := make(chan *manager.ClientStreamMessage, 1)
eventLogChan := make(chan interface{}, 10)
s := &Stdout{
LogsChan: logsChan,
EventSender: func(event interface{}) error {
eventLogChan <- event
return assert.AnError
},
ComputationId: "test-computation",
}
// Test with a closed channel to simulate an error condition
close(logsChan)
message := []byte("This should fail")
n, err := s.Write(message)
assert.Error(t, err)
assert.Equal(t, len(message), n)
assert.Equal(t, ErrPanicRecovered, err)
assert.Equal(t, assert.AnError, err)
}
func TestStderrWriteErrorHandling(t *testing.T) {
logsChan := make(chan *manager.ClientStreamMessage, 1)
eventLogChan := make(chan interface{}, 10)
s := &Stderr{
LogsChan: logsChan,
EventSender: func(event interface{}) error {
eventLogChan <- event
return assert.AnError
},
ComputationId: "test-computation",
}
// Test with a closed channel to simulate an error condition
close(logsChan)
message := []byte("This should fail")
n, err := s.Write(message)
assert.Error(t, err)
assert.Equal(t, len(message), n)
assert.Equal(t, ErrPanicRecovered, err)
assert.Equal(t, assert.AnError, err)
}
+5 -7
View File
@@ -7,8 +7,6 @@ package mocks
import (
mock "github.com/stretchr/testify/mock"
manager "github.com/ultravioletrs/cocos/pkg/manager"
vm "github.com/ultravioletrs/cocos/manager/vm"
)
@@ -17,17 +15,17 @@ type Provider struct {
mock.Mock
}
// Execute provides a mock function with given fields: config, logsChan, computationId
func (_m *Provider) Execute(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) vm.VM {
ret := _m.Called(config, logsChan, computationId)
// Execute provides a mock function with given fields: config, eventSender, computationId
func (_m *Provider) Execute(config interface{}, eventSender vm.EventSender, computationId string) vm.VM {
ret := _m.Called(config, eventSender, computationId)
if len(ret) == 0 {
panic("no return value specified for Execute")
}
var r0 vm.VM
if rf, ok := ret.Get(0).(func(interface{}, chan *manager.ClientStreamMessage, string) vm.VM); ok {
r0 = rf(config, logsChan, computationId)
if rf, ok := ret.Get(0).(func(interface{}, vm.EventSender, string) vm.VM); ok {
r0 = rf(config, eventSender, computationId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(vm.VM)
+26 -3
View File
@@ -4,7 +4,8 @@ package vm
import (
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/pkg/manager"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/types/known/timestamppb"
)
// VM represents a virtual machine.
@@ -17,10 +18,32 @@ type VM interface {
SetProcess(pid int) error
GetProcess() int
GetCID() int
Transition(newState manager.ManagerState) error
Transition(newState pkgmanager.ManagerState) error
State() string
GetConfig() interface{}
}
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type Provider func(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) VM
type Provider func(config interface{}, eventSender EventSender, computationId string) VM
type Event struct {
EventType string
Timestamp *timestamppb.Timestamp
ComputationId string
Details []byte
Originator string
Status string
}
type Log struct {
Message string
ComputationId string
Level string
Timestamp *timestamppb.Timestamp
}
func (l *Log) IsEventLog() bool {
return true
}
type EventSender func(event interface{}) error
+1 -1
View File
@@ -3,8 +3,8 @@
package manager
import (
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
"github.com/ultravioletrs/cocos/pkg/manager"
)
// NewManagerClient creates new manager gRPC client instance.
+1
View File
@@ -19,4 +19,5 @@ const (
Stopped
Warning
Disconnected
Failed
)
+3 -2
View File
@@ -12,11 +12,12 @@ func _() {
_ = x[Stopped-1]
_ = x[Warning-2]
_ = x[Disconnected-3]
_ = x[Failed-4]
}
const _ManagerStatus_name = "StartingStoppedWarningDisconnected"
const _ManagerStatus_name = "StartingStoppedWarningDisconnectedFailed"
var _ManagerStatus_index = [...]uint8{0, 8, 15, 22, 34}
var _ManagerStatus_index = [...]uint8{0, 8, 15, 22, 34, 40}
func (i ManagerStatus) String() string {
if i >= ManagerStatus(len(_ManagerStatus_index)-1) {
+1 -1
View File
@@ -16,8 +16,8 @@ import (
"github.com/ultravioletrs/cocos/internal"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
"github.com/ultravioletrs/cocos/manager"
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
"github.com/ultravioletrs/cocos/pkg/manager"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
+3 -3
View File
@@ -18,12 +18,12 @@ import (
"github.com/ultravioletrs/cocos/internal"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/events"
"github.com/ultravioletrs/cocos/manager/qemu"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
)
const (
managerVsockPort = manager.ManagerVsockPort
managerVsockPort = events.ManagerVsockPort
vsockConfigPort = qemu.VsockConfigPort
)
@@ -115,7 +115,7 @@ func handleConnection(conn net.Conn) {
ackReader := internalvsock.NewAckReader(conn)
for {
var message pkgmanager.ClientStreamMessage
var message manager.ClientStreamMessage
err := ackReader.ReadProto(&message)
if err != nil {
log.Printf("Error reading message: %v", err)