mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Refactor manager events and detangle service (#287)
* extract events service Signed-off-by: Sammy Oina <sammyoina@gmail.com> * major refactor and detangling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * small fixes Signed-off-by: Sammy Oina <sammyoina@gmail.com> * handle tests better Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix lint Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix race condition Signed-off-by: Sammy Oina <sammyoina@gmail.com> * fix race Signed-off-by: Sammy Oina <sammyoina@gmail.com> * use plain interface Signed-off-by: Sammy Oina <sammyoina@gmail.com> * move mutex Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
20e7ea76e0
commit
fad3182638
@@ -35,7 +35,8 @@ $(BACKEND_INFO):
|
||||
|
||||
protoc:
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/agent.proto
|
||||
protoc -I. --go_out=./pkg --go_opt=paths=source_relative --go-grpc_out=./pkg --go-grpc_opt=paths=source_relative manager/manager.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative manager/manager.proto
|
||||
protoc -I. --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative agent/events/events.proto
|
||||
|
||||
mocks:
|
||||
go generate ./...
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/exec"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
@@ -24,8 +25,8 @@ type binary struct {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
|
||||
return &binary{
|
||||
algoFile: algoFile,
|
||||
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &algorithm.Stdout{Logger: logger},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
args: args,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
@@ -39,11 +39,11 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := b.stderr.(*algorithm.Stderr); !ok {
|
||||
if _, ok := b.stderr.(*logging.Stderr); !ok {
|
||||
t.Errorf("Expected stderr to be *algorithm.Stderr")
|
||||
}
|
||||
|
||||
if _, ok := b.stdout.(*algorithm.Stdout); !ok {
|
||||
if _, ok := b.stdout.(*logging.Stdout); !ok {
|
||||
t.Errorf("Expected stdout to be *algorithm.Stdout")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/docker/docker/api/types/mount"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
@@ -38,8 +39,8 @@ func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string
|
||||
d := &docker{
|
||||
algoFile: algoFile,
|
||||
logger: logger,
|
||||
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &algorithm.Stdout{Logger: logger},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
}
|
||||
|
||||
return d
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,6 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
assert.True(t, ok, "NewAlgorithm should return a *docker")
|
||||
assert.Equal(t, algoFile, d.algoFile, "algoFile should be set correctly")
|
||||
assert.NotNil(t, d.logger, "logger should be set")
|
||||
assert.IsType(t, &algorithm.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr")
|
||||
assert.IsType(t, &algorithm.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout")
|
||||
assert.IsType(t, &logging.Stderr{}, d.stderr, "stderr should be of type *algorithm.Stderr")
|
||||
assert.IsType(t, &logging.Stdout{}, d.stdout, "stdout should be of type *algorithm.Stdout")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package algorithm
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -1,6 +1,6 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package algorithm
|
||||
package logging
|
||||
|
||||
import (
|
||||
"strings"
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
@@ -43,8 +44,8 @@ type python struct {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string, args []string) algorithm.Algorithm {
|
||||
p := &python{
|
||||
algoFile: algoFile,
|
||||
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &algorithm.Stdout{Logger: logger},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
requirementsFile: requirementsFile,
|
||||
args: args,
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
@@ -90,8 +90,8 @@ func TestRun(t *testing.T) {
|
||||
|
||||
algo := &python{
|
||||
algoFile: scriptPath,
|
||||
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
|
||||
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
|
||||
stderr: io.MultiWriter(&stderr, &logging.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
|
||||
stdout: io.MultiWriter(&stdout, &logging.Stdout{Logger: slog.Default()}),
|
||||
runtime: "python3",
|
||||
}
|
||||
|
||||
@@ -132,8 +132,8 @@ func TestRunWithRequirements(t *testing.T) {
|
||||
algo := &python{
|
||||
algoFile: scriptPath,
|
||||
requirementsFile: requirementsPath,
|
||||
stderr: io.MultiWriter(&stderr, &algorithm.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
|
||||
stdout: io.MultiWriter(&stdout, &algorithm.Stdout{Logger: slog.Default()}),
|
||||
stderr: io.MultiWriter(&stderr, &logging.Stderr{Logger: slog.Default(), EventSvc: eventsSvc}),
|
||||
stdout: io.MultiWriter(&stdout, &logging.Stdout{Logger: slog.Default()}),
|
||||
runtime: "python3",
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os/exec"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
)
|
||||
|
||||
@@ -28,8 +29,8 @@ type wasm struct {
|
||||
func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string, args []string) algorithm.Algorithm {
|
||||
return &wasm{
|
||||
algoFile: algoFile,
|
||||
stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &algorithm.Stdout{Logger: logger},
|
||||
stderr: &logging.Stderr{Logger: logger, EventSvc: eventsSvc},
|
||||
stdout: &logging.Stdout{Logger: logger},
|
||||
args: args,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"os/exec"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm/logging"
|
||||
"github.com/ultravioletrs/cocos/agent/events/mocks"
|
||||
)
|
||||
|
||||
@@ -33,12 +33,12 @@ func TestNewAlgorithm(t *testing.T) {
|
||||
t.Errorf("Expected %d args, got %d", len(args), len(w.args))
|
||||
}
|
||||
|
||||
_, ok = w.stderr.(*algorithm.Stderr)
|
||||
_, ok = w.stderr.(*logging.Stderr)
|
||||
if !ok {
|
||||
t.Errorf("Expected stderr to be *algorithm.Stderr")
|
||||
}
|
||||
|
||||
_, ok = w.stdout.(*algorithm.Stdout)
|
||||
_, ok = w.stdout.(*logging.Stdout)
|
||||
if !ok {
|
||||
t.Errorf("Expected stdout to be *algorithm.Stdout")
|
||||
}
|
||||
|
||||
+1
-11
@@ -8,7 +8,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
@@ -24,15 +23,6 @@ type service struct {
|
||||
stopRetry chan struct{}
|
||||
}
|
||||
|
||||
type AgentEvent struct {
|
||||
EventType string `json:"event_type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ComputationID string `json:"computation_id,omitempty"`
|
||||
Details json.RawMessage `json:"details,omitempty"`
|
||||
Originator string `json:"originator"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
//go:generate mockery --name Service --output=./mocks --filename events.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Service interface {
|
||||
SendEvent(event, status string, details json.RawMessage) error
|
||||
@@ -54,7 +44,7 @@ func New(svc, computationID string, conn io.Writer) (Service, error) {
|
||||
}
|
||||
|
||||
func (s *service) SendEvent(event, status string, details json.RawMessage) error {
|
||||
body := manager.ClientStreamMessage{Message: &manager.ClientStreamMessage_AgentEvent{AgentEvent: &manager.AgentEvent{
|
||||
body := EventsLogs{Message: &EventsLogs_AgentEvent{AgentEvent: &AgentEvent{
|
||||
EventType: event,
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: s.computationID,
|
||||
|
||||
@@ -0,0 +1,405 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v5.28.1
|
||||
// source: agent/events/events.proto
|
||||
|
||||
package events
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type AgentEvent struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
EventType string `protobuf:"bytes,1,opt,name=event_type,json=eventType,proto3" json:"event_type,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,3,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Details []byte `protobuf:"bytes,4,opt,name=details,proto3" json:"details,omitempty"`
|
||||
Originator string `protobuf:"bytes,5,opt,name=originator,proto3" json:"originator,omitempty"`
|
||||
Status string `protobuf:"bytes,6,opt,name=status,proto3" json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (x *AgentEvent) Reset() {
|
||||
*x = AgentEvent{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_events_events_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *AgentEvent) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*AgentEvent) ProtoMessage() {}
|
||||
|
||||
func (x *AgentEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_events_events_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use AgentEvent.ProtoReflect.Descriptor instead.
|
||||
func (*AgentEvent) Descriptor() ([]byte, []int) {
|
||||
return file_agent_events_events_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *AgentEvent) GetEventType() string {
|
||||
if x != nil {
|
||||
return x.EventType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AgentEvent) GetTimestamp() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *AgentEvent) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AgentEvent) GetDetails() []byte {
|
||||
if x != nil {
|
||||
return x.Details
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *AgentEvent) GetOriginator() string {
|
||||
if x != nil {
|
||||
return x.Originator
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AgentEvent) GetStatus() string {
|
||||
if x != nil {
|
||||
return x.Status
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type AgentLog struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
|
||||
ComputationId string `protobuf:"bytes,2,opt,name=computation_id,json=computationId,proto3" json:"computation_id,omitempty"`
|
||||
Level string `protobuf:"bytes,3,opt,name=level,proto3" json:"level,omitempty"`
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
}
|
||||
|
||||
func (x *AgentLog) Reset() {
|
||||
*x = AgentLog{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_events_events_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *AgentLog) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*AgentLog) ProtoMessage() {}
|
||||
|
||||
func (x *AgentLog) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_events_events_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use AgentLog.ProtoReflect.Descriptor instead.
|
||||
func (*AgentLog) Descriptor() ([]byte, []int) {
|
||||
return file_agent_events_events_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *AgentLog) GetMessage() string {
|
||||
if x != nil {
|
||||
return x.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AgentLog) GetComputationId() string {
|
||||
if x != nil {
|
||||
return x.ComputationId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AgentLog) GetLevel() string {
|
||||
if x != nil {
|
||||
return x.Level
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AgentLog) GetTimestamp() *timestamppb.Timestamp {
|
||||
if x != nil {
|
||||
return x.Timestamp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type EventsLogs struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Message:
|
||||
//
|
||||
// *EventsLogs_AgentLog
|
||||
// *EventsLogs_AgentEvent
|
||||
Message isEventsLogs_Message `protobuf_oneof:"message"`
|
||||
}
|
||||
|
||||
func (x *EventsLogs) Reset() {
|
||||
*x = EventsLogs{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_agent_events_events_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *EventsLogs) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*EventsLogs) ProtoMessage() {}
|
||||
|
||||
func (x *EventsLogs) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_agent_events_events_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use EventsLogs.ProtoReflect.Descriptor instead.
|
||||
func (*EventsLogs) Descriptor() ([]byte, []int) {
|
||||
return file_agent_events_events_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (m *EventsLogs) GetMessage() isEventsLogs_Message {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventsLogs) GetAgentLog() *AgentLog {
|
||||
if x, ok := x.GetMessage().(*EventsLogs_AgentLog); ok {
|
||||
return x.AgentLog
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *EventsLogs) GetAgentEvent() *AgentEvent {
|
||||
if x, ok := x.GetMessage().(*EventsLogs_AgentEvent); ok {
|
||||
return x.AgentEvent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isEventsLogs_Message interface {
|
||||
isEventsLogs_Message()
|
||||
}
|
||||
|
||||
type EventsLogs_AgentLog struct {
|
||||
AgentLog *AgentLog `protobuf:"bytes,1,opt,name=agent_log,json=agentLog,proto3,oneof"`
|
||||
}
|
||||
|
||||
type EventsLogs_AgentEvent struct {
|
||||
AgentEvent *AgentEvent `protobuf:"bytes,2,opt,name=agent_event,json=agentEvent,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*EventsLogs_AgentLog) isEventsLogs_Message() {}
|
||||
|
||||
func (*EventsLogs_AgentEvent) isEventsLogs_Message() {}
|
||||
|
||||
var File_agent_events_events_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_agent_events_events_proto_rawDesc = []byte{
|
||||
0x0a, 0x19, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x65,
|
||||
0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06, 0x65, 0x76, 0x65,
|
||||
0x6e, 0x74, 0x73, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x22, 0xde, 0x01, 0x0a, 0x0a, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76,
|
||||
0x65, 0x6e, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x74, 0x79, 0x70,
|
||||
0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x54, 0x79,
|
||||
0x70, 0x65, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18,
|
||||
0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d,
|
||||
0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x25, 0x0a, 0x0e,
|
||||
0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x03,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04,
|
||||
0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x12, 0x1e, 0x0a,
|
||||
0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x0a, 0x6f, 0x72, 0x69, 0x67, 0x69, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x12, 0x16, 0x0a,
|
||||
0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73,
|
||||
0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x9b, 0x01, 0x0a, 0x08, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x25, 0x0a, 0x0e,
|
||||
0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x63, 0x6f, 0x6d, 0x70, 0x75, 0x74, 0x61, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d,
|
||||
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67,
|
||||
0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54,
|
||||
0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74,
|
||||
0x61, 0x6d, 0x70, 0x22, 0x7f, 0x0a, 0x0a, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x4c, 0x6f, 0x67,
|
||||
0x73, 0x12, 0x2f, 0x0a, 0x09, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x6c, 0x6f, 0x67, 0x18, 0x01,
|
||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73, 0x2e, 0x41, 0x67,
|
||||
0x65, 0x6e, 0x74, 0x4c, 0x6f, 0x67, 0x48, 0x00, 0x52, 0x08, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x4c,
|
||||
0x6f, 0x67, 0x12, 0x35, 0x0a, 0x0b, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x5f, 0x65, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
|
||||
0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x48, 0x00, 0x52, 0x0a, 0x61,
|
||||
0x67, 0x65, 0x6e, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x6d, 0x65, 0x73,
|
||||
0x73, 0x61, 0x67, 0x65, 0x42, 0x0a, 0x5a, 0x08, 0x2e, 0x2f, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x73,
|
||||
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_agent_events_events_proto_rawDescOnce sync.Once
|
||||
file_agent_events_events_proto_rawDescData = file_agent_events_events_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_agent_events_events_proto_rawDescGZIP() []byte {
|
||||
file_agent_events_events_proto_rawDescOnce.Do(func() {
|
||||
file_agent_events_events_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_events_events_proto_rawDescData)
|
||||
})
|
||||
return file_agent_events_events_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_agent_events_events_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
|
||||
var file_agent_events_events_proto_goTypes = []any{
|
||||
(*AgentEvent)(nil), // 0: events.AgentEvent
|
||||
(*AgentLog)(nil), // 1: events.AgentLog
|
||||
(*EventsLogs)(nil), // 2: events.EventsLogs
|
||||
(*timestamppb.Timestamp)(nil), // 3: google.protobuf.Timestamp
|
||||
}
|
||||
var file_agent_events_events_proto_depIdxs = []int32{
|
||||
3, // 0: events.AgentEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
3, // 1: events.AgentLog.timestamp:type_name -> google.protobuf.Timestamp
|
||||
1, // 2: events.EventsLogs.agent_log:type_name -> events.AgentLog
|
||||
0, // 3: events.EventsLogs.agent_event:type_name -> events.AgentEvent
|
||||
4, // [4:4] is the sub-list for method output_type
|
||||
4, // [4:4] is the sub-list for method input_type
|
||||
4, // [4:4] is the sub-list for extension type_name
|
||||
4, // [4:4] is the sub-list for extension extendee
|
||||
0, // [0:4] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_agent_events_events_proto_init() }
|
||||
func file_agent_events_events_proto_init() {
|
||||
if File_agent_events_events_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_agent_events_events_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AgentEvent); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_events_events_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*AgentLog); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_agent_events_events_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*EventsLogs); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_agent_events_events_proto_msgTypes[2].OneofWrappers = []any{
|
||||
(*EventsLogs_AgentLog)(nil),
|
||||
(*EventsLogs_AgentEvent)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_agent_events_events_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 3,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_agent_events_events_proto_goTypes,
|
||||
DependencyIndexes: file_agent_events_events_proto_depIdxs,
|
||||
MessageInfos: file_agent_events_events_proto_msgTypes,
|
||||
}.Build()
|
||||
File_agent_events_events_proto = out.File
|
||||
file_agent_events_events_proto_rawDesc = nil
|
||||
file_agent_events_events_proto_goTypes = nil
|
||||
file_agent_events_events_proto_depIdxs = nil
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package events;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
option go_package = "./events";
|
||||
|
||||
message AgentEvent {
|
||||
string event_type = 1;
|
||||
google.protobuf.Timestamp timestamp = 2;
|
||||
string computation_id = 3;
|
||||
bytes details = 4;
|
||||
string originator = 5;
|
||||
string status = 6;
|
||||
}
|
||||
|
||||
message AgentLog {
|
||||
string message = 1;
|
||||
string computation_id = 2;
|
||||
string level = 3;
|
||||
google.protobuf.Timestamp timestamp = 4;
|
||||
}
|
||||
|
||||
message EventsLogs {
|
||||
oneof message {
|
||||
AgentLog agent_log = 1;
|
||||
AgentEvent agent_event = 2;
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
@@ -37,7 +36,7 @@ func TestSendEventSuccess(t *testing.T) {
|
||||
err = svc.SendEvent("test_event", "success", details)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var writtenMessage manager.ClientStreamMessage
|
||||
var writtenMessage EventsLogs
|
||||
err = proto.Unmarshal(mockConnection.buf.Bytes(), &writtenMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
sync "sync"
|
||||
|
||||
"github.com/google/go-sev-guest/client"
|
||||
"github.com/ultravioletrs/cocos/agent/algorithm"
|
||||
@@ -111,6 +112,7 @@ type Service interface {
|
||||
}
|
||||
|
||||
type agentService struct {
|
||||
mu sync.Mutex
|
||||
computation Computation // Holds the current computation request details.
|
||||
algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation.
|
||||
result []byte // Stores the result of the computation.
|
||||
@@ -181,6 +183,8 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error {
|
||||
if as.sm.GetState() != ReceivingAlgorithm {
|
||||
return ErrStateNotReady
|
||||
}
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if as.algorithm != nil {
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
@@ -262,6 +266,8 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
if as.sm.GetState() != ReceivingData {
|
||||
return ErrStateNotReady
|
||||
}
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if len(as.computation.Datasets) == 0 {
|
||||
return ErrAllManifestItemsReceived
|
||||
}
|
||||
@@ -322,6 +328,8 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
|
||||
return []byte{}, ErrUndeclaredConsumer
|
||||
}
|
||||
|
||||
as.mu.Lock()
|
||||
defer as.mu.Unlock()
|
||||
if index < 0 || index >= len(as.computation.ResultConsumers) {
|
||||
return []byte{}, ErrUndeclaredConsumer
|
||||
}
|
||||
|
||||
+2
-2
@@ -29,7 +29,7 @@ import (
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
managerevents "github.com/ultravioletrs/cocos/manager/events"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -213,7 +213,7 @@ func dialVsock() (*vsock.Conn, error) {
|
||||
var err error
|
||||
|
||||
err = backoff.Retry(func() error {
|
||||
conn, err = vsock.Dial(vsock.Host, manager.ManagerVsockPort, nil)
|
||||
conn, err = vsock.Dial(vsock.Host, managerevents.ManagerVsockPort, nil)
|
||||
if err == nil {
|
||||
log.Println("vsock connection established")
|
||||
return nil
|
||||
|
||||
+12
-4
@@ -22,11 +22,11 @@ import (
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/api"
|
||||
managerapi "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager/events"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/tracing"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
managergrpc "github.com/ultravioletrs/cocos/pkg/clients/grpc/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -113,7 +113,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
eventsChan := make(chan *pkgmanager.ClientStreamMessage, clientBufferSize)
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, clientBufferSize)
|
||||
svc, err := newService(logger, tracer, qemuCfg, eventsChan, cfg.BackendMeasurementBinary)
|
||||
if err != nil {
|
||||
logger.Error(err.Error())
|
||||
@@ -121,6 +121,15 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
eventsSvc := events.New(logger, svc.ReportBrokenConnection, eventsChan)
|
||||
if eventsSvc == nil {
|
||||
logger.Error("Failed to create events service")
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
go eventsSvc.Listen(ctx)
|
||||
|
||||
mc := managerapi.NewClient(pc, svc, eventsChan, logger)
|
||||
|
||||
g.Go(func() error {
|
||||
@@ -147,12 +156,11 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, eventsChan chan *pkgmanager.ClientStreamMessage, backendMeasurementPath string) (manager.Service, error) {
|
||||
func newService(logger *slog.Logger, tracer trace.Tracer, qemuCfg qemu.Config, eventsChan chan *manager.ClientStreamMessage, backendMeasurementPath string) (manager.Service, error) {
|
||||
svc, err := manager.New(qemuCfg, backendMeasurementPath, logger, eventsChan, qemu.NewVM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go svc.RetrieveAgentEventsLogs()
|
||||
svc = api.LoggingMiddleware(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics(svcName, "api")
|
||||
svc = api.MetricsMiddleware(svc, counter, latency)
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"github.com/ultravioletrs/cocos/agent/events"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
@@ -71,9 +71,9 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
|
||||
|
||||
chunk := message[start:end]
|
||||
|
||||
agentLog := manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
agentLog := events.EventsLogs{
|
||||
Message: &events.EventsLogs_AgentLog{
|
||||
AgentLog: &events.AgentLog{
|
||||
Timestamp: timestamp,
|
||||
Message: chunk,
|
||||
Level: level,
|
||||
|
||||
@@ -104,18 +104,15 @@ func (aw *AckWriter) sendMessages() {
|
||||
}
|
||||
|
||||
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
|
||||
// Write message ID
|
||||
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write message length
|
||||
messageLen := uint32(len(p))
|
||||
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write message content
|
||||
if _, err := aw.conn.Write(p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
|
||||
+12
-82
@@ -4,97 +4,18 @@ package manager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
ManagerVsockPort = 9997
|
||||
messageSize int = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
errFailedToParseCID = fmt.Errorf("failed to parse computation ID")
|
||||
errComputationNotFound = fmt.Errorf("computation not found")
|
||||
)
|
||||
|
||||
// RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock.
|
||||
func (ms *managerService) RetrieveAgentEventsLogs() {
|
||||
l, err := vsock.Listen(ManagerVsockPort, nil)
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
go ms.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *managerService) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
var message manager.ClientStreamMessage
|
||||
data, err := ackReader.Read()
|
||||
if err != nil {
|
||||
go ms.reportBrokenConnection(cmpID)
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := proto.Unmarshal(data, &message); err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
ms.eventsChan <- &message
|
||||
|
||||
args := []any{}
|
||||
|
||||
switch message.Message.(type) {
|
||||
case *manager.ClientStreamMessage_AgentEvent:
|
||||
args = append(args, slog.Group("agent-event",
|
||||
slog.String("event-type", message.GetAgentEvent().GetEventType()),
|
||||
slog.String("computation-id", message.GetAgentEvent().GetComputationId()),
|
||||
slog.String("status", message.GetAgentEvent().GetStatus()),
|
||||
slog.String("originator", message.GetAgentEvent().GetOriginator()),
|
||||
slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()),
|
||||
slog.String("details", string(message.GetAgentEvent().GetDetails()))))
|
||||
case *manager.ClientStreamMessage_AgentLog:
|
||||
args = append(args, slog.Group("agent-log",
|
||||
slog.String("computation-id", message.GetAgentLog().GetComputationId()),
|
||||
slog.String("level", message.GetAgentLog().GetLevel()),
|
||||
slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()),
|
||||
slog.String("message", message.GetAgentLog().GetMessage())))
|
||||
}
|
||||
|
||||
ms.logger.Info("", args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *managerService) computationIDFromAddress(address string) (string, error) {
|
||||
re := regexp.MustCompile(`vm\((\d+)\)`)
|
||||
matches := re.FindStringSubmatch(address)
|
||||
@@ -122,9 +43,9 @@ func (ms *managerService) findComputationID(cid int) (string, error) {
|
||||
}
|
||||
|
||||
func (ms *managerService) reportBrokenConnection(cmpID string) {
|
||||
ms.eventsChan <- &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &AgentEvent{
|
||||
EventType: ms.vms[cmpID].State(),
|
||||
ComputationId: cmpID,
|
||||
Status: manager.Disconnected.String(),
|
||||
@@ -134,3 +55,12 @@ func (ms *managerService) reportBrokenConnection(cmpID string) {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *managerService) ReportBrokenConnection(addr string) {
|
||||
cmpID, err := ms.computationIDFromAddress(addr)
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
ms.reportBrokenConnection(cmpID)
|
||||
}
|
||||
|
||||
@@ -3,83 +3,19 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockAddr struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func TestComputationIDFromAddress(t *testing.T) {
|
||||
ms := &managerService{
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
|
||||
"comp2": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}, make(chan *manager.ClientStreamMessage), "comp2"),
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, func(event interface{}) error { return nil }, "comp1"),
|
||||
"comp2": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 5}}, func(event interface{}) error { return nil }, "comp2"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -107,52 +43,11 @@ func TestComputationIDFromAddress(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
ms := &managerService{
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
|
||||
},
|
||||
eventsChan: make(chan *manager.ClientStreamMessage, 1),
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
mockConn := new(MockConn)
|
||||
mockAddr := new(MockAddr)
|
||||
mockConn.On("RemoteAddr").Return(mockAddr)
|
||||
mockConn.On("Close").Return(nil)
|
||||
mockAddr.On("String").Return("vm(3)")
|
||||
|
||||
msg := &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: manager.VmProvision.String(),
|
||||
ComputationId: "comp1",
|
||||
Status: manager.VmProvision.String(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "agent",
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBytes, _ := proto.Marshal(msg)
|
||||
|
||||
mockConn.On("Read", mock.Anything).Return(len(msgBytes), nil).Run(func(args mock.Arguments) {
|
||||
copy(args.Get(0).([]byte), msgBytes)
|
||||
}).Once()
|
||||
|
||||
mockConn.On("Read", mock.Anything).Return(0, net.ErrClosed)
|
||||
|
||||
go ms.handleConnection(mockConn)
|
||||
|
||||
receivedMsg := <-ms.eventsChan
|
||||
assert.Equal(t, msg.GetAgentEvent().EventType, receivedMsg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, msg.GetAgentEvent().ComputationId, receivedMsg.GetAgentEvent().ComputationId)
|
||||
}
|
||||
|
||||
func TestReportBrokenConnection(t *testing.T) {
|
||||
ms := &managerService{
|
||||
eventsChan: make(chan *manager.ClientStreamMessage, 1),
|
||||
eventsChan: make(chan *ClientStreamMessage, 1),
|
||||
vms: map[string]vm.VM{
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, make(chan *manager.ClientStreamMessage), "comp1"),
|
||||
"comp1": qemu.NewVM(qemu.Config{VSockConfig: qemu.VSockConfig{GuestCID: 3}}, func(event interface{}) error { return nil }, "comp1"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+24
-25
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -22,15 +21,15 @@ var (
|
||||
)
|
||||
|
||||
type ManagerClient struct {
|
||||
stream pkgmanager.ManagerService_ProcessClient
|
||||
stream manager.ManagerService_ProcessClient
|
||||
svc manager.Service
|
||||
messageQueue chan *pkgmanager.ClientStreamMessage
|
||||
messageQueue chan *manager.ClientStreamMessage
|
||||
logger *slog.Logger
|
||||
runReqManager *runRequestManager
|
||||
}
|
||||
|
||||
// NewClient returns new gRPC client instance.
|
||||
func NewClient(stream pkgmanager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *pkgmanager.ClientStreamMessage, logger *slog.Logger) ManagerClient {
|
||||
func NewClient(stream manager.ManagerService_ProcessClient, svc manager.Service, messageQueue chan *manager.ClientStreamMessage, logger *slog.Logger) ManagerClient {
|
||||
return ManagerClient{
|
||||
stream: stream,
|
||||
svc: svc,
|
||||
@@ -71,15 +70,15 @@ func (client ManagerClient) handleIncomingMessages(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkgmanager.ServerStreamMessage) error {
|
||||
func (client ManagerClient) processIncomingMessage(ctx context.Context, req *manager.ServerStreamMessage) error {
|
||||
switch mes := req.Message.(type) {
|
||||
case *pkgmanager.ServerStreamMessage_RunReqChunks:
|
||||
case *manager.ServerStreamMessage_RunReqChunks:
|
||||
return client.handleRunReqChunks(ctx, mes)
|
||||
case *pkgmanager.ServerStreamMessage_TerminateReq:
|
||||
case *manager.ServerStreamMessage_TerminateReq:
|
||||
return client.handleTerminateReq(mes)
|
||||
case *pkgmanager.ServerStreamMessage_StopComputation:
|
||||
case *manager.ServerStreamMessage_StopComputation:
|
||||
go client.handleStopComputation(ctx, mes)
|
||||
case *pkgmanager.ServerStreamMessage_BackendInfoReq:
|
||||
case *manager.ServerStreamMessage_BackendInfoReq:
|
||||
go client.handleBackendInfoReq(ctx, mes)
|
||||
default:
|
||||
return errors.New("unknown message type")
|
||||
@@ -87,11 +86,11 @@ func (client ManagerClient) processIncomingMessage(ctx context.Context, req *pkg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgmanager.ServerStreamMessage_RunReqChunks) error {
|
||||
func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *manager.ServerStreamMessage_RunReqChunks) error {
|
||||
buffer, complete := client.runReqManager.addChunk(mes.RunReqChunks.Id, mes.RunReqChunks.Data, mes.RunReqChunks.IsLast)
|
||||
|
||||
if complete {
|
||||
var runReq pkgmanager.ComputationRunReq
|
||||
var runReq manager.ComputationRunReq
|
||||
if err := proto.Unmarshal(buffer, &runReq); err != nil {
|
||||
return errors.Wrap(err, errCorruptedManifest)
|
||||
}
|
||||
@@ -102,50 +101,50 @@ func (client *ManagerClient) handleRunReqChunks(ctx context.Context, mes *pkgman
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client ManagerClient) executeRun(ctx context.Context, runReq *pkgmanager.ComputationRunReq) {
|
||||
func (client ManagerClient) executeRun(ctx context.Context, runReq *manager.ComputationRunReq) {
|
||||
port, err := client.svc.Run(ctx, runReq)
|
||||
if err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
runRes := &pkgmanager.ClientStreamMessage_RunRes{
|
||||
RunRes: &pkgmanager.RunResponse{
|
||||
runRes := &manager.ClientStreamMessage_RunRes{
|
||||
RunRes: &manager.RunResponse{
|
||||
AgentPort: port,
|
||||
ComputationId: runReq.Id,
|
||||
},
|
||||
}
|
||||
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: runRes})
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: runRes})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleTerminateReq(mes *pkgmanager.ServerStreamMessage_TerminateReq) error {
|
||||
func (client ManagerClient) handleTerminateReq(mes *manager.ServerStreamMessage_TerminateReq) error {
|
||||
return errors.Wrap(errTerminationFromServer, errors.New(mes.TerminateReq.Message))
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleStopComputation(ctx context.Context, mes *pkgmanager.ServerStreamMessage_StopComputation) {
|
||||
msg := &pkgmanager.ClientStreamMessage_StopComputationRes{
|
||||
StopComputationRes: &pkgmanager.StopComputationResponse{
|
||||
func (client ManagerClient) handleStopComputation(ctx context.Context, mes *manager.ServerStreamMessage_StopComputation) {
|
||||
msg := &manager.ClientStreamMessage_StopComputationRes{
|
||||
StopComputationRes: &manager.StopComputationResponse{
|
||||
ComputationId: mes.StopComputation.ComputationId,
|
||||
},
|
||||
}
|
||||
if err := client.svc.Stop(ctx, mes.StopComputation.ComputationId); err != nil {
|
||||
msg.StopComputationRes.Message = err.Error()
|
||||
}
|
||||
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: msg})
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: msg})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *pkgmanager.ServerStreamMessage_BackendInfoReq) {
|
||||
func (client ManagerClient) handleBackendInfoReq(ctx context.Context, mes *manager.ServerStreamMessage_BackendInfoReq) {
|
||||
res, err := client.svc.FetchBackendInfo(ctx, mes.BackendInfoReq.Id)
|
||||
if err != nil {
|
||||
client.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
info := &pkgmanager.ClientStreamMessage_BackendInfo{
|
||||
BackendInfo: &pkgmanager.BackendInfo{
|
||||
info := &manager.ClientStreamMessage_BackendInfo{
|
||||
BackendInfo: &manager.BackendInfo{
|
||||
Info: res,
|
||||
Id: mes.BackendInfoReq.Id,
|
||||
},
|
||||
}
|
||||
client.sendMessage(&pkgmanager.ClientStreamMessage{Message: info})
|
||||
client.sendMessage(&manager.ClientStreamMessage{Message: info})
|
||||
}
|
||||
|
||||
func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error {
|
||||
@@ -161,7 +160,7 @@ func (client ManagerClient) handleOutgoingMessages(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (client ManagerClient) sendMessage(mes *pkgmanager.ClientStreamMessage) {
|
||||
func (client ManagerClient) sendMessage(mes *manager.ClientStreamMessage) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/mocks"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
@@ -22,12 +22,12 @@ type mockStream struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (m *mockStream) Recv() (*pkgmanager.ServerStreamMessage, error) {
|
||||
func (m *mockStream) Recv() (*manager.ServerStreamMessage, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*pkgmanager.ServerStreamMessage), args.Error(1)
|
||||
return args.Get(0).(*manager.ServerStreamMessage), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockStream) Send(msg *pkgmanager.ClientStreamMessage) error {
|
||||
func (m *mockStream) Send(msg *manager.ClientStreamMessage) error {
|
||||
args := m.Called(msg)
|
||||
return args.Error(0)
|
||||
}
|
||||
@@ -35,7 +35,7 @@ func (m *mockStream) Send(msg *pkgmanager.ClientStreamMessage) error {
|
||||
func TestManagerClient_Process(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
@@ -43,7 +43,7 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
mockStream.On("Recv").Return(&pkgmanager.ServerStreamMessage{Message: &pkgmanager.ServerStreamMessage_StopComputation{StopComputation: &pkgmanager.StopComputation{}}}, nil).Maybe()
|
||||
mockStream.On("Recv").Return(&manager.ServerStreamMessage{Message: &manager.ServerStreamMessage_StopComputation{StopComputation: &manager.StopComputation{}}}, nil).Maybe()
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Maybe()
|
||||
|
||||
mockSvc.On("Stop", mock.Anything, mock.Anything).Return(nil).Maybe()
|
||||
@@ -57,25 +57,25 @@ func TestManagerClient_Process(t *testing.T) {
|
||||
func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
runReq := &pkgmanager.ComputationRunReq{
|
||||
runReq := &manager.ComputationRunReq{
|
||||
Id: "test-id",
|
||||
}
|
||||
runReqBytes, _ := proto.Marshal(runReq)
|
||||
|
||||
chunk1 := &pkgmanager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &pkgmanager.RunReqChunks{
|
||||
chunk1 := &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[:len(runReqBytes)/2],
|
||||
IsLast: false,
|
||||
},
|
||||
}
|
||||
chunk2 := &pkgmanager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &pkgmanager.RunReqChunks{
|
||||
chunk2 := &manager.ServerStreamMessage_RunReqChunks{
|
||||
RunReqChunks: &manager.RunReqChunks{
|
||||
Id: "chunk-1",
|
||||
Data: runReqBytes[len(runReqBytes)/2:],
|
||||
IsLast: true,
|
||||
@@ -97,7 +97,7 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
runRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_RunRes)
|
||||
runRes, ok := msg.Message.(*manager.ClientStreamMessage_RunRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "8080", runRes.RunRes.AgentPort)
|
||||
assert.Equal(t, "test-id", runRes.RunRes.ComputationId)
|
||||
@@ -106,8 +106,8 @@ func TestManagerClient_handleRunReqChunks(t *testing.T) {
|
||||
func TestManagerClient_handleTerminateReq(t *testing.T) {
|
||||
client := ManagerClient{}
|
||||
|
||||
terminateReq := &pkgmanager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &pkgmanager.Terminate{
|
||||
terminateReq := &manager.ServerStreamMessage_TerminateReq{
|
||||
TerminateReq: &manager.Terminate{
|
||||
Message: "Test termination",
|
||||
},
|
||||
}
|
||||
@@ -121,13 +121,13 @@ func TestManagerClient_handleTerminateReq(t *testing.T) {
|
||||
func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
stopReq := &pkgmanager.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &pkgmanager.StopComputation{
|
||||
stopReq := &manager.ServerStreamMessage_StopComputation{
|
||||
StopComputation: &manager.StopComputation{
|
||||
ComputationId: "test-comp-id",
|
||||
},
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func TestManagerClient_handleStopComputation(t *testing.T) {
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
stopRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_StopComputationRes)
|
||||
stopRes, ok := msg.Message.(*manager.ClientStreamMessage_StopComputationRes)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-comp-id", stopRes.StopComputationRes.ComputationId)
|
||||
assert.Empty(t, stopRes.StopComputationRes.Message)
|
||||
@@ -153,13 +153,13 @@ func TestManagerClient_handleBackendInfoReq(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &pkgmanager.BackendInfoReq{
|
||||
infoReq := &manager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &manager.BackendInfoReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func TestManagerClient_handleBackendInfoReq(t *testing.T) {
|
||||
assert.Len(t, messageQueue, 1)
|
||||
|
||||
msg := <-messageQueue
|
||||
infoRes, ok := msg.Message.(*pkgmanager.ClientStreamMessage_BackendInfo)
|
||||
infoRes, ok := msg.Message.(*manager.ClientStreamMessage_BackendInfo)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test-info-id", infoRes.BackendInfo.Id)
|
||||
assert.Equal(t, []byte("test-backend-info"), infoRes.BackendInfo.Info)
|
||||
@@ -183,13 +183,13 @@ func TestManagerClient_handleBackendInfoReq(t *testing.T) {
|
||||
t.Run("error", func(t *testing.T) {
|
||||
mockStream := new(mockStream)
|
||||
mockSvc := new(mocks.Service)
|
||||
messageQueue := make(chan *pkgmanager.ClientStreamMessage, 10)
|
||||
messageQueue := make(chan *manager.ClientStreamMessage, 10)
|
||||
logger := mglog.NewMock()
|
||||
|
||||
client := NewClient(mockStream, mockSvc, messageQueue, logger)
|
||||
|
||||
infoReq := &pkgmanager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &pkgmanager.BackendInfoReq{
|
||||
infoReq := &manager.ServerStreamMessage_BackendInfoReq{
|
||||
BackendInfoReq: &manager.BackendInfoReq{
|
||||
Id: "test-info-id",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
var _ manager.Service = (*loggingMiddleware)(nil)
|
||||
@@ -28,7 +27,7 @@ func LoggingMiddleware(svc manager.Service, logger *slog.Logger) manager.Service
|
||||
return &loggingMiddleware{logger, svc}
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Run(ctx context.Context, mc *pkgmanager.ComputationRunReq) (agentAddr string, err error) {
|
||||
func (lm *loggingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (agentAddr string, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method Run for computation took %s to complete", time.Since(begin))
|
||||
if err != nil {
|
||||
@@ -54,10 +53,6 @@ func (lm *loggingMiddleware) Stop(ctx context.Context, computationID string) (er
|
||||
return lm.svc.Stop(ctx, computationID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RetrieveAgentEventsLogs() {
|
||||
lm.svc.RetrieveAgentEventsLogs()
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) (body []byte, err error) {
|
||||
defer func(begin time.Time) {
|
||||
message := fmt.Sprintf("Method FetchBackendInfo for computation %s took %s to complete", cmpId, time.Since(begin))
|
||||
@@ -71,3 +66,7 @@ func (lm *loggingMiddleware) FetchBackendInfo(ctx context.Context, cmpId string)
|
||||
|
||||
return lm.svc.FetchBackendInfo(ctx, cmpId)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ReportBrokenConnection(addr string) {
|
||||
lm.svc.ReportBrokenConnection(addr)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
var _ manager.Service = (*metricsMiddleware)(nil)
|
||||
@@ -33,7 +32,7 @@ func MetricsMiddleware(svc manager.Service, counter metrics.Counter, latency met
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Run(ctx context.Context, mc *pkgmanager.ComputationRunReq) (string, error) {
|
||||
func (ms *metricsMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "Run").Add(1)
|
||||
ms.latency.With("method", "Run").Observe(time.Since(begin).Seconds())
|
||||
@@ -51,10 +50,6 @@ func (ms *metricsMiddleware) Stop(ctx context.Context, computationID string) err
|
||||
return ms.svc.Stop(ctx, computationID)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RetrieveAgentEventsLogs() {
|
||||
ms.svc.RetrieveAgentEventsLogs()
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) FetchBackendInfo(ctx context.Context, cmpId string) ([]byte, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "FetchBackendInfo").Add(1)
|
||||
@@ -63,3 +58,7 @@ func (ms *metricsMiddleware) FetchBackendInfo(ctx context.Context, cmpId string)
|
||||
|
||||
return ms.svc.FetchBackendInfo(ctx, cmpId)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ReportBrokenConnection(addr string) {
|
||||
ms.svc.ReportBrokenConnection(addr)
|
||||
}
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/ultravioletrs/cocos/cli"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/virtee/sev-snp-measure-go/cpuid"
|
||||
"github.com/virtee/sev-snp-measure-go/guest"
|
||||
"github.com/virtee/sev-snp-measure-go/vmmtypes"
|
||||
@@ -48,7 +48,7 @@ func (ms *managerService) FetchBackendInfo(_ context.Context, computationId stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backendInfo cli.AttestationConfiguration
|
||||
var backendInfo grpc.AttestationConfiguration
|
||||
|
||||
if err = json.Unmarshal(f, &backendInfo); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import "context"
|
||||
|
||||
type Listener interface {
|
||||
Listen(ctx context.Context)
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
agentevents "github.com/ultravioletrs/cocos/agent/events"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
ManagerVsockPort = 9997
|
||||
messageSize int = 1024 * 1024
|
||||
)
|
||||
|
||||
type ReportBrokenConnectionFunc func(address string)
|
||||
|
||||
type events struct {
|
||||
reportBrokenConnection ReportBrokenConnectionFunc
|
||||
lis net.Listener
|
||||
logger *slog.Logger
|
||||
eventsChan chan *manager.ClientStreamMessage
|
||||
}
|
||||
|
||||
func New(logger *slog.Logger, reportBrokenConnection ReportBrokenConnectionFunc, eventsChan chan *manager.ClientStreamMessage) Listener {
|
||||
l, err := vsock.Listen(ManagerVsockPort, nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &events{
|
||||
lis: l,
|
||||
reportBrokenConnection: reportBrokenConnection,
|
||||
logger: logger,
|
||||
eventsChan: eventsChan,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *events) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
e.logger.Info("Listener shutting down")
|
||||
return
|
||||
default:
|
||||
conn, err := e.lis.Accept()
|
||||
if err != nil {
|
||||
e.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
go e.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *events) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
var message agentevents.EventsLogs
|
||||
data, err := ackReader.Read()
|
||||
if err != nil {
|
||||
go e.reportBrokenConnection(conn.RemoteAddr().String())
|
||||
e.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := proto.Unmarshal(data, &message); err != nil {
|
||||
e.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
var mes manager.ClientStreamMessage
|
||||
|
||||
args := []any{}
|
||||
|
||||
switch message.Message.(type) {
|
||||
case *agentevents.EventsLogs_AgentEvent:
|
||||
args = append(args, slog.Group("agent-event",
|
||||
slog.String("event-type", message.GetAgentEvent().GetEventType()),
|
||||
slog.String("computation-id", message.GetAgentEvent().GetComputationId()),
|
||||
slog.String("status", message.GetAgentEvent().GetStatus()),
|
||||
slog.String("originator", message.GetAgentEvent().GetOriginator()),
|
||||
slog.String("timestamp", message.GetAgentEvent().GetTimestamp().String()),
|
||||
slog.String("details", string(message.GetAgentEvent().GetDetails()))))
|
||||
mes = manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: message.GetAgentEvent().GetEventType(),
|
||||
ComputationId: message.GetAgentEvent().GetComputationId(),
|
||||
Status: message.GetAgentEvent().GetStatus(),
|
||||
Originator: message.GetAgentEvent().GetOriginator(),
|
||||
Timestamp: message.GetAgentEvent().GetTimestamp(),
|
||||
Details: message.GetAgentEvent().GetDetails(),
|
||||
},
|
||||
},
|
||||
}
|
||||
case *agentevents.EventsLogs_AgentLog:
|
||||
args = append(args, slog.Group("agent-log",
|
||||
slog.String("computation-id", message.GetAgentLog().GetComputationId()),
|
||||
slog.String("level", message.GetAgentLog().GetLevel()),
|
||||
slog.String("timestamp", message.GetAgentLog().GetTimestamp().String()),
|
||||
slog.String("message", message.GetAgentLog().GetMessage())))
|
||||
mes = manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
ComputationId: message.GetAgentLog().GetComputationId(),
|
||||
Level: message.GetAgentLog().GetLevel(),
|
||||
Timestamp: message.GetAgentLog().GetTimestamp(),
|
||||
Message: message.GetAgentLog().GetMessage(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
e.eventsChan <- &mes
|
||||
|
||||
e.logger.Info("", args...)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package events
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type MockVsockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockVsockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockVsockListener) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockVsockListener) Addr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
var _ net.Conn = (*MockConn)(nil)
|
||||
|
||||
type MockConn struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
if !vsockDeviceExists() {
|
||||
t.Skip("Skipping test: vsock device not available")
|
||||
}
|
||||
|
||||
logger := &slog.Logger{}
|
||||
reportBrokenConnection := func(address string) {}
|
||||
eventsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
e := New(logger, reportBrokenConnection, eventsChan)
|
||||
|
||||
assert.NotNil(t, e)
|
||||
assert.IsType(t, &events{}, e)
|
||||
}
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
mockListener := new(MockVsockListener)
|
||||
mockConn := new(MockConn)
|
||||
|
||||
e := &events{
|
||||
lis: mockListener,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
mockListener.On("Accept").Return(mockConn, fmt.Errorf("mock error")).Once()
|
||||
mockListener.On("Accept").Return(mockConn, nil)
|
||||
mockConn.On("Close").Return(nil)
|
||||
mockConn.On("Read", mock.Anything).Return(0, nil)
|
||||
|
||||
go e.Listen(context.Background())
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func vsockDeviceExists() bool {
|
||||
fs, err := os.Stat("/dev/vsock")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if fs.Mode()&os.ModeDevice == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type MockConnWithBuffer struct {
|
||||
mock.Mock
|
||||
readBuf *bytes.Buffer
|
||||
writeBuf *bytes.Buffer
|
||||
}
|
||||
|
||||
func NewMockConnWithBuffer() *MockConnWithBuffer {
|
||||
return &MockConnWithBuffer{
|
||||
readBuf: new(bytes.Buffer),
|
||||
writeBuf: new(bytes.Buffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) Read(b []byte) (n int, err error) {
|
||||
return m.readBuf.Read(b)
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) Write(b []byte) (n int, err error) {
|
||||
return m.writeBuf.Write(b)
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) RemoteAddr() net.Addr {
|
||||
return &net.IPAddr{IP: net.ParseIP("localhost")}
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConnWithBuffer) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
mockConn := NewMockConnWithBuffer()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
|
||||
e := &events{
|
||||
logger: mglog.NewMock(),
|
||||
eventsChan: eventsChan,
|
||||
reportBrokenConnection: func(address string) {},
|
||||
}
|
||||
|
||||
message := &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
EventType: "test_event",
|
||||
ComputationId: "test_computation",
|
||||
Status: "test_status",
|
||||
Originator: "test_originator",
|
||||
Timestamp: timestamppb.Now(),
|
||||
Details: []byte("test_details"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(message)
|
||||
assert.NoError(t, err)
|
||||
|
||||
messageID := uint32(1)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, messageID)
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(len(data)))
|
||||
assert.NoError(t, err)
|
||||
_, err = mockConn.readBuf.Write(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Add EOF to signal end of stream
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
err = binary.Write(mockConn.readBuf, binary.LittleEndian, uint32(0))
|
||||
assert.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
e.handleConnection(mockConn)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
var receivedMessage *manager.ClientStreamMessage
|
||||
select {
|
||||
case receivedMessage = <-eventsChan:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for message in eventsChan")
|
||||
}
|
||||
|
||||
assert.NotNil(t, receivedMessage)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// handleConnection has exited
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Timeout waiting for handleConnection to exit")
|
||||
}
|
||||
|
||||
// Check if ack was written
|
||||
var receivedAck uint32
|
||||
err = binary.Read(mockConn.writeBuf, binary.LittleEndian, &receivedAck)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, messageID, receivedAck)
|
||||
|
||||
// Ensure no unexpected calls were made on the mock
|
||||
mockConn.AssertExpectations(t)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
@@ -9,8 +9,7 @@ import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
manager "github.com/ultravioletrs/cocos/manager"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
@@ -48,13 +47,13 @@ func (_m *Service) FetchBackendInfo(ctx context.Context, computationID string) (
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveAgentEventsLogs provides a mock function with given fields:
|
||||
func (_m *Service) RetrieveAgentEventsLogs() {
|
||||
_m.Called()
|
||||
// ReportBrokenConnection provides a mock function with given fields: addr
|
||||
func (_m *Service) ReportBrokenConnection(addr string) {
|
||||
_m.Called(addr)
|
||||
}
|
||||
|
||||
// Run provides a mock function with given fields: ctx, c
|
||||
func (_m *Service) Run(ctx context.Context, c *pkgmanager.ComputationRunReq) (string, error) {
|
||||
func (_m *Service) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) {
|
||||
ret := _m.Called(ctx, c)
|
||||
|
||||
if len(ret) == 0 {
|
||||
@@ -63,16 +62,16 @@ func (_m *Service) Run(ctx context.Context, c *pkgmanager.ComputationRunReq) (st
|
||||
|
||||
var r0 string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) (string, error)); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) (string, error)); ok {
|
||||
return rf(ctx, c)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *pkgmanager.ComputationRunReq) string); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *manager.ComputationRunReq) string); ok {
|
||||
r0 = rf(ctx, c)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *pkgmanager.ComputationRunReq) error); ok {
|
||||
if rf, ok := ret.Get(1).(func(context.Context, *manager.ComputationRunReq) error); ok {
|
||||
r1 = rf(ctx, c)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
|
||||
+27
-31
@@ -26,19 +26,19 @@ const (
|
||||
)
|
||||
|
||||
type qemuVM struct {
|
||||
config Config
|
||||
cmd *exec.Cmd
|
||||
logsChan chan *manager.ClientStreamMessage
|
||||
computationId string
|
||||
config Config
|
||||
cmd *exec.Cmd
|
||||
eventsLogsSender vm.EventSender
|
||||
computationId string
|
||||
vm.StateMachine
|
||||
}
|
||||
|
||||
func NewVM(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) vm.VM {
|
||||
func NewVM(config interface{}, eventsLogsSender vm.EventSender, computationId string) vm.VM {
|
||||
return &qemuVM{
|
||||
config: config.(Config),
|
||||
logsChan: logsChan,
|
||||
computationId: computationId,
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
config: config.(Config),
|
||||
eventsLogsSender: eventsLogsSender,
|
||||
computationId: computationId,
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,8 +74,8 @@ func (v *qemuVM) Start() (err error) {
|
||||
}
|
||||
|
||||
v.cmd = exec.Command(exe, args...)
|
||||
v.cmd.Stdout = &vm.Stdout{LogsChan: v.logsChan, ComputationId: v.computationId}
|
||||
v.cmd.Stderr = &vm.Stderr{LogsChan: v.logsChan, ComputationId: v.computationId, StateMachine: v.StateMachine}
|
||||
v.cmd.Stdout = &vm.Stdout{ComputationId: v.computationId, EventSender: v.eventsLogsSender}
|
||||
v.cmd.Stderr = &vm.Stderr{EventSender: v.eventsLogsSender, ComputationId: v.computationId, StateMachine: v.StateMachine}
|
||||
|
||||
return v.cmd.Start()
|
||||
}
|
||||
@@ -84,16 +84,14 @@ func (v *qemuVM) Stop() error {
|
||||
defer func() {
|
||||
err := v.StateMachine.Transition(manager.StopComputationRun)
|
||||
if err != nil {
|
||||
v.logsChan <- &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
ComputationId: v.computationId,
|
||||
EventType: v.StateMachine.State(),
|
||||
Status: manager.Warning.String(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
},
|
||||
},
|
||||
if err := v.eventsLogsSender(&vm.Event{
|
||||
EventType: v.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: v.computationId,
|
||||
Originator: "manager",
|
||||
Status: manager.Warning.String(),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -160,16 +158,14 @@ func (v *qemuVM) executableAndArgs() (string, []string, error) {
|
||||
func (v *qemuVM) checkVMProcessPeriodically() {
|
||||
for {
|
||||
if !processExists(v.GetProcess()) {
|
||||
v.logsChan <- &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
ComputationId: v.computationId,
|
||||
EventType: v.StateMachine.State(),
|
||||
Status: manager.Stopped.String(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
},
|
||||
},
|
||||
if err := v.eventsLogsSender(&vm.Event{
|
||||
EventType: v.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
ComputationId: v.computationId,
|
||||
Originator: "manager",
|
||||
Status: manager.Stopped.String(),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
+22
-23
@@ -11,16 +11,15 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/manager/vm/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
const testComputationID = "test-computation"
|
||||
|
||||
func TestNewVM(t *testing.T) {
|
||||
config := Config{}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
vm := NewVM(config, logsChan, testComputationID)
|
||||
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID)
|
||||
|
||||
assert.NotNil(t, vm)
|
||||
assert.IsType(t, &qemuVM{}, vm)
|
||||
@@ -38,9 +37,8 @@ func TestStart(t *testing.T) {
|
||||
},
|
||||
QemuBinPath: "echo",
|
||||
}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
|
||||
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
@@ -62,9 +60,8 @@ func TestStartSudo(t *testing.T) {
|
||||
QemuBinPath: "echo",
|
||||
UseSudo: true,
|
||||
}
|
||||
logsChan := make(chan *manager.ClientStreamMessage)
|
||||
|
||||
vm := NewVM(config, logsChan, testComputationID).(*qemuVM)
|
||||
vm := NewVM(config, func(event interface{}) error { return nil }, testComputationID).(*qemuVM)
|
||||
|
||||
err = vm.Start()
|
||||
assert.NoError(t, err)
|
||||
@@ -79,7 +76,7 @@ func TestStop(t *testing.T) {
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
sm := new(mocks.StateMachine)
|
||||
sm.On("Transition", manager.StopComputationRun).Return(nil)
|
||||
sm.On("Transition", pkgmanager.StopComputationRun).Return(nil)
|
||||
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
@@ -96,21 +93,19 @@ func TestStop(t *testing.T) {
|
||||
err := cmd.Start()
|
||||
assert.NoError(t, err)
|
||||
sm := new(mocks.StateMachine)
|
||||
sm.On("Transition", manager.StopComputationRun).Return(assert.AnError)
|
||||
sm.On("State").Return(manager.Stopped.String())
|
||||
sm.On("Transition", pkgmanager.StopComputationRun).Return(assert.AnError)
|
||||
sm.On("State").Return(pkgmanager.Stopped.String())
|
||||
|
||||
vm := &qemuVM{
|
||||
cmd: &exec.Cmd{
|
||||
Process: cmd.Process,
|
||||
},
|
||||
StateMachine: sm,
|
||||
logsChan: make(chan *manager.ClientStreamMessage),
|
||||
eventsLogsSender: func(event interface{}) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-vm.logsChan
|
||||
}()
|
||||
|
||||
err = vm.Stop()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
@@ -168,9 +163,12 @@ func TestGetConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckVMProcessPeriodically(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
vm := &qemuVM{
|
||||
logsChan: logsChan,
|
||||
logsChan := make(chan interface{}, 1)
|
||||
vmi := &qemuVM{
|
||||
eventsLogsSender: func(event interface{}) error {
|
||||
logsChan <- event
|
||||
return nil
|
||||
},
|
||||
computationId: testComputationID,
|
||||
cmd: &exec.Cmd{
|
||||
Process: &os.Process{Pid: -1}, // Use an invalid PID to simulate a stopped process
|
||||
@@ -178,14 +176,15 @@ func TestCheckVMProcessPeriodically(t *testing.T) {
|
||||
StateMachine: vm.NewStateMachine(),
|
||||
}
|
||||
|
||||
go vm.checkVMProcessPeriodically()
|
||||
go vmi.checkVMProcessPeriodically()
|
||||
|
||||
select {
|
||||
case msg := <-logsChan:
|
||||
assert.NotNil(t, msg.GetAgentEvent())
|
||||
assert.Equal(t, testComputationID, msg.GetAgentEvent().ComputationId)
|
||||
assert.Equal(t, manager.VmProvision.String(), msg.GetAgentEvent().EventType)
|
||||
assert.Equal(t, manager.Stopped.String(), msg.GetAgentEvent().Status)
|
||||
assert.NotNil(t, msg)
|
||||
msgE := msg.(*vm.Event)
|
||||
assert.Equal(t, testComputationID, msgE.ComputationId)
|
||||
assert.Equal(t, pkgmanager.VmProvision.String(), msgE.EventType)
|
||||
assert.Equal(t, pkgmanager.Stopped.String(), msgE.Status)
|
||||
case <-time.After(2 * interval):
|
||||
t.Fatal("Timeout waiting for VM stopped message")
|
||||
}
|
||||
|
||||
+41
-11
@@ -57,13 +57,13 @@ var (
|
||||
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Service interface {
|
||||
// Run create a computation.
|
||||
Run(ctx context.Context, c *manager.ComputationRunReq) (string, error)
|
||||
Run(ctx context.Context, c *ComputationRunReq) (string, error)
|
||||
// Stop stops a computation.
|
||||
Stop(ctx context.Context, computationID string) error
|
||||
// RetrieveAgentEventsLogs Retrieve and forward agent logs and events via vsock.
|
||||
RetrieveAgentEventsLogs()
|
||||
// FetchBackendInfo measures and fetches the backend information.
|
||||
FetchBackendInfo(ctx context.Context, computationID string) ([]byte, error)
|
||||
// ReportBrokenConnection reports a broken connection.
|
||||
ReportBrokenConnection(addr string)
|
||||
}
|
||||
|
||||
type managerService struct {
|
||||
@@ -71,7 +71,7 @@ type managerService struct {
|
||||
qemuCfg qemu.Config
|
||||
backendMeasurementBinaryPath string
|
||||
logger *slog.Logger
|
||||
eventsChan chan *manager.ClientStreamMessage
|
||||
eventsChan chan *ClientStreamMessage
|
||||
vms map[string]vm.VM
|
||||
vmFactory vm.Provider
|
||||
portRangeMin int
|
||||
@@ -82,7 +82,7 @@ type managerService struct {
|
||||
var _ Service = (*managerService)(nil)
|
||||
|
||||
// New instantiates the manager service implementation.
|
||||
func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, eventsChan chan *manager.ClientStreamMessage, vmFactory vm.Provider) (Service, error) {
|
||||
func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger, eventsChan chan *ClientStreamMessage, vmFactory vm.Provider) (Service, error) {
|
||||
start, end, err := decodeRange(cfg.HostFwdRange)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -112,7 +112,7 @@ func New(cfg qemu.Config, backendMeasurementBinPath string, logger *slog.Logger,
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq) (string, error) {
|
||||
func (ms *managerService) Run(ctx context.Context, c *ComputationRunReq) (string, error) {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, manager.Starting.String(), json.RawMessage{})
|
||||
ac := agent.Computation{
|
||||
ID: c.Id,
|
||||
@@ -164,7 +164,7 @@ func (ms *managerService) Run(ctx context.Context, c *manager.ComputationRunReq)
|
||||
// Define host-data value of QEMU for SEV-SNP, with a base64 encoding of the computation hash.
|
||||
ms.qemuCfg.SevConfig.HostData = base64.StdEncoding.EncodeToString(ch[:])
|
||||
|
||||
cvm := ms.vmFactory(ms.qemuCfg, ms.eventsChan, c.Id)
|
||||
cvm := ms.vmFactory(ms.qemuCfg, ms.eventsLogsSender, c.Id)
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.InProgress.String(), json.RawMessage{})
|
||||
if err = cvm.Start(); err != nil {
|
||||
ms.publishEvent(manager.VmProvision.String(), c.Id, agent.Failed.String(), json.RawMessage{})
|
||||
@@ -267,9 +267,9 @@ func checkPortisFree(port int) bool {
|
||||
}
|
||||
|
||||
func (ms *managerService) publishEvent(event, cmpID, status string, details json.RawMessage) {
|
||||
ms.eventsChan <- &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &AgentEvent{
|
||||
EventType: event,
|
||||
ComputationId: cmpID,
|
||||
Status: status,
|
||||
@@ -329,7 +329,7 @@ func (ms *managerService) restoreVMs() error {
|
||||
continue
|
||||
}
|
||||
|
||||
cvm := ms.vmFactory(state.Config, ms.eventsChan, state.ID)
|
||||
cvm := ms.vmFactory(state.Config, ms.eventsLogsSender, state.ID)
|
||||
|
||||
if err = cvm.SetProcess(state.PID); err != nil {
|
||||
ms.logger.Warn("Failed to reattach to process", "computation", state.ID, "pid", state.PID, "error", err)
|
||||
@@ -362,3 +362,33 @@ func (ms *managerService) processExists(pid int) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ms *managerService) eventsLogsSender(e interface{}) error {
|
||||
switch msg := e.(type) {
|
||||
case *vm.Event:
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &AgentEvent{
|
||||
EventType: msg.EventType,
|
||||
Timestamp: msg.Timestamp,
|
||||
ComputationId: msg.ComputationId,
|
||||
Originator: msg.Originator,
|
||||
Status: msg.Status,
|
||||
Details: msg.Details,
|
||||
},
|
||||
},
|
||||
}
|
||||
case *vm.Log:
|
||||
ms.eventsChan <- &ClientStreamMessage{
|
||||
Message: &ClientStreamMessage_AgentLog{
|
||||
AgentLog: &AgentLog{
|
||||
ComputationId: msg.ComputationId,
|
||||
Level: msg.Level,
|
||||
Timestamp: msg.Timestamp,
|
||||
Message: msg.Message,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+19
-20
@@ -22,7 +22,6 @@ import (
|
||||
persistenceMocks "github.com/ultravioletrs/cocos/manager/qemu/mocks"
|
||||
"github.com/ultravioletrs/cocos/manager/vm"
|
||||
"github.com/ultravioletrs/cocos/manager/vm/mocks"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
@@ -30,7 +29,7 @@ func TestNew(t *testing.T) {
|
||||
HostFwdRange: "6000-6100",
|
||||
}
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage)
|
||||
eventsChan := make(chan *ClientStreamMessage)
|
||||
vmf := new(mocks.Provider)
|
||||
|
||||
service, err := New(cfg, "", logger, eventsChan, vmf.Execute)
|
||||
@@ -47,59 +46,59 @@ func TestRun(t *testing.T) {
|
||||
vmf.On("Execute", mock.Anything, mock.Anything, mock.Anything).Return(vmMock)
|
||||
tests := []struct {
|
||||
name string
|
||||
req *manager.ComputationRunReq
|
||||
req *ComputationRunReq
|
||||
vmStartError error
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Successful run",
|
||||
req: &manager.ComputationRunReq{
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &manager.Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &manager.AgentConfig{},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
vmStartError: nil,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "VM start failure",
|
||||
req: &manager.ComputationRunReq{
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &manager.Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &manager.AgentConfig{},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
vmStartError: assert.AnError,
|
||||
expectedError: assert.AnError,
|
||||
},
|
||||
{
|
||||
name: "Invalid algorithm hash",
|
||||
req: &manager.ComputationRunReq{
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &manager.Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength-1),
|
||||
},
|
||||
AgentConfig: &manager.AgentConfig{},
|
||||
AgentConfig: &AgentConfig{},
|
||||
},
|
||||
vmStartError: nil,
|
||||
expectedError: errInvalidHashLength,
|
||||
},
|
||||
{
|
||||
name: "Invalid dataset hash",
|
||||
req: &manager.ComputationRunReq{
|
||||
req: &ComputationRunReq{
|
||||
Id: "test-computation",
|
||||
Name: "Test Computation",
|
||||
Algorithm: &manager.Algorithm{
|
||||
Algorithm: &Algorithm{
|
||||
Hash: make([]byte, hashLength),
|
||||
},
|
||||
AgentConfig: &manager.AgentConfig{},
|
||||
Datasets: []*manager.Dataset{
|
||||
AgentConfig: &AgentConfig{},
|
||||
Datasets: []*Dataset{
|
||||
{
|
||||
Hash: make([]byte, hashLength-1),
|
||||
},
|
||||
@@ -130,7 +129,7 @@ func TestRun(t *testing.T) {
|
||||
},
|
||||
}
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 10)
|
||||
eventsChan := make(chan *ClientStreamMessage, 10)
|
||||
|
||||
ms := &managerService{
|
||||
qemuCfg: qemuCfg,
|
||||
@@ -203,7 +202,7 @@ func TestStop(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := slog.Default()
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 10)
|
||||
eventsChan := make(chan *ClientStreamMessage, 10)
|
||||
ms := &managerService{
|
||||
logger: logger,
|
||||
vms: make(map[string]vm.VM),
|
||||
@@ -281,7 +280,7 @@ func TestPublishEvent(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
eventsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
eventsChan := make(chan *ClientStreamMessage, 1)
|
||||
ms := &managerService{
|
||||
eventsChan: eventsChan,
|
||||
}
|
||||
@@ -370,7 +369,7 @@ func TestRestoreVMs(t *testing.T) {
|
||||
ms := &managerService{
|
||||
persistence: mockPersistence,
|
||||
vms: make(map[string]vm.VM),
|
||||
eventsChan: make(chan *manager.ClientStreamMessage, 10),
|
||||
eventsChan: make(chan *ClientStreamMessage, 10),
|
||||
vmFactory: vmf.Execute,
|
||||
logger: mglog.NewMock(),
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"golang.org/x/crypto/sha3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -22,7 +21,7 @@ func New(svc manager.Service, tracer trace.Tracer) manager.Service {
|
||||
return &tracingMiddleware{tracer, svc}
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Run(ctx context.Context, mc *pkgmanager.ComputationRunReq) (string, error) {
|
||||
func (tm *tracingMiddleware) Run(ctx context.Context, mc *manager.ComputationRunReq) (string, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "run")
|
||||
defer span.End()
|
||||
|
||||
@@ -36,13 +35,13 @@ func (tm *tracingMiddleware) Stop(ctx context.Context, computationID string) err
|
||||
return tm.svc.Stop(ctx, computationID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RetrieveAgentEventsLogs() {
|
||||
tm.svc.RetrieveAgentEventsLogs()
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) FetchBackendInfo(ctx context.Context, computationId string) ([]byte, error) {
|
||||
_, span := tm.tracer.Start(ctx, "fetch_backend_info")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.FetchBackendInfo(ctx, computationId)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ReportBrokenConnection(addr string) {
|
||||
tm.svc.ReportBrokenConnection(addr)
|
||||
}
|
||||
|
||||
+21
-58
@@ -4,46 +4,26 @@ package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
var (
|
||||
_ io.Writer = &Stdout{}
|
||||
_ io.Writer = &Stderr{}
|
||||
ErrFailedToSendMessage = errors.New("failed to send message to channel")
|
||||
ErrPanicRecovered = errors.New("panic recovered: channel may be closed")
|
||||
_ io.Writer = &Stdout{}
|
||||
_ io.Writer = &Stderr{}
|
||||
)
|
||||
|
||||
const bufSize = 1024
|
||||
|
||||
type Stdout struct {
|
||||
LogsChan chan *manager.ClientStreamMessage
|
||||
EventSender EventSender
|
||||
ComputationId string
|
||||
}
|
||||
|
||||
// safeSend safely sends a message to the channel and returns an error on failure.
|
||||
func safeSend(ch chan *manager.ClientStreamMessage, msg *manager.ClientStreamMessage) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Recover from panic if the channel is closed
|
||||
err = ErrPanicRecovered
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case ch <- msg:
|
||||
return nil
|
||||
default:
|
||||
// Channel is full or closed
|
||||
return ErrFailedToSendMessage
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (s *Stdout) Write(p []byte) (n int, err error) {
|
||||
inBuf := bytes.NewBuffer(p)
|
||||
@@ -59,7 +39,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) {
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
|
||||
if err := sendLog(s.LogsChan, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil {
|
||||
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), slog.LevelDebug.String()); err != nil {
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
}
|
||||
@@ -68,7 +48,7 @@ func (s *Stdout) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
type Stderr struct {
|
||||
LogsChan chan *manager.ClientStreamMessage
|
||||
EventSender EventSender
|
||||
ComputationId string
|
||||
StateMachine StateMachine
|
||||
}
|
||||
@@ -88,32 +68,23 @@ func (s *Stderr) Write(p []byte) (n int, err error) {
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
|
||||
if err := sendLog(s.LogsChan, s.ComputationId, string(buf[:n]), ""); err != nil {
|
||||
if err := sendLog(s.EventSender, s.ComputationId, string(buf[:n]), ""); err != nil {
|
||||
return len(p) - inBuf.Len(), err
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure vm-provision failure message is sent
|
||||
eventMsg := &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentEvent{
|
||||
AgentEvent: &manager.AgentEvent{
|
||||
ComputationId: s.ComputationId,
|
||||
EventType: s.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
Status: manager.Warning.String(),
|
||||
},
|
||||
},
|
||||
eventMsg := &Event{
|
||||
ComputationId: s.ComputationId,
|
||||
EventType: s.StateMachine.State(),
|
||||
Timestamp: timestamppb.Now(),
|
||||
Originator: "manager",
|
||||
Status: pkgmanager.Warning.String(),
|
||||
}
|
||||
|
||||
if err := safeSend(s.LogsChan, eventMsg); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
return len(p), s.EventSender(eventMsg)
|
||||
}
|
||||
|
||||
func sendLog(logsChan chan *manager.ClientStreamMessage, computationID, message, level string) error {
|
||||
func sendLog(eventSender EventSender, computationID, message, level string) error {
|
||||
if len(message) < 3 {
|
||||
return nil
|
||||
}
|
||||
@@ -126,20 +97,12 @@ func sendLog(logsChan chan *manager.ClientStreamMessage, computationID, message,
|
||||
}
|
||||
}
|
||||
|
||||
msg := &manager.ClientStreamMessage{
|
||||
Message: &manager.ClientStreamMessage_AgentLog{
|
||||
AgentLog: &manager.AgentLog{
|
||||
Message: message,
|
||||
ComputationId: computationID,
|
||||
Level: level,
|
||||
Timestamp: timestamppb.Now(),
|
||||
},
|
||||
},
|
||||
msg := Log{
|
||||
Message: message,
|
||||
ComputationId: computationID,
|
||||
Level: level,
|
||||
Timestamp: timestamppb.Now(),
|
||||
}
|
||||
|
||||
if err := safeSend(logsChan, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return eventSender(&msg)
|
||||
}
|
||||
|
||||
+40
-36
@@ -8,7 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
func TestStdoutWrite(t *testing.T) {
|
||||
@@ -36,9 +36,12 @@ func TestStdoutWrite(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 10)
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stdout{
|
||||
LogsChan: logsChan,
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return nil
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
@@ -50,9 +53,9 @@ func TestStdoutWrite(t *testing.T) {
|
||||
var receivedWrites int
|
||||
for i := 0; i < tt.expectedWrites; i++ {
|
||||
select {
|
||||
case msg := <-logsChan:
|
||||
case msg := <-eventLogChan:
|
||||
receivedWrites++
|
||||
agentLog := msg.GetAgentLog()
|
||||
agentLog := msg.(*Log)
|
||||
assert.NotNil(t, agentLog)
|
||||
assert.Equal(t, "test-computation", agentLog.ComputationId)
|
||||
assert.Equal(t, slog.LevelDebug.String(), agentLog.Level)
|
||||
@@ -93,14 +96,17 @@ func TestStderrWrite(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 10)
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stderr{
|
||||
LogsChan: logsChan,
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return nil
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
StateMachine: NewStateMachine(),
|
||||
}
|
||||
|
||||
err := s.StateMachine.Transition(manager.VmRunning)
|
||||
err := s.StateMachine.Transition(pkgmanager.VmRunning)
|
||||
assert.NoError(t, err)
|
||||
|
||||
n, err := s.Write([]byte(tt.input))
|
||||
@@ -111,23 +117,21 @@ func TestStderrWrite(t *testing.T) {
|
||||
var receivedWrites int
|
||||
for i := 0; i < tt.expectedWrites; i++ {
|
||||
select {
|
||||
case msg := <-logsChan:
|
||||
case msg := <-eventLogChan:
|
||||
receivedWrites++
|
||||
switch msg.Message.(type) {
|
||||
case *manager.ClientStreamMessage_AgentLog:
|
||||
agentLog := msg.GetAgentLog()
|
||||
assert.NotNil(t, agentLog)
|
||||
assert.Equal(t, "test-computation", agentLog.ComputationId)
|
||||
assert.Equal(t, slog.LevelError.String(), agentLog.Level)
|
||||
assert.NotEmpty(t, agentLog.Message)
|
||||
assert.NotNil(t, agentLog.Timestamp)
|
||||
case *manager.ClientStreamMessage_AgentEvent:
|
||||
agentEvent := msg.GetAgentEvent()
|
||||
assert.NotNil(t, agentEvent)
|
||||
assert.Equal(t, "test-computation", agentEvent.ComputationId)
|
||||
assert.Equal(t, manager.VmRunning.String(), agentEvent.EventType)
|
||||
assert.Equal(t, manager.Warning.String(), agentEvent.Status)
|
||||
assert.NotNil(t, agentEvent.Timestamp)
|
||||
switch logEv := msg.(type) {
|
||||
case *Log:
|
||||
assert.NotNil(t, logEv)
|
||||
assert.Equal(t, "test-computation", logEv.ComputationId)
|
||||
assert.Equal(t, slog.LevelError.String(), logEv.Level)
|
||||
assert.NotEmpty(t, logEv.Message)
|
||||
assert.NotNil(t, logEv.Timestamp)
|
||||
case *Event:
|
||||
assert.NotNil(t, logEv)
|
||||
assert.Equal(t, "test-computation", logEv.ComputationId)
|
||||
assert.Equal(t, pkgmanager.VmRunning.String(), logEv.EventType)
|
||||
assert.Equal(t, pkgmanager.Warning.String(), logEv.Status)
|
||||
assert.NotNil(t, logEv.Timestamp)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for log message")
|
||||
@@ -140,37 +144,37 @@ func TestStderrWrite(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStdoutWriteErrorHandling(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stdout{
|
||||
LogsChan: logsChan,
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return assert.AnError
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
// Test with a closed channel to simulate an error condition
|
||||
close(logsChan)
|
||||
|
||||
message := []byte("This should fail")
|
||||
n, err := s.Write(message)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, len(message), n)
|
||||
assert.Equal(t, ErrPanicRecovered, err)
|
||||
assert.Equal(t, assert.AnError, err)
|
||||
}
|
||||
|
||||
func TestStderrWriteErrorHandling(t *testing.T) {
|
||||
logsChan := make(chan *manager.ClientStreamMessage, 1)
|
||||
eventLogChan := make(chan interface{}, 10)
|
||||
s := &Stderr{
|
||||
LogsChan: logsChan,
|
||||
EventSender: func(event interface{}) error {
|
||||
eventLogChan <- event
|
||||
return assert.AnError
|
||||
},
|
||||
ComputationId: "test-computation",
|
||||
}
|
||||
|
||||
// Test with a closed channel to simulate an error condition
|
||||
close(logsChan)
|
||||
|
||||
message := []byte("This should fail")
|
||||
n, err := s.Write(message)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, len(message), n)
|
||||
assert.Equal(t, ErrPanicRecovered, err)
|
||||
assert.Equal(t, assert.AnError, err)
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@ package mocks
|
||||
|
||||
import (
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
manager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
|
||||
vm "github.com/ultravioletrs/cocos/manager/vm"
|
||||
)
|
||||
|
||||
@@ -17,17 +15,17 @@ type Provider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Execute provides a mock function with given fields: config, logsChan, computationId
|
||||
func (_m *Provider) Execute(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) vm.VM {
|
||||
ret := _m.Called(config, logsChan, computationId)
|
||||
// Execute provides a mock function with given fields: config, eventSender, computationId
|
||||
func (_m *Provider) Execute(config interface{}, eventSender vm.EventSender, computationId string) vm.VM {
|
||||
ret := _m.Called(config, eventSender, computationId)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Execute")
|
||||
}
|
||||
|
||||
var r0 vm.VM
|
||||
if rf, ok := ret.Get(0).(func(interface{}, chan *manager.ClientStreamMessage, string) vm.VM); ok {
|
||||
r0 = rf(config, logsChan, computationId)
|
||||
if rf, ok := ret.Get(0).(func(interface{}, vm.EventSender, string) vm.VM); ok {
|
||||
r0 = rf(config, eventSender, computationId)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(vm.VM)
|
||||
|
||||
+26
-3
@@ -4,7 +4,8 @@ package vm
|
||||
|
||||
import (
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// VM represents a virtual machine.
|
||||
@@ -17,10 +18,32 @@ type VM interface {
|
||||
SetProcess(pid int) error
|
||||
GetProcess() int
|
||||
GetCID() int
|
||||
Transition(newState manager.ManagerState) error
|
||||
Transition(newState pkgmanager.ManagerState) error
|
||||
State() string
|
||||
GetConfig() interface{}
|
||||
}
|
||||
|
||||
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
|
||||
type Provider func(config interface{}, logsChan chan *manager.ClientStreamMessage, computationId string) VM
|
||||
type Provider func(config interface{}, eventSender EventSender, computationId string) VM
|
||||
|
||||
type Event struct {
|
||||
EventType string
|
||||
Timestamp *timestamppb.Timestamp
|
||||
ComputationId string
|
||||
Details []byte
|
||||
Originator string
|
||||
Status string
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Message string
|
||||
ComputationId string
|
||||
Level string
|
||||
Timestamp *timestamppb.Timestamp
|
||||
}
|
||||
|
||||
func (l *Log) IsEventLog() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type EventSender func(event interface{}) error
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/pkg/clients/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
// NewManagerClient creates new manager gRPC client instance.
|
||||
|
||||
@@ -19,4 +19,5 @@ const (
|
||||
Stopped
|
||||
Warning
|
||||
Disconnected
|
||||
Failed
|
||||
)
|
||||
|
||||
@@ -12,11 +12,12 @@ func _() {
|
||||
_ = x[Stopped-1]
|
||||
_ = x[Warning-2]
|
||||
_ = x[Disconnected-3]
|
||||
_ = x[Failed-4]
|
||||
}
|
||||
|
||||
const _ManagerStatus_name = "StartingStoppedWarningDisconnected"
|
||||
const _ManagerStatus_name = "StartingStoppedWarningDisconnectedFailed"
|
||||
|
||||
var _ManagerStatus_index = [...]uint8{0, 8, 15, 22, 34}
|
||||
var _ManagerStatus_index = [...]uint8{0, 8, 15, 22, 34, 40}
|
||||
|
||||
func (i ManagerStatus) String() string {
|
||||
if i >= ManagerStatus(len(_ManagerStatus_index)-1) {
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
@@ -18,12 +18,12 @@ import (
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/events"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
managerVsockPort = manager.ManagerVsockPort
|
||||
managerVsockPort = events.ManagerVsockPort
|
||||
vsockConfigPort = qemu.VsockConfigPort
|
||||
)
|
||||
|
||||
@@ -115,7 +115,7 @@ func handleConnection(conn net.Conn) {
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
var message pkgmanager.ClientStreamMessage
|
||||
var message manager.ClientStreamMessage
|
||||
err := ackReader.ReadProto(&message)
|
||||
if err != nil {
|
||||
log.Printf("Error reading message: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user