mirror of
https://github.com/ultravioletrs/cocos.git
synced 2026-06-23 04:10:25 +00:00
NOISSUE - Reduce message loss via vsock with acks (#252)
* state check within func Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug logs sending Signed-off-by: Sammy Oina <sammyoina@gmail.com> * debug message sending Signed-off-by: Sammy Oina <sammyoina@gmail.com> * ack messages Signed-off-by: Sammy Oina <sammyoina@gmail.com> * handle proto better Signed-off-by: Sammy Oina <sammyoina@gmail.com> * improve concurrency Signed-off-by: Sammy Oina <sammyoina@gmail.com> * improve manager handling Signed-off-by: Sammy Oina <sammyoina@gmail.com> * remove debug lines Signed-off-by: Sammy Oina <sammyoina@gmail.com> * sync next id Signed-off-by: Sammy Oina <sammyoina@gmail.com> * reduce locks Signed-off-by: Sammy Oina <sammyoina@gmail.com> --------- Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
committed by
GitHub
parent
df923f9b1f
commit
5d5ae35e2b
@@ -4,10 +4,10 @@ package events
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@@ -18,7 +18,7 @@ const retryInterval = 5 * time.Second
|
||||
type service struct {
|
||||
service string
|
||||
computationID string
|
||||
conn *vsock.Conn
|
||||
conn io.Writer
|
||||
cachedMessages [][]byte
|
||||
mutex sync.Mutex
|
||||
stopRetry chan struct{}
|
||||
@@ -39,7 +39,7 @@ type Service interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
func New(svc, computationID string, conn *vsock.Conn) (Service, error) {
|
||||
func New(svc, computationID string, conn io.Writer) (Service, error) {
|
||||
s := &service{
|
||||
service: svc,
|
||||
computationID: computationID,
|
||||
|
||||
+3
-3
@@ -236,7 +236,7 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
|
||||
}
|
||||
|
||||
if len(as.computation.Datasets) == 0 {
|
||||
as.sm.SendEvent(dataReceived)
|
||||
defer as.sm.SendEvent(dataReceived)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -256,7 +256,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
|
||||
as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1)
|
||||
|
||||
if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults {
|
||||
as.sm.SendEvent(resultsConsumed)
|
||||
defer as.sm.SendEvent(resultsConsumed)
|
||||
}
|
||||
|
||||
return as.result, as.runError
|
||||
@@ -320,8 +320,8 @@ func (as *agentService) runComputation() {
|
||||
}
|
||||
|
||||
func (as *agentService) publishEvent(status string, details json.RawMessage) func() {
|
||||
st := as.sm.GetState().String()
|
||||
return func() {
|
||||
st := as.sm.GetState().String()
|
||||
if err := as.eventSvc.SendEvent(st, status, details); err != nil {
|
||||
as.sm.logger.Warn(err.Error())
|
||||
}
|
||||
|
||||
+15
-9
@@ -104,22 +104,28 @@ func (sm *StateMachine) Start(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case event := <-sm.EventChan:
|
||||
currentState := sm.GetState()
|
||||
var nextState State
|
||||
var stateFunc func()
|
||||
var valid bool
|
||||
|
||||
sm.mu.Lock()
|
||||
nextState, valid := sm.Transitions[sm.State][event]
|
||||
nextState, valid = sm.Transitions[sm.State][event]
|
||||
if valid {
|
||||
sm.State = nextState
|
||||
sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", sm.State, nextState))
|
||||
} else {
|
||||
sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State))
|
||||
stateFunc = sm.StateFunctions[nextState]
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
sm.mu.Lock()
|
||||
stateFunc, exists := sm.StateFunctions[sm.State]
|
||||
sm.mu.Unlock()
|
||||
if exists {
|
||||
go stateFunc()
|
||||
if valid {
|
||||
sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState))
|
||||
if stateFunc != nil {
|
||||
go stateFunc()
|
||||
}
|
||||
} else {
|
||||
sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State))
|
||||
}
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
+5
-4
@@ -25,6 +25,7 @@ import (
|
||||
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
|
||||
"github.com/ultravioletrs/cocos/internal/server"
|
||||
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
|
||||
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -53,6 +54,8 @@ func main() {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ackConn := ackvsock.NewAckWriter(conn)
|
||||
|
||||
var exitCode int
|
||||
defer mglog.ExitWithError(&exitCode)
|
||||
|
||||
@@ -63,10 +66,10 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID)
|
||||
handler := agentlogger.NewProtoHandler(ackConn, &slog.HandlerOptions{Level: level}, cfg.ID)
|
||||
logger := slog.New(handler)
|
||||
|
||||
eventSvc, err := events.New(svcName, cfg.ID, conn)
|
||||
eventSvc, err := events.New(svcName, cfg.ID, ackConn)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create events service %s", err.Error()))
|
||||
exitCode = 1
|
||||
@@ -116,8 +119,6 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal("failed to reconnect: ", err)
|
||||
}
|
||||
handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID)
|
||||
logger = slog.New(handler)
|
||||
}
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
// Copyright (c) Ultraviolet
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package vsock
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
maxRetries = 3
|
||||
retryDelay = time.Second
|
||||
maxMessageSize = 1 << 20 // 1 MB
|
||||
ackTimeout = 5 * time.Second
|
||||
maxConcurrent = 100 // Maximum number of concurrent messages
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint32
|
||||
Content []byte
|
||||
}
|
||||
|
||||
type AckWriter struct {
|
||||
conn net.Conn
|
||||
pendingMessages chan *Message
|
||||
ackChannels map[uint32]chan bool
|
||||
ackMu sync.RWMutex
|
||||
nextID uint32
|
||||
done chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAckWriter(conn net.Conn) *AckWriter {
|
||||
aw := &AckWriter{
|
||||
conn: conn,
|
||||
pendingMessages: make(chan *Message, maxConcurrent),
|
||||
ackChannels: make(map[uint32]chan bool),
|
||||
nextID: 1,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
aw.wg.Add(2)
|
||||
go aw.sendMessages()
|
||||
go aw.handleAcknowledgments()
|
||||
return aw
|
||||
}
|
||||
|
||||
func (aw *AckWriter) WriteProto(msg proto.Message) (int, error) {
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error marshaling protobuf message: %v", err)
|
||||
}
|
||||
return aw.Write(data)
|
||||
}
|
||||
|
||||
func (aw *AckWriter) Write(p []byte) (int, error) {
|
||||
if len(p) > maxMessageSize {
|
||||
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
|
||||
}
|
||||
|
||||
aw.ackMu.Lock()
|
||||
messageID := aw.nextID
|
||||
aw.nextID++
|
||||
|
||||
ackCh := make(chan bool, 1)
|
||||
aw.ackChannels[messageID] = ackCh
|
||||
aw.ackMu.Unlock()
|
||||
|
||||
message := &Message{ID: messageID, Content: p}
|
||||
|
||||
select {
|
||||
case aw.pendingMessages <- message:
|
||||
// Message queued successfully
|
||||
case <-aw.done:
|
||||
return 0, fmt.Errorf("writer is closed")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ackCh:
|
||||
return len(p), nil
|
||||
case <-time.After(ackTimeout):
|
||||
return 0, fmt.Errorf("timeout waiting for acknowledgment")
|
||||
case <-aw.done:
|
||||
return 0, fmt.Errorf("writer closed while waiting for acknowledgment")
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *AckWriter) sendMessages() {
|
||||
defer aw.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-aw.done:
|
||||
return
|
||||
case msg := <-aw.pendingMessages:
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
if err := aw.writeMessage(msg.ID, msg.Content); err != nil {
|
||||
log.Printf("Error writing message %d (attempt %d): %v", msg.ID, i+1, err)
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
|
||||
// Write message ID
|
||||
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write message length
|
||||
messageLen := uint32(len(p))
|
||||
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write message content
|
||||
if _, err := aw.conn.Write(p); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (aw *AckWriter) handleAcknowledgments() {
|
||||
defer aw.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-aw.done:
|
||||
return
|
||||
default:
|
||||
var ackID uint32
|
||||
err := binary.Read(aw.conn, binary.LittleEndian, &ackID)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
log.Println("Connection closed, stopping acknowledgment handler")
|
||||
return
|
||||
}
|
||||
log.Printf("Error reading ACK: %v", err)
|
||||
time.Sleep(retryDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
aw.ackMu.RLock()
|
||||
ackCh, ok := aw.ackChannels[ackID]
|
||||
aw.ackMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
select {
|
||||
case ackCh <- true:
|
||||
default:
|
||||
// Channel is already closed or full
|
||||
}
|
||||
aw.ackMu.Lock()
|
||||
delete(aw.ackChannels, ackID)
|
||||
aw.ackMu.Unlock()
|
||||
} else {
|
||||
log.Printf("Received ACK for unknown message ID: %d", ackID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *AckWriter) Close() error {
|
||||
close(aw.done)
|
||||
aw.wg.Wait()
|
||||
return aw.conn.Close()
|
||||
}
|
||||
|
||||
type AckReader struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func NewAckReader(conn net.Conn) *AckReader {
|
||||
return &AckReader{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (ar *AckReader) ReadProto(msg proto.Message) error {
|
||||
data, err := ar.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return proto.Unmarshal(data, msg)
|
||||
}
|
||||
|
||||
func (ar *AckReader) Read() ([]byte, error) {
|
||||
var messageID uint32
|
||||
if err := binary.Read(ar.conn, binary.LittleEndian, &messageID); err != nil {
|
||||
return nil, fmt.Errorf("error reading message ID: %v", err)
|
||||
}
|
||||
|
||||
var messageLen uint32
|
||||
if err := binary.Read(ar.conn, binary.LittleEndian, &messageLen); err != nil {
|
||||
return nil, fmt.Errorf("error reading message length: %v", err)
|
||||
}
|
||||
|
||||
if messageLen > maxMessageSize {
|
||||
return nil, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
|
||||
}
|
||||
|
||||
data := make([]byte, messageLen)
|
||||
_, err := io.ReadFull(ar.conn, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading message content: %v", err)
|
||||
}
|
||||
|
||||
if err := ar.sendAck(messageID); err != nil {
|
||||
return nil, fmt.Errorf("error sending ACK: %v", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (ar *AckReader) sendAck(messageID uint32) error {
|
||||
return binary.Write(ar.conn, binary.LittleEndian, messageID)
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/mdlayher/vsock"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@@ -40,30 +41,35 @@ func (ms *managerService) RetrieveAgentEventsLogs() {
|
||||
continue
|
||||
}
|
||||
|
||||
go ms.handleConnections(conn)
|
||||
go ms.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *managerService) handleConnections(conn net.Conn) {
|
||||
func (ms *managerService) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
b := make([]byte, messageSize)
|
||||
n, err := conn.Read(b)
|
||||
var message manager.ClientStreamMessage
|
||||
data, err := ackReader.Read()
|
||||
if err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
go ms.reportBrokenConnection(cmpID)
|
||||
ms.logger.Warn(err.Error())
|
||||
return
|
||||
}
|
||||
var message manager.ClientStreamMessage
|
||||
if err := proto.Unmarshal(b[:n], &message); err != nil {
|
||||
|
||||
if err := proto.Unmarshal(data, &message); err != nil {
|
||||
ms.logger.Warn(err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
ms.eventsChan <- &message
|
||||
|
||||
args := []any{}
|
||||
|
||||
@@ -8,7 +8,6 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
@@ -17,10 +16,15 @@ import (
|
||||
"github.com/mdlayher/vsock"
|
||||
"github.com/ultravioletrs/cocos/agent"
|
||||
"github.com/ultravioletrs/cocos/internal"
|
||||
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
|
||||
"github.com/ultravioletrs/cocos/manager"
|
||||
"github.com/ultravioletrs/cocos/manager/qemu"
|
||||
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
managerVsockPort = manager.ManagerVsockPort
|
||||
vsockConfigPort = qemu.VsockConfigPort
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -50,10 +54,6 @@ func main() {
|
||||
log.Fatalf("failed to calculate checksum: %s", err)
|
||||
}
|
||||
|
||||
l, err := vsock.Listen(manager.ManagerVsockPort, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
ac := agent.Computation{
|
||||
ID: "123",
|
||||
Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}},
|
||||
@@ -65,21 +65,30 @@ func main() {
|
||||
AttestedTls: attestedTLS,
|
||||
},
|
||||
}
|
||||
if err := SendAgentConfig(3, ac); err != nil {
|
||||
if err := sendAgentConfig(3, ac); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
listener, err := vsock.Listen(managerVsockPort, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen on vsock: %s", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
log.Printf("Listening on vsock port %d", managerVsockPort)
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Printf("failed to accept connection: %s", err)
|
||||
continue
|
||||
}
|
||||
go handleConnections(conn)
|
||||
|
||||
go handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func SendAgentConfig(cid uint32, ac agent.Computation) error {
|
||||
func sendAgentConfig(cid uint32, ac agent.Computation) error {
|
||||
conn, err := vsock.Dial(cid, qemu.VsockConfigPort, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -100,20 +109,19 @@ func SendAgentConfig(cid uint32, ac agent.Computation) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleConnections(conn net.Conn) {
|
||||
func handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
ackReader := internalvsock.NewAckReader(conn)
|
||||
|
||||
for {
|
||||
b := make([]byte, 1024)
|
||||
n, err := conn.Read(b)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
var message pkgmanager.ClientStreamMessage
|
||||
if err := proto.Unmarshal(b[:n], &message); err != nil {
|
||||
log.Println(err)
|
||||
err := ackReader.ReadProto(&message)
|
||||
if err != nil {
|
||||
log.Printf("Error reading message: %v", err)
|
||||
return
|
||||
}
|
||||
fmt.Println(message.String())
|
||||
|
||||
log.Printf("Received message: %s", message.String())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user