NOISSUE - Reduce message loss via vsock with acks (#252)

* state check within func

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug logs sending

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* debug message sending

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* ack messages

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* handle proto better

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve concurrency

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* improve manager handling

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* remove debug lines

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* sync next id

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* reduce locks

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-09-23 19:38:02 +03:00
committed by GitHub
parent df923f9b1f
commit 5d5ae35e2b
7 changed files with 297 additions and 48 deletions
+3 -3
View File
@@ -4,10 +4,10 @@ package events
import (
"encoding/json"
"io"
"sync"
"time"
"github.com/mdlayher/vsock"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -18,7 +18,7 @@ const retryInterval = 5 * time.Second
type service struct {
service string
computationID string
conn *vsock.Conn
conn io.Writer
cachedMessages [][]byte
mutex sync.Mutex
stopRetry chan struct{}
@@ -39,7 +39,7 @@ type Service interface {
Close()
}
func New(svc, computationID string, conn *vsock.Conn) (Service, error) {
func New(svc, computationID string, conn io.Writer) (Service, error) {
s := &service{
service: svc,
computationID: computationID,
+3 -3
View File
@@ -236,7 +236,7 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error {
}
if len(as.computation.Datasets) == 0 {
as.sm.SendEvent(dataReceived)
defer as.sm.SendEvent(dataReceived)
}
return nil
@@ -256,7 +256,7 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) {
as.computation.ResultConsumers = slices.Delete(as.computation.ResultConsumers, index, index+1)
if len(as.computation.ResultConsumers) == 0 && as.sm.GetState() == ConsumingResults {
as.sm.SendEvent(resultsConsumed)
defer as.sm.SendEvent(resultsConsumed)
}
return as.result, as.runError
@@ -320,8 +320,8 @@ func (as *agentService) runComputation() {
}
func (as *agentService) publishEvent(status string, details json.RawMessage) func() {
st := as.sm.GetState().String()
return func() {
st := as.sm.GetState().String()
if err := as.eventSvc.SendEvent(st, status, details); err != nil {
as.sm.logger.Warn(err.Error())
}
+15 -9
View File
@@ -104,22 +104,28 @@ func (sm *StateMachine) Start(ctx context.Context) {
for {
select {
case event := <-sm.EventChan:
currentState := sm.GetState()
var nextState State
var stateFunc func()
var valid bool
sm.mu.Lock()
nextState, valid := sm.Transitions[sm.State][event]
nextState, valid = sm.Transitions[sm.State][event]
if valid {
sm.State = nextState
sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", sm.State, nextState))
} else {
sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State))
stateFunc = sm.StateFunctions[nextState]
}
sm.mu.Unlock()
sm.mu.Lock()
stateFunc, exists := sm.StateFunctions[sm.State]
sm.mu.Unlock()
if exists {
go stateFunc()
if valid {
sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", currentState, nextState))
if stateFunc != nil {
go stateFunc()
}
} else {
sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State))
}
case <-ctx.Done():
return
}
+5 -4
View File
@@ -25,6 +25,7 @@ import (
agentlogger "github.com/ultravioletrs/cocos/internal/logger"
"github.com/ultravioletrs/cocos/internal/server"
grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc"
ackvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/qemu"
"golang.org/x/sync/errgroup"
@@ -53,6 +54,8 @@ func main() {
}
defer conn.Close()
ackConn := ackvsock.NewAckWriter(conn)
var exitCode int
defer mglog.ExitWithError(&exitCode)
@@ -63,10 +66,10 @@ func main() {
return
}
handler := agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID)
handler := agentlogger.NewProtoHandler(ackConn, &slog.HandlerOptions{Level: level}, cfg.ID)
logger := slog.New(handler)
eventSvc, err := events.New(svcName, cfg.ID, conn)
eventSvc, err := events.New(svcName, cfg.ID, ackConn)
if err != nil {
logger.Error(fmt.Sprintf("failed to create events service %s", err.Error()))
exitCode = 1
@@ -116,8 +119,6 @@ func main() {
if err != nil {
log.Fatal("failed to reconnect: ", err)
}
handler = agentlogger.NewProtoHandler(conn, &slog.HandlerOptions{Level: level}, cfg.ID)
logger = slog.New(handler)
}
time.Sleep(retryInterval)
}
+228
View File
@@ -0,0 +1,228 @@
// Copyright (c) Ultraviolet
// SPDX-License-Identifier: Apache-2.0
package vsock
import (
"encoding/binary"
"fmt"
"io"
"log"
"net"
"sync"
"time"
"google.golang.org/protobuf/proto"
)
const (
maxRetries = 3
retryDelay = time.Second
maxMessageSize = 1 << 20 // 1 MB
ackTimeout = 5 * time.Second
maxConcurrent = 100 // Maximum number of concurrent messages
)
type Message struct {
ID uint32
Content []byte
}
type AckWriter struct {
conn net.Conn
pendingMessages chan *Message
ackChannels map[uint32]chan bool
ackMu sync.RWMutex
nextID uint32
done chan struct{}
wg sync.WaitGroup
}
func NewAckWriter(conn net.Conn) *AckWriter {
aw := &AckWriter{
conn: conn,
pendingMessages: make(chan *Message, maxConcurrent),
ackChannels: make(map[uint32]chan bool),
nextID: 1,
done: make(chan struct{}),
}
aw.wg.Add(2)
go aw.sendMessages()
go aw.handleAcknowledgments()
return aw
}
func (aw *AckWriter) WriteProto(msg proto.Message) (int, error) {
data, err := proto.Marshal(msg)
if err != nil {
return 0, fmt.Errorf("error marshaling protobuf message: %v", err)
}
return aw.Write(data)
}
func (aw *AckWriter) Write(p []byte) (int, error) {
if len(p) > maxMessageSize {
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
}
aw.ackMu.Lock()
messageID := aw.nextID
aw.nextID++
ackCh := make(chan bool, 1)
aw.ackChannels[messageID] = ackCh
aw.ackMu.Unlock()
message := &Message{ID: messageID, Content: p}
select {
case aw.pendingMessages <- message:
// Message queued successfully
case <-aw.done:
return 0, fmt.Errorf("writer is closed")
}
select {
case <-ackCh:
return len(p), nil
case <-time.After(ackTimeout):
return 0, fmt.Errorf("timeout waiting for acknowledgment")
case <-aw.done:
return 0, fmt.Errorf("writer closed while waiting for acknowledgment")
}
}
func (aw *AckWriter) sendMessages() {
defer aw.wg.Done()
for {
select {
case <-aw.done:
return
case msg := <-aw.pendingMessages:
for i := 0; i < maxRetries; i++ {
if err := aw.writeMessage(msg.ID, msg.Content); err != nil {
log.Printf("Error writing message %d (attempt %d): %v", msg.ID, i+1, err)
time.Sleep(retryDelay)
continue
}
break
}
}
}
}
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
// Write message ID
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
return err
}
// Write message length
messageLen := uint32(len(p))
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
return err
}
// Write message content
if _, err := aw.conn.Write(p); err != nil {
return err
}
return nil
}
func (aw *AckWriter) handleAcknowledgments() {
defer aw.wg.Done()
for {
select {
case <-aw.done:
return
default:
var ackID uint32
err := binary.Read(aw.conn, binary.LittleEndian, &ackID)
if err != nil {
if err == io.EOF {
log.Println("Connection closed, stopping acknowledgment handler")
return
}
log.Printf("Error reading ACK: %v", err)
time.Sleep(retryDelay)
continue
}
aw.ackMu.RLock()
ackCh, ok := aw.ackChannels[ackID]
aw.ackMu.RUnlock()
if ok {
select {
case ackCh <- true:
default:
// Channel is already closed or full
}
aw.ackMu.Lock()
delete(aw.ackChannels, ackID)
aw.ackMu.Unlock()
} else {
log.Printf("Received ACK for unknown message ID: %d", ackID)
}
}
}
}
func (aw *AckWriter) Close() error {
close(aw.done)
aw.wg.Wait()
return aw.conn.Close()
}
type AckReader struct {
conn net.Conn
}
func NewAckReader(conn net.Conn) *AckReader {
return &AckReader{
conn: conn,
}
}
func (ar *AckReader) ReadProto(msg proto.Message) error {
data, err := ar.Read()
if err != nil {
return err
}
return proto.Unmarshal(data, msg)
}
func (ar *AckReader) Read() ([]byte, error) {
var messageID uint32
if err := binary.Read(ar.conn, binary.LittleEndian, &messageID); err != nil {
return nil, fmt.Errorf("error reading message ID: %v", err)
}
var messageLen uint32
if err := binary.Read(ar.conn, binary.LittleEndian, &messageLen); err != nil {
return nil, fmt.Errorf("error reading message length: %v", err)
}
if messageLen > maxMessageSize {
return nil, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
}
data := make([]byte, messageLen)
_, err := io.ReadFull(ar.conn, data)
if err != nil {
return nil, fmt.Errorf("error reading message content: %v", err)
}
if err := ar.sendAck(messageID); err != nil {
return nil, fmt.Errorf("error sending ACK: %v", err)
}
return data, nil
}
func (ar *AckReader) sendAck(messageID uint32) error {
return binary.Write(ar.conn, binary.LittleEndian, messageID)
}
+14 -8
View File
@@ -10,6 +10,7 @@ import (
"strconv"
"github.com/mdlayher/vsock"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -40,30 +41,35 @@ func (ms *managerService) RetrieveAgentEventsLogs() {
continue
}
go ms.handleConnections(conn)
go ms.handleConnection(conn)
}
}
func (ms *managerService) handleConnections(conn net.Conn) {
func (ms *managerService) handleConnection(conn net.Conn) {
defer conn.Close()
cmpID, err := ms.computationIDFromAddress(conn.RemoteAddr().String())
if err != nil {
ms.logger.Warn(err.Error())
return
}
defer conn.Close()
ackReader := internalvsock.NewAckReader(conn)
for {
b := make([]byte, messageSize)
n, err := conn.Read(b)
var message manager.ClientStreamMessage
data, err := ackReader.Read()
if err != nil {
ms.logger.Warn(err.Error())
go ms.reportBrokenConnection(cmpID)
ms.logger.Warn(err.Error())
return
}
var message manager.ClientStreamMessage
if err := proto.Unmarshal(b[:n], &message); err != nil {
if err := proto.Unmarshal(data, &message); err != nil {
ms.logger.Warn(err.Error())
continue
}
ms.eventsChan <- &message
args := []any{}
+29 -21
View File
@@ -8,7 +8,6 @@ package main
import (
"encoding/json"
"encoding/pem"
"fmt"
"log"
"net"
"os"
@@ -17,10 +16,15 @@ import (
"github.com/mdlayher/vsock"
"github.com/ultravioletrs/cocos/agent"
"github.com/ultravioletrs/cocos/internal"
internalvsock "github.com/ultravioletrs/cocos/internal/vsock"
"github.com/ultravioletrs/cocos/manager"
"github.com/ultravioletrs/cocos/manager/qemu"
pkgmanager "github.com/ultravioletrs/cocos/pkg/manager"
"google.golang.org/protobuf/proto"
)
const (
managerVsockPort = manager.ManagerVsockPort
vsockConfigPort = qemu.VsockConfigPort
)
func main() {
@@ -50,10 +54,6 @@ func main() {
log.Fatalf("failed to calculate checksum: %s", err)
}
l, err := vsock.Listen(manager.ManagerVsockPort, nil)
if err != nil {
log.Fatal(err)
}
ac := agent.Computation{
ID: "123",
Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}},
@@ -65,21 +65,30 @@ func main() {
AttestedTls: attestedTLS,
},
}
if err := SendAgentConfig(3, ac); err != nil {
if err := sendAgentConfig(3, ac); err != nil {
log.Fatal(err)
}
listener, err := vsock.Listen(managerVsockPort, nil)
if err != nil {
log.Fatalf("failed to listen on vsock: %s", err)
}
defer listener.Close()
log.Printf("Listening on vsock port %d", managerVsockPort)
for {
conn, err := l.Accept()
conn, err := listener.Accept()
if err != nil {
log.Println(err)
log.Printf("failed to accept connection: %s", err)
continue
}
go handleConnections(conn)
go handleConnection(conn)
}
}
func SendAgentConfig(cid uint32, ac agent.Computation) error {
func sendAgentConfig(cid uint32, ac agent.Computation) error {
conn, err := vsock.Dial(cid, qemu.VsockConfigPort, nil)
if err != nil {
return err
@@ -100,20 +109,19 @@ func SendAgentConfig(cid uint32, ac agent.Computation) error {
return nil
}
func handleConnections(conn net.Conn) {
func handleConnection(conn net.Conn) {
defer conn.Close()
ackReader := internalvsock.NewAckReader(conn)
for {
b := make([]byte, 1024)
n, err := conn.Read(b)
if err != nil {
log.Println(err)
return
}
var message pkgmanager.ClientStreamMessage
if err := proto.Unmarshal(b[:n], &message); err != nil {
log.Println(err)
err := ackReader.ReadProto(&message)
if err != nil {
log.Printf("Error reading message: %v", err)
return
}
fmt.Println(message.String())
log.Printf("Received message: %s", message.String())
}
}