NOISSUE - Remove redundant retry logic (#293)

* remove redundant logic

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix test

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* remove line

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* fix internal tests

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* add test cases

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

* all pb files

Signed-off-by: Sammy Oina <sammyoina@gmail.com>

---------

Signed-off-by: Sammy Oina <sammyoina@gmail.com>
This commit is contained in:
Sammy Kerata Oina
2024-10-31 17:46:56 +03:00
committed by GitHub
parent 69b8dfa3ea
commit 534ad91623
9 changed files with 242 additions and 201 deletions
+10 -62
View File
@@ -5,42 +5,28 @@ package events
import (
"encoding/json"
"io"
"sync"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
const retryInterval = 5 * time.Second
type service struct {
service string
computationID string
conn io.Writer
cachedMessages [][]byte
mutex sync.Mutex
stopRetry chan struct{}
service string
computationID string
conn io.Writer
}
//go:generate mockery --name Service --output=./mocks --filename events.go --quiet --note "Copyright (c) Ultraviolet \n // SPDX-License-Identifier: Apache-2.0"
type Service interface {
SendEvent(event, status string, details json.RawMessage) error
Close()
}
func New(svc, computationID string, conn io.Writer) (Service, error) {
s := &service{
service: svc,
computationID: computationID,
conn: conn,
cachedMessages: make([][]byte, 0),
stopRetry: make(chan struct{}),
}
go s.periodicRetry()
return s, nil
return &service{
service: svc,
computationID: computationID,
conn: conn,
}, nil
}
func (s *service) SendEvent(event, status string, details json.RawMessage) error {
@@ -56,44 +42,6 @@ func (s *service) SendEvent(event, status string, details json.RawMessage) error
if err != nil {
return err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if _, err := s.conn.Write(protoBody); err != nil {
s.cachedMessages = append(s.cachedMessages, protoBody)
return err
}
return nil
}
func (s *service) periodicRetry() {
ticker := time.NewTicker(retryInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.retrySendCachedMessages()
case <-s.stopRetry:
return
}
}
}
func (s *service) retrySendCachedMessages() {
s.mutex.Lock()
defer s.mutex.Unlock()
tmp := [][]byte{}
for _, msg := range s.cachedMessages {
if _, err := s.conn.Write(msg); err != nil {
tmp = append(tmp, msg)
}
}
s.cachedMessages = tmp
}
func (s *service) Close() {
close(s.stopRetry)
_, err = s.conn.Write(protoBody)
return err
}
-17
View File
@@ -61,21 +61,4 @@ func TestSendEventFailure(t *testing.T) {
err = svc.SendEvent("test_event", "failure", details)
assert.Error(t, err)
assert.Equal(t, "write error", err.Error())
assert.Len(t, svc.(*service).cachedMessages, 1)
}
func TestClose(t *testing.T) {
mockConnection := &mockConn{}
svc, err := New("test_service", "12345", mockConnection)
assert.NoError(t, err)
svc.Close()
time.Sleep(1 * time.Second)
details := json.RawMessage(`{"key": "value"}`)
err = svc.SendEvent("test_event", "success", details)
assert.NoError(t, err)
}
-1
View File
@@ -79,7 +79,6 @@ func main() {
exitCode = 1
return
}
defer eventSvc.Close()
qp, err := quoteprovider.GetQuoteProvider()
if err != nil {
+1
View File
@@ -9,3 +9,4 @@ coverage:
- "**/logging.go"
- "**/metrics.go"
- "**/tracing.go"
- "**/*.pb.go"
+7 -47
View File
@@ -6,25 +6,18 @@ import (
"context"
"io"
"log/slog"
"sync"
"time"
"github.com/ultravioletrs/cocos/agent/events"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
const retryInterval = 5 * time.Second
var _ slog.Handler = (*handler)(nil)
type handler struct {
opts slog.HandlerOptions
w io.Writer
cmpID string
cachedMessages [][]byte
mutex sync.Mutex
stopRetry chan struct{}
opts slog.HandlerOptions
w io.Writer
cmpID string
}
//go:generate mockery --name io.Writer --output ./mocks --filename io_writer.go
@@ -34,15 +27,11 @@ func NewProtoHandler(conn io.Writer, opts *slog.HandlerOptions, cmpID string) sl
opts = &slog.HandlerOptions{}
}
h := &handler{
opts: *opts,
w: conn,
cmpID: cmpID,
cachedMessages: make([][]byte, 0),
stopRetry: make(chan struct{}),
opts: *opts,
w: conn,
cmpID: cmpID,
}
go h.periodicRetry()
return h
}
@@ -87,43 +76,15 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error {
return err
}
h.mutex.Lock()
_, err = h.w.Write(b)
if err != nil {
h.cachedMessages = append(h.cachedMessages, b)
return err
}
h.mutex.Unlock()
}
return nil
}
func (h *handler) periodicRetry() {
ticker := time.NewTicker(retryInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.retrySendCachedMessages()
case <-h.stopRetry:
return
}
}
}
func (h *handler) retrySendCachedMessages() {
h.mutex.Lock()
defer h.mutex.Unlock()
tmp := [][]byte{}
for _, msg := range h.cachedMessages {
if _, err := h.w.Write(msg); err != nil {
tmp = append(tmp, msg)
}
}
h.cachedMessages = tmp
}
func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler {
panic("unimplemented")
}
@@ -133,6 +94,5 @@ func (h *handler) WithGroup(name string) slog.Handler {
}
func (h *handler) Close() error {
close(h.stopRetry)
return nil
}
+2 -9
View File
@@ -9,6 +9,7 @@ import (
"testing"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
)
@@ -50,8 +51,7 @@ func TestHandleMessageFailure(t *testing.T) {
err := protohandler.Handle(context.Background(), record)
assert.NoError(t, err, "Handle should not return an error even when write fails")
assert.NotEmpty(t, protohandler.(*handler).CachedMessages(), "Cached messages should not be empty")
assert.True(t, errors.Contains(err, io.ErrUnexpectedEOF), "Handle should return an error")
}
// TestEnabled tests that the handler enables logging based on level.
@@ -74,10 +74,3 @@ func TestCloseStopsRetry(t *testing.T) {
assert.NoError(t, err, "Close should not return an error")
time.Sleep(1 * time.Second) // Ensure no retry after close
}
// Utility function to retrieve cached messages.
func (h *handler) CachedMessages() [][]byte {
h.mutex.Lock()
defer h.mutex.Unlock()
return h.cachedMessages
}
+109 -61
View File
@@ -4,12 +4,14 @@
package vsock
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"google.golang.org/protobuf/proto"
@@ -20,31 +22,43 @@ const (
retryDelay = time.Second
maxMessageSize = 1 << 20 // 1 MB
ackTimeout = 5 * time.Second
maxConcurrent = 100 // Maximum number of concurrent messages
maxConcurrent = 100
)
type MessageStatus int
const (
StatusPending MessageStatus = iota
StatusSent
StatusAcknowledged
StatusFailed
)
type Message struct {
ID uint32
Content []byte
Status MessageStatus
Retries int
}
type AckWriter struct {
conn net.Conn
pendingMessages chan *Message
ackChannels map[uint32]chan bool
ackMu sync.RWMutex
messageStore sync.Map // map[uint32]*Message
nextID uint32
done chan struct{}
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func NewAckWriter(conn net.Conn) io.WriteCloser {
ctx, cancel := context.WithCancel(context.Background())
aw := &AckWriter{
conn: conn,
pendingMessages: make(chan *Message, maxConcurrent),
ackChannels: make(map[uint32]chan bool),
nextID: 1,
done: make(chan struct{}),
ctx: ctx,
cancel: cancel,
}
aw.wg.Add(2)
go aw.sendMessages()
@@ -57,64 +71,91 @@ func (aw *AckWriter) Write(p []byte) (int, error) {
return 0, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
}
aw.ackMu.Lock()
messageID := aw.nextID
aw.nextID++
messageID := atomic.AddUint32(&aw.nextID, 1)
message := &Message{
ID: messageID,
Content: make([]byte, len(p)),
Status: StatusPending,
}
copy(message.Content, p)
ackCh := make(chan bool, 1)
aw.ackChannels[messageID] = ackCh
aw.ackMu.Unlock()
message := &Message{ID: messageID, Content: p}
aw.messageStore.Store(messageID, message)
select {
case aw.pendingMessages <- message:
// Message queued successfully
case <-aw.done:
return 0, fmt.Errorf("writer is closed")
}
timer := time.NewTimer(ackTimeout)
defer timer.Stop()
select {
case <-ackCh:
return len(p), nil
case <-time.After(ackTimeout):
return 0, fmt.Errorf("timeout waiting for acknowledgment")
case <-aw.done:
return 0, fmt.Errorf("writer closed while waiting for acknowledgment")
for {
if msg, ok := aw.messageStore.Load(messageID); ok {
m := msg.(*Message)
if m.Status == StatusAcknowledged {
return len(p), nil
}
if m.Status == StatusFailed {
return 0, fmt.Errorf("message delivery failed after %d retries", maxRetries)
}
}
select {
case <-timer.C:
return 0, fmt.Errorf("timeout waiting for acknowledgment")
case <-aw.ctx.Done():
return 0, fmt.Errorf("writer closed while waiting for acknowledgment")
case <-time.After(100 * time.Millisecond):
continue
}
}
case <-aw.ctx.Done():
return 0, fmt.Errorf("writer is closed")
}
}
func (aw *AckWriter) sendMessages() {
defer aw.wg.Done()
for {
select {
case <-aw.done:
case <-aw.ctx.Done():
return
case msg := <-aw.pendingMessages:
for i := 0; i < maxRetries; i++ {
if err := aw.writeMessage(msg.ID, msg.Content); err != nil {
log.Printf("Error writing message %d (attempt %d): %v", msg.ID, i+1, err)
time.Sleep(retryDelay)
continue
}
break
if err := aw.sendWithRetry(msg); err != nil {
log.Printf("Failed to send message %d after all retries: %v", msg.ID, err)
msg.Status = StatusFailed
aw.messageStore.Store(msg.ID, msg)
}
}
}
}
func (aw *AckWriter) sendWithRetry(msg *Message) error {
for msg.Retries < maxRetries {
if err := aw.writeMessage(msg.ID, msg.Content); err != nil {
msg.Retries++
msg.Status = StatusPending
log.Printf("Error writing message %d (attempt %d): %v", msg.ID, msg.Retries, err)
time.Sleep(retryDelay)
continue
}
msg.Status = StatusSent
aw.messageStore.Store(msg.ID, msg)
return nil
}
return fmt.Errorf("max retries reached")
}
func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
if err := binary.Write(aw.conn, binary.LittleEndian, messageID); err != nil {
return err
return fmt.Errorf("failed to write message ID: %w", err)
}
messageLen := uint32(len(p))
if err := binary.Write(aw.conn, binary.LittleEndian, messageLen); err != nil {
return err
return fmt.Errorf("failed to write message length: %w", err)
}
if _, err := aw.conn.Write(p); err != nil {
return err
return fmt.Errorf("failed to write message content: %w", err)
}
return nil
@@ -122,14 +163,14 @@ func (aw *AckWriter) writeMessage(messageID uint32, p []byte) error {
func (aw *AckWriter) handleAcknowledgments() {
defer aw.wg.Done()
for {
select {
case <-aw.done:
case <-aw.ctx.Done():
return
default:
var ackID uint32
err := binary.Read(aw.conn, binary.LittleEndian, &ackID)
if err != nil {
if err := binary.Read(aw.conn, binary.LittleEndian, &ackID); err != nil {
if err == io.EOF {
log.Println("Connection closed, stopping acknowledgment handler")
return
@@ -139,19 +180,13 @@ func (aw *AckWriter) handleAcknowledgments() {
continue
}
aw.ackMu.RLock()
ackCh, ok := aw.ackChannels[ackID]
aw.ackMu.RUnlock()
if msg, ok := aw.messageStore.Load(ackID); ok {
m := msg.(*Message)
m.Status = StatusAcknowledged
aw.messageStore.Store(ackID, m)
if ok {
select {
case ackCh <- true:
default:
// Channel is already closed or full
}
aw.ackMu.Lock()
delete(aw.ackChannels, ackID)
aw.ackMu.Unlock()
// Clean up old messages periodically
go aw.cleanupOldMessages(ackID)
} else {
log.Printf("Received ACK for unknown message ID: %d", ackID)
}
@@ -159,8 +194,21 @@ func (aw *AckWriter) handleAcknowledgments() {
}
}
func (aw *AckWriter) cleanupOldMessages(currentID uint32) {
aw.messageStore.Range(func(key, value interface{}) bool {
msgID := key.(uint32)
msg := value.(*Message)
// Clean up acknowledged messages that are old
if msg.Status == StatusAcknowledged && msgID < currentID-maxConcurrent {
aw.messageStore.Delete(msgID)
}
return true
})
}
func (aw *AckWriter) Close() error {
close(aw.done)
aw.cancel()
aw.wg.Wait()
return aw.conn.Close()
}
@@ -172,46 +220,46 @@ type Reader interface {
type AckReader struct {
conn net.Conn
ctx context.Context
}
func NewAckReader(conn net.Conn) Reader {
return &AckReader{
conn: conn,
ctx: context.Background(),
}
}
func (ar *AckReader) ReadProto(msg proto.Message) error {
data, err := ar.Read()
if err != nil {
return err
return fmt.Errorf("failed to read proto message: %w", err)
}
return proto.Unmarshal(data, msg)
}
func (ar *AckReader) Read() ([]byte, error) {
var messageID uint32
if err := binary.Read(ar.conn, binary.LittleEndian, &messageID); err != nil {
return nil, fmt.Errorf("error reading message ID: %v", err)
return nil, fmt.Errorf("error reading message ID: %w", err)
}
var messageLen uint32
if err := binary.Read(ar.conn, binary.LittleEndian, &messageLen); err != nil {
return nil, fmt.Errorf("error reading message length: %v", err)
return nil, fmt.Errorf("error reading message length: %w", err)
}
if messageLen > maxMessageSize {
return nil, fmt.Errorf("message size exceeds maximum allowed size of %d bytes", maxMessageSize)
return nil, fmt.Errorf("message size %d exceeds maximum allowed size of %d bytes", messageLen, maxMessageSize)
}
data := make([]byte, messageLen)
_, err := io.ReadFull(ar.conn, data)
if err != nil {
return nil, fmt.Errorf("error reading message content: %v", err)
if _, err := io.ReadFull(ar.conn, data); err != nil {
return nil, fmt.Errorf("error reading message content: %w", err)
}
if err := ar.sendAck(messageID); err != nil {
return nil, fmt.Errorf("error sending ACK: %v", err)
return nil, fmt.Errorf("error sending ACK: %w", err)
}
return data, nil
+112
View File
@@ -203,3 +203,115 @@ func TestAckWriter_Close(t *testing.T) {
t.Errorf("AckWriter.Close() did not close the connection")
}
}
func TestAckWriter_Write(t *testing.T) {
tests := []struct {
name string
input []byte
mockBehavior func(*MockConn)
expectErr bool
expectedError string
}{
{
name: "Message exceeds max size",
input: make([]byte, maxMessageSize+1),
mockBehavior: func(m *MockConn) {},
expectErr: true,
expectedError: "message size exceeds maximum allowed size",
},
{
name: "Timeout waiting for acknowledgment",
input: []byte("timeout message"),
mockBehavior: func(m *MockConn) {
// Don't send ACK, let it timeout
},
expectErr: true,
expectedError: "timeout waiting for acknowledgment",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConn := &MockConn{
mu: sync.Mutex{},
}
if tt.mockBehavior != nil {
tt.mockBehavior(mockConn)
}
writer := NewAckWriter(mockConn)
defer writer.Close()
n, err := writer.Write(tt.input)
if tt.expectErr {
assert.Error(t, err)
if tt.expectedError != "" {
assert.Contains(t, err.Error(), tt.expectedError)
}
assert.Zero(t, n)
} else {
assert.NoError(t, err)
assert.Equal(t, len(tt.input), n)
assert.GreaterOrEqual(t, len(mockConn.WrittenData), 8+len(tt.input))
messageLen := binary.LittleEndian.Uint32(mockConn.WrittenData[4:8])
assert.Equal(t, uint32(len(tt.input)), messageLen)
assert.Equal(t, tt.input, mockConn.WrittenData[8:8+len(tt.input)])
}
})
}
}
func TestAckWriter_CleanupOldMessages(t *testing.T) {
mockConn := &MockConn{}
writer := NewAckWriter(mockConn).(*AckWriter)
defer writer.Close()
for i := uint32(1); i <= maxConcurrent+10; i++ {
msg := &Message{
ID: i,
Content: []byte("test"),
Status: StatusAcknowledged,
}
writer.messageStore.Store(i, msg)
}
writer.cleanupOldMessages(maxConcurrent + 11)
var count int
writer.messageStore.Range(func(key, value interface{}) bool {
count++
return true
})
assert.LessOrEqual(t, count, maxConcurrent)
}
func TestAckReader_LargeMessage(t *testing.T) {
mockConn := &MockConn{}
reader := NewAckReader(mockConn)
largeMessage := make([]byte, maxMessageSize-1)
for i := range largeMessage {
largeMessage[i] = byte(i % 256)
}
messageID := uint32(1)
messageLen := uint32(len(largeMessage))
mockData := make([]byte, 8+len(largeMessage))
binary.LittleEndian.PutUint32(mockData[:4], messageID)
binary.LittleEndian.PutUint32(mockData[4:8], messageLen)
copy(mockData[8:], largeMessage)
mockConn.ReadData = mockData
data, err := reader.Read()
assert.NoError(t, err)
assert.Equal(t, largeMessage, data)
assert.Equal(t, 4, len(mockConn.WrittenData))
ackID := binary.LittleEndian.Uint32(mockConn.WrittenData)
assert.Equal(t, messageID, ackID)
}
+1 -4
View File
@@ -14,10 +14,7 @@ import (
"google.golang.org/protobuf/proto"
)
const (
ManagerVsockPort = 9997
messageSize int = 1024 * 1024
)
const ManagerVsockPort = 9997
type ReportBrokenConnectionFunc func(address string)